Add allow_list/deny_list to gesture recognizer options.

PiperOrigin-RevId: 485141209
This commit is contained in:
MediaPipe Team 2022-10-31 13:55:32 -07:00 committed by Copybara-Service
parent 4717ac298c
commit 9f8b5e5c11
7 changed files with 184 additions and 39 deletions

View File

@ -141,6 +141,7 @@ cc_library(
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/containers:gesture_recognition_result", "//mediapipe/tasks/cc/components/containers:gesture_recognition_result",
"//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:base_task_api", "//mediapipe/tasks/cc/core:base_task_api",

View File

@ -141,16 +141,22 @@ ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) {
// Configure hand gesture recognizer options. // Configure hand gesture recognizer options.
auto* hand_gesture_recognizer_graph_options = auto* hand_gesture_recognizer_graph_options =
options_proto->mutable_hand_gesture_recognizer_graph_options(); options_proto->mutable_hand_gesture_recognizer_graph_options();
if (options->min_gesture_confidence >= 0) { auto canned_gestures_classifier_options_proto =
std::make_unique<components::processors::proto::ClassifierOptions>(
components::processors::ConvertClassifierOptionsToProto(
&(options->canned_gestures_classifier_options)));
hand_gesture_recognizer_graph_options hand_gesture_recognizer_graph_options
->mutable_canned_gesture_classifier_graph_options() ->mutable_canned_gesture_classifier_graph_options()
->mutable_classifier_options() ->mutable_classifier_options()
->set_score_threshold(options->min_gesture_confidence); ->Swap(canned_gestures_classifier_options_proto.get());
auto custom_gestures_classifier_options_proto =
std::make_unique<components::processors::proto::ClassifierOptions>(
components::processors::ConvertClassifierOptionsToProto(
&(options->canned_gestures_classifier_options)));
hand_gesture_recognizer_graph_options hand_gesture_recognizer_graph_options
->mutable_custom_gesture_classifier_graph_options() ->mutable_custom_gesture_classifier_graph_options()
->mutable_classifier_options() ->mutable_classifier_options()
->set_score_threshold(options->min_gesture_confidence); ->Swap(canned_gestures_classifier_options_proto.get());
}
return options_proto; return options_proto;
} }

View File

@ -18,12 +18,14 @@ limitations under the License.
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <vector>
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h" #include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
@ -64,12 +66,17 @@ struct GestureRecognizerOptions {
// successful. // successful.
float min_tracking_confidence = 0.5; float min_tracking_confidence = 0.5;
// The minimum confidence score for the gestures to be considered // TODO Note this option is subject to change.
// successful. If < 0, the gesture confidence thresholds in the model // Options for configuring the canned gestures classifier, such as score
// metadata are used. // threshold, allow list and deny list of gestures. The categories for canned
// TODO Note this option is subject to change, after scoring // gesture classifiers are: ["None", "Closed_Fist", "Open_Palm",
// merging calculator is implemented. // "Pointing_Up", "Thumb_Down", "Thumb_Up", "Victory", "ILoveYou"]
float min_gesture_confidence = -1; components::processors::ClassifierOptions canned_gestures_classifier_options;
// TODO Note this option is subject to change.
// Options for configuring the custom gestures classifier, such as score
// threshold, allow list and deny list of gestures.
components::processors::ClassifierOptions custom_gestures_classifier_options;
// The user-defined result callback for processing live stream data. // The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set

View File

@ -403,11 +403,11 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
const core::ModelResources* model_resources, const core::ModelResources* model_resources,
const proto::GestureClassifierGraphOptions& options, const proto::GestureClassifierGraphOptions& options,
Source<Tensor>& embedding_tensors, Graph& graph) { Source<Tensor>& embedding_tensors, Graph& graph) {
auto& custom_gesture_classifier_inference = AddInference( auto& gesture_classifier_inference = AddInference(
*model_resources, options.base_options().acceleration(), graph); *model_resources, options.base_options().acceleration(), graph);
embedding_tensors >> custom_gesture_classifier_inference.In(kTensorsTag); embedding_tensors >> gesture_classifier_inference.In(kTensorsTag);
auto custom_gesture_inference_out_tensors = auto gesture_inference_out_tensors =
custom_gesture_classifier_inference.Out(kTensorsTag); gesture_classifier_inference.Out(kTensorsTag);
auto& tensors_to_classification = auto& tensors_to_classification =
graph.AddNode("TensorsToClassificationCalculator"); graph.AddNode("TensorsToClassificationCalculator");
MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator( MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator(
@ -415,8 +415,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
0, 0,
&tensors_to_classification.GetOptions< &tensors_to_classification.GetOptions<
mediapipe::TensorsToClassificationCalculatorOptions>())); mediapipe::TensorsToClassificationCalculatorOptions>()));
custom_gesture_inference_out_tensors >> gesture_inference_out_tensors >> tensors_to_classification.In(kTensorsTag);
tensors_to_classification.In(kTensorsTag);
return tensors_to_classification.Out("CLASSIFICATIONS") return tensors_to_classification.Out("CLASSIFICATIONS")
.Cast<ClassificationList>(); .Cast<ClassificationList>();
} }

View File

@ -137,6 +137,7 @@ android_library(
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue", "//third_party:autovalue",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",

View File

@ -26,7 +26,7 @@ import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler; import com.google.mediapipe.tasks.core.OutputHandler;
@ -398,13 +398,26 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
public abstract Builder setMinTrackingConfidence(Float value); public abstract Builder setMinTrackingConfidence(Float value);
/** /**
* Sets the minimum confidence score for the gestures to be considered successful. If < 0, the * Sets the optional {@link ClassifierOptions} controling the canned gestures classifier, such
* gesture confidence threshold=0.5 for the model is used. * as score threshold, allow list and deny list of gestures. The categories for canned gesture
* classifiers are: ["None", "Closed_Fist", "Open_Palm", "Pointing_Up", "Thumb_Down",
* "Thumb_Up", "Victory", "ILoveYou"]
* *
* <p>TODO Note this option is subject to change, after scoring merging * <p>TODO Note this option is subject to change, after scoring merging
* calculator is implemented. * calculator is implemented.
*/ */
public abstract Builder setMinGestureConfidence(Float value); public abstract Builder setCannedGesturesClassifierOptions(
ClassifierOptions classifierOptions);
/**
* Sets the optional {@link ClassifierOptions} controling the custom gestures classifier, such
* as score threshold, allow list and deny list of gestures.
*
* <p>TODO Note this option is subject to change, after scoring merging
* calculator is implemented.
*/
public abstract Builder setCustomGesturesClassifierOptions(
ClassifierOptions classifierOptions);
/** /**
* Sets the result listener to receive the detection results asynchronously when the gesture * Sets the result listener to receive the detection results asynchronously when the gesture
@ -454,8 +467,9 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
abstract Optional<Float> minTrackingConfidence(); abstract Optional<Float> minTrackingConfidence();
// TODO update gesture confidence options after score merging calculator is ready. abstract Optional<ClassifierOptions> cannedGesturesClassifierOptions();
abstract Optional<Float> minGestureConfidence();
abstract Optional<ClassifierOptions> customGesturesClassifierOptions();
abstract Optional<ResultListener<GestureRecognitionResult, MPImage>> resultListener(); abstract Optional<ResultListener<GestureRecognitionResult, MPImage>> resultListener();
@ -467,8 +481,7 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
.setNumHands(1) .setNumHands(1)
.setMinHandDetectionConfidence(0.5f) .setMinHandDetectionConfidence(0.5f)
.setMinHandPresenceConfidence(0.5f) .setMinHandPresenceConfidence(0.5f)
.setMinTrackingConfidence(0.5f) .setMinTrackingConfidence(0.5f);
.setMinGestureConfidence(-1f);
} }
/** /**
@ -511,13 +524,22 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.Builder HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.Builder
handGestureRecognizerGraphOptionsBuilder = handGestureRecognizerGraphOptionsBuilder =
HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.newBuilder(); HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.newBuilder();
ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = cannedGesturesClassifierOptions()
ClassifierOptionsProto.ClassifierOptions.newBuilder(); .ifPresent(
minGestureConfidence().ifPresent(classifierOptionsBuilder::setScoreThreshold); classifierOptions -> {
handGestureRecognizerGraphOptionsBuilder.setCannedGestureClassifierGraphOptions( handGestureRecognizerGraphOptionsBuilder.setCannedGestureClassifierGraphOptions(
GestureClassifierGraphOptionsProto.GestureClassifierGraphOptions.newBuilder() GestureClassifierGraphOptionsProto.GestureClassifierGraphOptions.newBuilder()
.setClassifierOptions(classifierOptionsBuilder.build())); .setClassifierOptions(classifierOptions.convertToProto())
.build());
});
customGesturesClassifierOptions()
.ifPresent(
classifierOptions -> {
handGestureRecognizerGraphOptionsBuilder.setCustomGestureClassifierGraphOptions(
GestureClassifierGraphOptionsProto.GestureClassifierGraphOptions.newBuilder()
.setClassifierOptions(classifierOptions.convertToProto())
.build());
});
taskOptionsBuilder taskOptionsBuilder
.setHandLandmarkerGraphOptions(handLandmarkerGraphOptionsBuilder.build()) .setHandLandmarkerGraphOptions(handLandmarkerGraphOptionsBuilder.build())
.setHandGestureRecognizerGraphOptions(handGestureRecognizerGraphOptionsBuilder.build()); .setHandGestureRecognizerGraphOptions(handGestureRecognizerGraphOptionsBuilder.build());

View File

@ -30,6 +30,7 @@ import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.Landmark; import com.google.mediapipe.tasks.components.containers.Landmark;
import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.core.RunningMode;
@ -106,14 +107,15 @@ public class GestureRecognizerTest {
} }
@Test @Test
public void recognize_successWithMinGestureConfidence() throws Exception { public void recognize_successWithScoreThreshold() throws Exception {
GestureRecognizerOptions options = GestureRecognizerOptions options =
GestureRecognizerOptions.builder() GestureRecognizerOptions.builder()
.setBaseOptions( .setBaseOptions(
BaseOptions.builder() BaseOptions.builder()
.setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
.build()) .build())
.setMinGestureConfidence(0.5f) .setCannedGesturesClassifierOptions(
ClassifierOptions.builder().setScoreThreshold(0.5f).build())
.build(); .build();
GestureRecognizer gestureRecognizer = GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -204,6 +206,113 @@ public class GestureRecognizerTest {
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
} }
@Test
public void recognize_successWithAllowGestureFist() throws Exception {
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
.build())
.setNumHands(1)
.setCannedGesturesClassifierOptions(
ClassifierOptions.builder()
.setScoreThreshold(0.5f)
.setCategoryAllowlist(Arrays.asList("Closed_Fist"))
.build())
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
GestureRecognitionResult actualResult =
gestureRecognizer.recognize(getImageFromAsset(FIST_IMAGE));
GestureRecognitionResult expectedResult =
getExpectedGestureRecognitionResult(FIST_LANDMARKS, FIST_LABEL);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
@Test
public void recognize_successWithDenyGestureFist() throws Exception {
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
.build())
.setNumHands(1)
.setCannedGesturesClassifierOptions(
ClassifierOptions.builder()
.setScoreThreshold(0.5f)
.setCategoryDenylist(Arrays.asList("Closed_Fist"))
.build())
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
GestureRecognitionResult actualResult =
gestureRecognizer.recognize(getImageFromAsset(FIST_IMAGE));
assertThat(actualResult.landmarks()).isEmpty();
assertThat(actualResult.worldLandmarks()).isEmpty();
assertThat(actualResult.handednesses()).isEmpty();
assertThat(actualResult.gestures()).isEmpty();
}
@Test
public void recognize_successWithAllowAllGestureExceptFist() throws Exception {
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
.build())
.setNumHands(1)
.setCannedGesturesClassifierOptions(
ClassifierOptions.builder()
.setScoreThreshold(0.5f)
.setCategoryAllowlist(
Arrays.asList(
"None",
"Open_Palm",
"Pointing_Up",
"Thumb_Down",
"Thumb_Up",
"Victory",
"ILoveYou"))
.build())
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
GestureRecognitionResult actualResult =
gestureRecognizer.recognize(getImageFromAsset(FIST_IMAGE));
assertThat(actualResult.landmarks()).isEmpty();
assertThat(actualResult.worldLandmarks()).isEmpty();
assertThat(actualResult.handednesses()).isEmpty();
assertThat(actualResult.gestures()).isEmpty();
}
@Test
public void recognize_successWithPreferAlowListThanDenyList() throws Exception {
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
.build())
.setNumHands(1)
.setCannedGesturesClassifierOptions(
ClassifierOptions.builder()
.setScoreThreshold(0.5f)
.setCategoryAllowlist(Arrays.asList("Closed_Fist"))
.setCategoryDenylist(Arrays.asList("Closed_Fist"))
.build())
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
GestureRecognitionResult actualResult =
gestureRecognizer.recognize(getImageFromAsset(FIST_IMAGE));
GestureRecognitionResult expectedResult =
getExpectedGestureRecognitionResult(FIST_LANDMARKS, FIST_LABEL);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
@Test @Test
public void recognize_failsWithRegionOfInterest() throws Exception { public void recognize_failsWithRegionOfInterest() throws Exception {
GestureRecognizerOptions options = GestureRecognizerOptions options =