From 1bbe065647b30f7b457df56747b24510c225258d Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 9 Jan 2023 09:11:37 -0800 Subject: [PATCH] Simplify default options for GestureRecognize PiperOrigin-RevId: 500729643 --- mediapipe/tasks/testdata/vision/BUILD | 2 + .../gesture_recognizer/gesture_recognizer.ts | 39 +++++-------------- third_party/external_files.bzl | 6 +++ 3 files changed, 17 insertions(+), 30 deletions(-) diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 95b721fdb..607245700 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -38,6 +38,7 @@ mediapipe_files(srcs = [ "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", "fist.jpg", + "fist.png", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "hand_landmarker.task", @@ -95,6 +96,7 @@ filegroup( "cats_and_dogs_no_resizing.jpg", "cats_and_dogs_rotated.jpg", "fist.jpg", + "fist.png", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "left_hands.jpg", diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 48efc4855..1b7201b9a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -54,7 +54,7 @@ const GESTURE_RECOGNIZER_GRAPH = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'; const DEFAULT_NUM_HANDS = 1; -const DEFAULT_SCORE_THRESHOLD = 0.5; +const DEFAULT_CONFIDENCE = 0.5; const DEFAULT_CATEGORY_INDEX = -1; /** Performs hand gesture recognition on images. */ @@ -143,8 +143,6 @@ export class GestureRecognizer extends VisionTaskRunner { new HandGestureRecognizerGraphOptions(); this.options.setHandGestureRecognizerGraphOptions( this.handGestureRecognizerGraphOptions); - - this.initDefaults(); } protected override get baseOptions(): BaseOptionsProto { @@ -165,22 +163,14 @@ export class GestureRecognizer extends VisionTaskRunner { * @param options The options for the gesture recognizer. */ override setOptions(options: GestureRecognizerOptions): Promise { - if ('numHands' in options) { - this.handDetectorGraphOptions.setNumHands( - options.numHands ?? DEFAULT_NUM_HANDS); - } - if ('minHandDetectionConfidence' in options) { - this.handDetectorGraphOptions.setMinDetectionConfidence( - options.minHandDetectionConfidence ?? DEFAULT_SCORE_THRESHOLD); - } - if ('minHandPresenceConfidence' in options) { - this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( - options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD); - } - if ('minTrackingConfidence' in options) { - this.handLandmarkerGraphOptions.setMinTrackingConfidence( - options.minTrackingConfidence ?? DEFAULT_SCORE_THRESHOLD); - } + this.handDetectorGraphOptions.setNumHands( + options.numHands ?? DEFAULT_NUM_HANDS); + this.handDetectorGraphOptions.setMinDetectionConfidence( + options.minHandDetectionConfidence ?? DEFAULT_CONFIDENCE); + this.handLandmarkerGraphOptions.setMinTrackingConfidence( + options.minTrackingConfidence ?? DEFAULT_CONFIDENCE); + this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( + options.minHandPresenceConfidence ?? DEFAULT_CONFIDENCE); if (options.cannedGesturesClassifierOptions) { // Note that we have to support both JSPB and ProtobufJS and cannot @@ -281,17 +271,6 @@ export class GestureRecognizer extends VisionTaskRunner { } } - /** Sets the default values for the graph. */ - private initDefaults(): void { - this.handDetectorGraphOptions.setNumHands(DEFAULT_NUM_HANDS); - this.handDetectorGraphOptions.setMinDetectionConfidence( - DEFAULT_SCORE_THRESHOLD); - this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( - DEFAULT_SCORE_THRESHOLD); - this.handLandmarkerGraphOptions.setMinTrackingConfidence( - DEFAULT_SCORE_THRESHOLD); - } - /** Converts the proto data to a Category[][] structure. */ private toJsCategories(data: Uint8Array[], populateIndex = true): Category[][] { diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 72ca95e66..790486676 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -286,6 +286,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/fist_landmarks.pbtxt?generation=1666999360561864"], ) + http_file( + name = "com_google_mediapipe_fist_png", + sha256 = "4397b3d3f590c88a8de7d21c08d73a0df4a97fd93f92cbd086eef37fd246daaa", + urls = ["https://storage.googleapis.com/mediapipe-assets/fist.png?generation=1672952068696274"], + ) + http_file( name = "com_google_mediapipe_general_meta_json", sha256 = "b95363e4bae89b9c2af484498312aaad4efc7ff57c7eadcc4e5e7adca641445f",