From 9f8b5e5c11b0aad8fe6987f3869ac8f5cf0ac013 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 31 Oct 2022 13:55:32 -0700 Subject: [PATCH] Add allow_list/deny_list to gesture recognizer options. PiperOrigin-RevId: 485141209 --- .../tasks/cc/vision/gesture_recognizer/BUILD | 1 + .../gesture_recognizer/gesture_recognizer.cc | 26 ++-- .../gesture_recognizer/gesture_recognizer.h | 19 ++- .../hand_gesture_recognizer_graph.cc | 11 +- .../com/google/mediapipe/tasks/vision/BUILD | 1 + .../gesturerecognizer/GestureRecognizer.java | 52 +++++--- .../GestureRecognizerTest.java | 113 +++++++++++++++++- 7 files changed, 184 insertions(+), 39 deletions(-) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index a88d0d72c..f32d4cc58 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -141,6 +141,7 @@ cc_library( "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/containers:gesture_recognition_result", + "//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:base_task_api", diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index fa1fc69ce..38cb5169d 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -141,16 +141,22 @@ ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) { // Configure hand gesture recognizer options. auto* hand_gesture_recognizer_graph_options = options_proto->mutable_hand_gesture_recognizer_graph_options(); - if (options->min_gesture_confidence >= 0) { - hand_gesture_recognizer_graph_options - ->mutable_canned_gesture_classifier_graph_options() - ->mutable_classifier_options() - ->set_score_threshold(options->min_gesture_confidence); - hand_gesture_recognizer_graph_options - ->mutable_custom_gesture_classifier_graph_options() - ->mutable_classifier_options() - ->set_score_threshold(options->min_gesture_confidence); - } + auto canned_gestures_classifier_options_proto = + std::make_unique( + components::processors::ConvertClassifierOptionsToProto( + &(options->canned_gestures_classifier_options))); + hand_gesture_recognizer_graph_options + ->mutable_canned_gesture_classifier_graph_options() + ->mutable_classifier_options() + ->Swap(canned_gestures_classifier_options_proto.get()); + auto custom_gestures_classifier_options_proto = + std::make_unique( + components::processors::ConvertClassifierOptionsToProto( + &(options->canned_gestures_classifier_options))); + hand_gesture_recognizer_graph_options + ->mutable_custom_gesture_classifier_graph_options() + ->mutable_classifier_options() + ->Swap(canned_gestures_classifier_options_proto.get()); return options_proto; } diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h index 3e281b26e..3f3d7acfe 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h @@ -18,12 +18,14 @@ limitations under the License. #include #include +#include #include "absl/status/statusor.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" @@ -64,12 +66,17 @@ struct GestureRecognizerOptions { // successful. float min_tracking_confidence = 0.5; - // The minimum confidence score for the gestures to be considered - // successful. If < 0, the gesture confidence thresholds in the model - // metadata are used. - // TODO Note this option is subject to change, after scoring - // merging calculator is implemented. - float min_gesture_confidence = -1; + // TODO Note this option is subject to change. + // Options for configuring the canned gestures classifier, such as score + // threshold, allow list and deny list of gestures. The categories for canned + // gesture classifiers are: ["None", "Closed_Fist", "Open_Palm", + // "Pointing_Up", "Thumb_Down", "Thumb_Up", "Victory", "ILoveYou"] + components::processors::ClassifierOptions canned_gestures_classifier_options; + + // TODO Note this option is subject to change. + // Options for configuring the custom gestures classifier, such as score + // threshold, allow list and deny list of gestures. + components::processors::ClassifierOptions custom_gestures_classifier_options; // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 6e83f9eec..7b6a8c79d 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -403,11 +403,11 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { const core::ModelResources* model_resources, const proto::GestureClassifierGraphOptions& options, Source& embedding_tensors, Graph& graph) { - auto& custom_gesture_classifier_inference = AddInference( + auto& gesture_classifier_inference = AddInference( *model_resources, options.base_options().acceleration(), graph); - embedding_tensors >> custom_gesture_classifier_inference.In(kTensorsTag); - auto custom_gesture_inference_out_tensors = - custom_gesture_classifier_inference.Out(kTensorsTag); + embedding_tensors >> gesture_classifier_inference.In(kTensorsTag); + auto gesture_inference_out_tensors = + gesture_classifier_inference.Out(kTensorsTag); auto& tensors_to_classification = graph.AddNode("TensorsToClassificationCalculator"); MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator( @@ -415,8 +415,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { 0, &tensors_to_classification.GetOptions< mediapipe::TensorsToClassificationCalculatorOptions>())); - custom_gesture_inference_out_tensors >> - tensors_to_classification.In(kTensorsTag); + gesture_inference_out_tensors >> tensors_to_classification.In(kTensorsTag); return tensors_to_classification.Out("CLASSIFICATIONS") .Cast(); } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index ed65fbcac..d15040ae7 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -137,6 +137,7 @@ android_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java index 8e5a30eab..d6faf5986 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java @@ -26,7 +26,7 @@ import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; -import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; +import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; @@ -398,13 +398,26 @@ public final class GestureRecognizer extends BaseVisionTaskApi { public abstract Builder setMinTrackingConfidence(Float value); /** - * Sets the minimum confidence score for the gestures to be considered successful. If < 0, the - * gesture confidence threshold=0.5 for the model is used. + * Sets the optional {@link ClassifierOptions} controling the canned gestures classifier, such + * as score threshold, allow list and deny list of gestures. The categories for canned gesture + * classifiers are: ["None", "Closed_Fist", "Open_Palm", "Pointing_Up", "Thumb_Down", + * "Thumb_Up", "Victory", "ILoveYou"] * *

TODO Note this option is subject to change, after scoring merging * calculator is implemented. */ - public abstract Builder setMinGestureConfidence(Float value); + public abstract Builder setCannedGesturesClassifierOptions( + ClassifierOptions classifierOptions); + + /** + * Sets the optional {@link ClassifierOptions} controling the custom gestures classifier, such + * as score threshold, allow list and deny list of gestures. + * + *

TODO Note this option is subject to change, after scoring merging + * calculator is implemented. + */ + public abstract Builder setCustomGesturesClassifierOptions( + ClassifierOptions classifierOptions); /** * Sets the result listener to receive the detection results asynchronously when the gesture @@ -454,8 +467,9 @@ public final class GestureRecognizer extends BaseVisionTaskApi { abstract Optional minTrackingConfidence(); - // TODO update gesture confidence options after score merging calculator is ready. - abstract Optional minGestureConfidence(); + abstract Optional cannedGesturesClassifierOptions(); + + abstract Optional customGesturesClassifierOptions(); abstract Optional> resultListener(); @@ -467,8 +481,7 @@ public final class GestureRecognizer extends BaseVisionTaskApi { .setNumHands(1) .setMinHandDetectionConfidence(0.5f) .setMinHandPresenceConfidence(0.5f) - .setMinTrackingConfidence(0.5f) - .setMinGestureConfidence(-1f); + .setMinTrackingConfidence(0.5f); } /** @@ -511,13 +524,22 @@ public final class GestureRecognizer extends BaseVisionTaskApi { HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.Builder handGestureRecognizerGraphOptionsBuilder = HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.newBuilder(); - ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = - ClassifierOptionsProto.ClassifierOptions.newBuilder(); - minGestureConfidence().ifPresent(classifierOptionsBuilder::setScoreThreshold); - handGestureRecognizerGraphOptionsBuilder.setCannedGestureClassifierGraphOptions( - GestureClassifierGraphOptionsProto.GestureClassifierGraphOptions.newBuilder() - .setClassifierOptions(classifierOptionsBuilder.build())); - + cannedGesturesClassifierOptions() + .ifPresent( + classifierOptions -> { + handGestureRecognizerGraphOptionsBuilder.setCannedGestureClassifierGraphOptions( + GestureClassifierGraphOptionsProto.GestureClassifierGraphOptions.newBuilder() + .setClassifierOptions(classifierOptions.convertToProto()) + .build()); + }); + customGesturesClassifierOptions() + .ifPresent( + classifierOptions -> { + handGestureRecognizerGraphOptionsBuilder.setCustomGestureClassifierGraphOptions( + GestureClassifierGraphOptionsProto.GestureClassifierGraphOptions.newBuilder() + .setClassifierOptions(classifierOptions.convertToProto()) + .build()); + }); taskOptionsBuilder .setHandLandmarkerGraphOptions(handLandmarkerGraphOptionsBuilder.build()) .setHandGestureRecognizerGraphOptions(handGestureRecognizerGraphOptionsBuilder.build()); diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java index 2d4b3a50d..f76c4eaab 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java @@ -30,6 +30,7 @@ import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.Landmark; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; +import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; @@ -106,14 +107,15 @@ public class GestureRecognizerTest { } @Test - public void recognize_successWithMinGestureConfidence() throws Exception { + public void recognize_successWithScoreThreshold() throws Exception { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( BaseOptions.builder() .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) .build()) - .setMinGestureConfidence(0.5f) + .setCannedGesturesClassifierOptions( + ClassifierOptions.builder().setScoreThreshold(0.5f).build()) .build(); GestureRecognizer gestureRecognizer = GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -204,6 +206,113 @@ public class GestureRecognizerTest { assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); } + @Test + public void recognize_successWithAllowGestureFist() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) + .setNumHands(1) + .setCannedGesturesClassifierOptions( + ClassifierOptions.builder() + .setScoreThreshold(0.5f) + .setCategoryAllowlist(Arrays.asList("Closed_Fist")) + .build()) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize(getImageFromAsset(FIST_IMAGE)); + GestureRecognitionResult expectedResult = + getExpectedGestureRecognitionResult(FIST_LANDMARKS, FIST_LABEL); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void recognize_successWithDenyGestureFist() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) + .setNumHands(1) + .setCannedGesturesClassifierOptions( + ClassifierOptions.builder() + .setScoreThreshold(0.5f) + .setCategoryDenylist(Arrays.asList("Closed_Fist")) + .build()) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize(getImageFromAsset(FIST_IMAGE)); + assertThat(actualResult.landmarks()).isEmpty(); + assertThat(actualResult.worldLandmarks()).isEmpty(); + assertThat(actualResult.handednesses()).isEmpty(); + assertThat(actualResult.gestures()).isEmpty(); + } + + @Test + public void recognize_successWithAllowAllGestureExceptFist() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) + .setNumHands(1) + .setCannedGesturesClassifierOptions( + ClassifierOptions.builder() + .setScoreThreshold(0.5f) + .setCategoryAllowlist( + Arrays.asList( + "None", + "Open_Palm", + "Pointing_Up", + "Thumb_Down", + "Thumb_Up", + "Victory", + "ILoveYou")) + .build()) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize(getImageFromAsset(FIST_IMAGE)); + assertThat(actualResult.landmarks()).isEmpty(); + assertThat(actualResult.worldLandmarks()).isEmpty(); + assertThat(actualResult.handednesses()).isEmpty(); + assertThat(actualResult.gestures()).isEmpty(); + } + + @Test + public void recognize_successWithPreferAlowListThanDenyList() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) + .setNumHands(1) + .setCannedGesturesClassifierOptions( + ClassifierOptions.builder() + .setScoreThreshold(0.5f) + .setCategoryAllowlist(Arrays.asList("Closed_Fist")) + .setCategoryDenylist(Arrays.asList("Closed_Fist")) + .build()) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize(getImageFromAsset(FIST_IMAGE)); + GestureRecognitionResult expectedResult = + getExpectedGestureRecognitionResult(FIST_LANDMARKS, FIST_LABEL); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + @Test public void recognize_failsWithRegionOfInterest() throws Exception { GestureRecognizerOptions options =