diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index 6296017d4..a88d0d72c 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -62,14 +62,18 @@ cc_library( "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:model_asset_bundle_resources", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources_cache", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/metadata/utils:zip_utils", + "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:combined_prediction_calculator", + "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:combined_prediction_calculator_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:handedness_to_matrix_calculator", "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator", "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto", @@ -77,8 +81,6 @@ cc_library( "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", "//mediapipe/tasks/metadata:metadata_schema_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc index c7147ea6e..cb95091d4 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc @@ -153,13 +153,12 @@ class CombinedPredictionCalculator : public Node { // After loop, if have winning prediction return. Otherwise empty packet. std::unique_ptr first_winning_prediction = nullptr; auto collection = kClassificationListIn(cc); - for (int idx = 0; idx < collection.Count(); ++idx) { - const auto& packet = collection[idx]; - if (packet.IsEmpty()) { + for (const auto& input : collection) { + if (input.IsEmpty() || input.Get().classification_size() == 0) { continue; } auto prediction = GetWinningPrediction( - packet.Get(), classwise_thresholds_, options_.background_label(), + input.Get(), classwise_thresholds_, options_.background_label(), options_.default_global_threshold()); if (prediction->classification(0).label() != options_.background_label()) { diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index d4ab16ac8..fa1fc69ce 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -146,6 +146,10 @@ ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) { ->mutable_canned_gesture_classifier_graph_options() ->mutable_classifier_options() ->set_score_threshold(options->min_gesture_confidence); + hand_gesture_recognizer_graph_options + ->mutable_custom_gesture_classifier_graph_options() + ->mutable_classifier_options() + ->set_score_threshold(options->min_gesture_confidence); } return options_proto; } diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 7b7746956..6e83f9eec 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -30,14 +30,17 @@ limitations under the License. #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.pb.h" @@ -58,6 +61,7 @@ using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::processors:: ConfigureTensorsToClassificationCalculator; using ::mediapipe::tasks::core::ModelAssetBundleResources; +using ::mediapipe::tasks::core::proto::BaseOptions; using ::mediapipe::tasks::metadata::SetExternalFile; using ::mediapipe::tasks::vision::gesture_recognizer::proto:: HandGestureRecognizerGraphOptions; @@ -78,13 +82,20 @@ constexpr char kVectorTag[] = "VECTOR"; constexpr char kIndexTag[] = "INDEX"; constexpr char kIterableTag[] = "ITERABLE"; constexpr char kBatchEndTag[] = "BATCH_END"; +constexpr char kPredictionTag[] = "PREDICTION"; +constexpr char kBackgroundLabel[] = "None"; constexpr char kGestureEmbedderTFLiteName[] = "gesture_embedder.tflite"; constexpr char kCannedGestureClassifierTFLiteName[] = "canned_gesture_classifier.tflite"; +constexpr char kCustomGestureClassifierTFLiteName[] = + "custom_gesture_classifier.tflite"; struct SubTaskModelResources { - const core::ModelResources* gesture_embedder_model_resource; - const core::ModelResources* canned_gesture_classifier_model_resource; + const core::ModelResources* gesture_embedder_model_resource = nullptr; + const core::ModelResources* canned_gesture_classifier_model_resource = + nullptr; + const core::ModelResources* custom_gesture_classifier_model_resource = + nullptr; }; Source> ConvertMatrixToTensor(Source matrix, @@ -94,41 +105,21 @@ Source> ConvertMatrixToTensor(Source matrix, return node[Output>{"TENSORS"}]; } -// Sets the base options in the sub tasks. -absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, - HandGestureRecognizerGraphOptions* options, - bool is_copy) { - ASSIGN_OR_RETURN(const auto gesture_embedder_file, - resources.GetModelFile(kGestureEmbedderTFLiteName)); - auto* gesture_embedder_graph_options = - options->mutable_gesture_embedder_graph_options(); - SetExternalFile(gesture_embedder_file, - gesture_embedder_graph_options->mutable_base_options() - ->mutable_model_asset(), - is_copy); - gesture_embedder_graph_options->mutable_base_options() - ->mutable_acceleration() - ->CopyFrom(options->base_options().acceleration()); - gesture_embedder_graph_options->mutable_base_options()->set_use_stream_mode( - options->base_options().use_stream_mode()); - - ASSIGN_OR_RETURN(const auto canned_gesture_classifier_file, - resources.GetModelFile(kCannedGestureClassifierTFLiteName)); - auto* canned_gesture_classifier_graph_options = - options->mutable_canned_gesture_classifier_graph_options(); - SetExternalFile( - canned_gesture_classifier_file, - canned_gesture_classifier_graph_options->mutable_base_options() - ->mutable_model_asset(), - is_copy); - canned_gesture_classifier_graph_options->mutable_base_options() - ->mutable_acceleration() - ->CopyFrom(options->base_options().acceleration()); - canned_gesture_classifier_graph_options->mutable_base_options() - ->set_use_stream_mode(options->base_options().use_stream_mode()); +absl::Status ConfigureCombinedPredictionCalculator( + CombinedPredictionCalculatorOptions* options) { + options->set_background_label(kBackgroundLabel); return absl::OkStatus(); } +void PopulateAccelerationAndUseStreamMode( + const BaseOptions& parent_base_options, + BaseOptions* sub_task_base_options) { + sub_task_base_options->mutable_acceleration()->CopyFrom( + parent_base_options.acceleration()); + sub_task_base_options->set_use_stream_mode( + parent_base_options.use_stream_mode()); +} + } // namespace // A @@ -212,6 +203,56 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { } private: + // Sets the base options in the sub tasks. + absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, + HandGestureRecognizerGraphOptions* options, + bool is_copy) { + ASSIGN_OR_RETURN(const auto gesture_embedder_file, + resources.GetModelFile(kGestureEmbedderTFLiteName)); + auto* gesture_embedder_graph_options = + options->mutable_gesture_embedder_graph_options(); + SetExternalFile(gesture_embedder_file, + gesture_embedder_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + PopulateAccelerationAndUseStreamMode( + options->base_options(), + gesture_embedder_graph_options->mutable_base_options()); + + ASSIGN_OR_RETURN( + const auto canned_gesture_classifier_file, + resources.GetModelFile(kCannedGestureClassifierTFLiteName)); + auto* canned_gesture_classifier_graph_options = + options->mutable_canned_gesture_classifier_graph_options(); + SetExternalFile( + canned_gesture_classifier_file, + canned_gesture_classifier_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + PopulateAccelerationAndUseStreamMode( + options->base_options(), + canned_gesture_classifier_graph_options->mutable_base_options()); + + const auto custom_gesture_classifier_file = + resources.GetModelFile(kCustomGestureClassifierTFLiteName); + if (custom_gesture_classifier_file.ok()) { + has_custom_gesture_classifier = true; + auto* custom_gesture_classifier_graph_options = + options->mutable_custom_gesture_classifier_graph_options(); + SetExternalFile( + custom_gesture_classifier_file.value(), + custom_gesture_classifier_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + PopulateAccelerationAndUseStreamMode( + options->base_options(), + custom_gesture_classifier_graph_options->mutable_base_options()); + } else { + LOG(INFO) << "Custom gesture classifier is not defined."; + } + return absl::OkStatus(); + } + absl::StatusOr CreateSubTaskModelResources( SubgraphContext* sc) { auto* options = sc->MutableOptions(); @@ -237,6 +278,19 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { std::make_unique( std::move(canned_gesture_classifier_model_asset)), "_canned_gesture_classifier")); + if (has_custom_gesture_classifier) { + auto& custom_gesture_classifier_model_asset = + *options->mutable_custom_gesture_classifier_graph_options() + ->mutable_base_options() + ->mutable_model_asset(); + ASSIGN_OR_RETURN( + sub_task_model_resources.custom_gesture_classifier_model_resource, + CreateModelResources( + sc, + std::make_unique( + std::move(custom_gesture_classifier_model_asset)), + "_custom_gesture_classifier")); + } return sub_task_model_resources; } @@ -302,7 +356,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { hand_world_landmarks_tensor >> concatenate_tensor_vector.In(2); auto concatenated_tensors = concatenate_tensor_vector.Out(""); - // Inference for static hand gesture recognition. + // Inference for gesture embedder. auto& gesture_embedder_inference = AddInference(*sub_task_model_resources.gesture_embedder_model_resource, graph_options.gesture_embedder_graph_options() @@ -310,34 +364,64 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { .acceleration(), graph); concatenated_tensors >> gesture_embedder_inference.In(kTensorsTag); - auto embedding_tensors = gesture_embedder_inference.Out(kTensorsTag); + auto embedding_tensors = + gesture_embedder_inference.Out(kTensorsTag).Cast(); - auto& canned_gesture_classifier_inference = AddInference( - *sub_task_model_resources.canned_gesture_classifier_model_resource, - graph_options.canned_gesture_classifier_graph_options() - .base_options() - .acceleration(), - graph); - embedding_tensors >> canned_gesture_classifier_inference.In(kTensorsTag); - auto inference_output_tensors = - canned_gesture_classifier_inference.Out(kTensorsTag); + auto& combine_predictions = graph.AddNode("CombinedPredictionCalculator"); + MP_RETURN_IF_ERROR(ConfigureCombinedPredictionCalculator( + &combine_predictions + .GetOptions())); + int classifier_nums = 0; + // Inference for custom gesture classifier if it exists. + if (has_custom_gesture_classifier) { + ASSIGN_OR_RETURN( + auto gesture_clasification_list, + GetGestureClassificationList( + sub_task_model_resources.custom_gesture_classifier_model_resource, + graph_options.custom_gesture_classifier_graph_options(), + embedding_tensors, graph)); + gesture_clasification_list >> combine_predictions.In(classifier_nums++); + } + + // Inference for canned gesture classifier. + ASSIGN_OR_RETURN( + auto gesture_clasification_list, + GetGestureClassificationList( + sub_task_model_resources.canned_gesture_classifier_model_resource, + graph_options.canned_gesture_classifier_graph_options(), + embedding_tensors, graph)); + gesture_clasification_list >> combine_predictions.In(classifier_nums++); + + auto combined_classification_list = + combine_predictions.Out(kPredictionTag).Cast(); + + return combined_classification_list; + } + + absl::StatusOr> GetGestureClassificationList( + const core::ModelResources* model_resources, + const proto::GestureClassifierGraphOptions& options, + Source& embedding_tensors, Graph& graph) { + auto& custom_gesture_classifier_inference = AddInference( + *model_resources, options.base_options().acceleration(), graph); + embedding_tensors >> custom_gesture_classifier_inference.In(kTensorsTag); + auto custom_gesture_inference_out_tensors = + custom_gesture_classifier_inference.Out(kTensorsTag); auto& tensors_to_classification = graph.AddNode("TensorsToClassificationCalculator"); MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator( - graph_options.canned_gesture_classifier_graph_options() - .classifier_options(), - *sub_task_model_resources.canned_gesture_classifier_model_resource - ->GetMetadataExtractor(), + options.classifier_options(), *model_resources->GetMetadataExtractor(), 0, &tensors_to_classification.GetOptions< mediapipe::TensorsToClassificationCalculatorOptions>())); - inference_output_tensors >> tensors_to_classification.In(kTensorsTag); - auto classification_list = - tensors_to_classification[Output( - "CLASSIFICATIONS")]; - return classification_list; + custom_gesture_inference_out_tensors >> + tensors_to_classification.In(kTensorsTag); + return tensors_to_classification.Out("CLASSIFICATIONS") + .Cast(); } + + bool has_custom_gesture_classifier = false; }; // clang-format off diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognitionResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognitionResult.java index fd764cb18..e9e1ebe8a 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognitionResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognitionResult.java @@ -31,6 +31,8 @@ import java.util.List; @AutoValue public abstract class GestureRecognitionResult implements TaskResult { + private static final int kGestureDefaultIndex = -1; + /** * Creates a {@link GestureRecognitionResult} instance from the lists of landmarks, handedness, * and gestures protobuf messages. @@ -97,7 +99,9 @@ public abstract class GestureRecognitionResult implements TaskResult { gestures.add( Category.create( classification.getScore(), - classification.getIndex(), + // Gesture index is not used, because the final gesture result comes from multiple + // classifiers. + kGestureDefaultIndex, classification.getLabel(), classification.getDisplayName())); } @@ -123,6 +127,10 @@ public abstract class GestureRecognitionResult implements TaskResult { /** Handedness of detected hands. */ public abstract List> handednesses(); - /** Recognized hand gestures of detected hands */ + /** + * Recognized hand gestures of detected hands. Note that the index of the gesture is always -1, + * because the raw indices from multiple gesture classifiers cannot consolidate to a meaningful + * index. + */ public abstract List> gestures(); } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java index eca5d35c2..2d4b3a50d 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java @@ -46,19 +46,24 @@ import org.junit.runners.Suite.SuiteClasses; @SuiteClasses({GestureRecognizerTest.General.class, GestureRecognizerTest.RunningModeTest.class}) public class GestureRecognizerTest { private static final String GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE = "gesture_recognizer.task"; + private static final String GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE = + "gesture_recognizer_with_custom_classifier.task"; private static final String TWO_HANDS_IMAGE = "right_hands.jpg"; private static final String THUMB_UP_IMAGE = "thumb_up.jpg"; private static final String POINTING_UP_ROTATED_IMAGE = "pointing_up_rotated.jpg"; private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg"; + private static final String FIST_IMAGE = "fist.jpg"; private static final String THUMB_UP_LANDMARKS = "thumb_up_landmarks.pb"; + private static final String FIST_LANDMARKS = "fist_landmarks.pb"; private static final String TAG = "Gesture Recognizer Test"; private static final String THUMB_UP_LABEL = "Thumb_Up"; - private static final int THUMB_UP_INDEX = 5; private static final String POINTING_UP_LABEL = "Pointing_Up"; - private static final int POINTING_UP_INDEX = 3; + private static final String FIST_LABEL = "Closed_Fist"; + private static final String ROCK_LABEL = "Rock"; private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f; private static final int IMAGE_WIDTH = 382; private static final int IMAGE_HEIGHT = 406; + private static final int GESTURE_EXPECTED_INDEX = -1; @RunWith(AndroidJUnit4.class) public static final class General extends GestureRecognizerTest { @@ -77,7 +82,7 @@ public class GestureRecognizerTest { GestureRecognitionResult actualResult = gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE)); GestureRecognitionResult expectedResult = - getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL); assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); } @@ -108,16 +113,14 @@ public class GestureRecognizerTest { BaseOptions.builder() .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) .build()) - // TODO update the confidence to be in range [0,1] after embedding model - // and scoring calculator is integrated. - .setMinGestureConfidence(2.0f) + .setMinGestureConfidence(0.5f) .build(); GestureRecognizer gestureRecognizer = GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); GestureRecognitionResult actualResult = gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE)); GestureRecognitionResult expectedResult = - getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL); // Only contains one top scoring gesture. assertThat(actualResult.gestures().get(0)).hasSize(1); assertActualGestureEqualExpectedGesture( @@ -159,10 +162,48 @@ public class GestureRecognizerTest { gestureRecognizer.recognize( getImageFromAsset(POINTING_UP_ROTATED_IMAGE), imageProcessingOptions); assertThat(actualResult.gestures()).hasSize(1); - assertThat(actualResult.gestures().get(0).get(0).index()).isEqualTo(POINTING_UP_INDEX); assertThat(actualResult.gestures().get(0).get(0).categoryName()).isEqualTo(POINTING_UP_LABEL); } + @Test + public void recognize_successWithCannedGestureFist() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) + .setNumHands(1) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize(getImageFromAsset(FIST_IMAGE)); + GestureRecognitionResult expectedResult = + getExpectedGestureRecognitionResult(FIST_LANDMARKS, FIST_LABEL); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void recognize_successWithCustomGestureRock() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath( + GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE) + .build()) + .setNumHands(1) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize(getImageFromAsset(FIST_IMAGE)); + GestureRecognitionResult expectedResult = + getExpectedGestureRecognitionResult(FIST_LANDMARKS, ROCK_LABEL); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + @Test public void recognize_failsWithRegionOfInterest() throws Exception { GestureRecognizerOptions options = @@ -331,7 +372,7 @@ public class GestureRecognizerTest { GestureRecognitionResult actualResult = gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE)); GestureRecognitionResult expectedResult = - getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL); assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); } @@ -348,7 +389,7 @@ public class GestureRecognizerTest { GestureRecognizer gestureRecognizer = GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); GestureRecognitionResult expectedResult = - getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL); for (int i = 0; i < 3; i++) { GestureRecognitionResult actualResult = gestureRecognizer.recognizeForVideo( @@ -361,7 +402,7 @@ public class GestureRecognizerTest { public void recognize_failsWithOutOfOrderInputTimestamps() throws Exception { MPImage image = getImageFromAsset(THUMB_UP_IMAGE); GestureRecognitionResult expectedResult = - getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL); GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( @@ -393,7 +434,7 @@ public class GestureRecognizerTest { public void recognize_successWithLiveSteamMode() throws Exception { MPImage image = getImageFromAsset(THUMB_UP_IMAGE); GestureRecognitionResult expectedResult = - getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL); GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( @@ -423,7 +464,7 @@ public class GestureRecognizerTest { } private static GestureRecognitionResult getExpectedGestureRecognitionResult( - String filePath, String gestureLabel, int gestureIndex) throws Exception { + String filePath, String gestureLabel) throws Exception { AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); InputStream istr = assetManager.open(filePath); LandmarksDetectionResult landmarksDetectionResultProto = @@ -431,9 +472,7 @@ public class GestureRecognizerTest { ClassificationProto.ClassificationList gesturesProto = ClassificationProto.ClassificationList.newBuilder() .addClassification( - ClassificationProto.Classification.newBuilder() - .setLabel(gestureLabel) - .setIndex(gestureIndex)) + ClassificationProto.Classification.newBuilder().setLabel(gestureLabel)) .build(); return GestureRecognitionResult.create( Arrays.asList(landmarksDetectionResultProto.getLandmarks()), @@ -479,8 +518,8 @@ public class GestureRecognizerTest { private static void assertActualGestureEqualExpectedGesture( Category actualGesture, Category expectedGesture) { - assertThat(actualGesture.index()).isEqualTo(actualGesture.index()); - assertThat(expectedGesture.categoryName()).isEqualTo(expectedGesture.categoryName()); + assertThat(actualGesture.categoryName()).isEqualTo(expectedGesture.categoryName()); + assertThat(actualGesture.index()).isEqualTo(GESTURE_EXPECTED_INDEX); } private static void assertImageSizeIsExpected(MPImage inputImage) { diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index c45cc6e69..ad8072b87 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -37,6 +37,7 @@ mediapipe_files(srcs = [ "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", + "fist.jpg", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "left_hands.jpg", @@ -64,6 +65,7 @@ mediapipe_files(srcs = [ "selfie_segm_144_256_3.tflite", "selfie_segm_144_256_3_expected_mask.jpg", "thumb_up.jpg", + "victory.jpg", ]) exports_files( @@ -91,6 +93,7 @@ filegroup( "cats_and_dogs.jpg", "cats_and_dogs_no_resizing.jpg", "cats_and_dogs_rotated.jpg", + "fist.jpg", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "left_hands.jpg", @@ -107,6 +110,7 @@ filegroup( "selfie_segm_128_128_3_expected_mask.jpg", "selfie_segm_144_256_3_expected_mask.jpg", "thumb_up.jpg", + "victory.jpg", ], visibility = [ "//mediapipe/python:__subpackages__", @@ -148,6 +152,7 @@ filegroup( "expected_left_up_hand_rotated_landmarks.prototxt", "expected_right_down_hand_landmarks.prototxt", "expected_right_up_hand_landmarks.prototxt", + "fist_landmarks.pbtxt", "hand_detector_result_one_hand.pbtxt", "hand_detector_result_one_hand_rotated.pbtxt", "hand_detector_result_two_hands.pbtxt", @@ -155,5 +160,6 @@ filegroup( "pointing_up_rotated_landmarks.pbtxt", "thumb_up_landmarks.pbtxt", "thumb_up_rotated_landmarks.pbtxt", + "victory_landmarks.pbtxt", ], ) diff --git a/mediapipe/tasks/testdata/vision/fist_landmarks.pbtxt b/mediapipe/tasks/testdata/vision/fist_landmarks.pbtxt new file mode 100644 index 000000000..a24358c3c --- /dev/null +++ b/mediapipe/tasks/testdata/vision/fist_landmarks.pbtxt @@ -0,0 +1,223 @@ +classifications { + classification { + score: 1.0 + label: "Left" + display_name: "Left" + } +} + +landmarks { + landmark { + x: 0.47709703 + y: 0.66129065 + z: -3.3540672e-07 + } + landmark { + x: 0.6125982 + y: 0.5578249 + z: -0.041392017 + } + landmark { + x: 0.71123487 + y: 0.4316616 + z: -0.064544134 + } + landmark { + x: 0.6836403 + y: 0.3199585 + z: -0.08752567 + } + landmark { + x: 0.5593274 + y: 0.3206453 + z: -0.09880819 + } + landmark { + x: 0.60828537 + y: 0.3068749 + z: -0.014799656 + } + landmark { + x: 0.62940764 + y: 0.21414441 + z: -0.06007311 + } + landmark { + x: 0.6244353 + y: 0.32872596 + z: -0.08326768 + } + landmark { + x: 0.60784453 + y: 0.3684796 + z: -0.09658983 + } + landmark { + x: 0.5156504 + y: 0.32194698 + z: -0.021699267 + } + landmark { + x: 0.52931 + y: 0.24767634 + z: -0.062571 + } + landmark { + x: 0.5484773 + y: 0.3805329 + z: -0.07028895 + } + landmark { + x: 0.54428184 + y: 0.3881125 + z: -0.07458326 + } + landmark { + x: 0.43159598 + y: 0.34918433 + z: -0.037482508 + } + landmark { + x: 0.4486106 + y: 0.27649382 + z: -0.08174769 + } + landmark { + x: 0.47723144 + y: 0.3964985 + z: -0.06496752 + } + landmark { + x: 0.46794242 + y: 0.4082967 + z: -0.04897496 + } + landmark { + x: 0.34826216 + y: 0.37813392 + z: -0.057438444 + } + landmark { + x: 0.3861837 + y: 0.32820183 + z: -0.07282783 + } + landmark { + x: 0.41143674 + y: 0.39734486 + z: -0.047633167 + } + landmark { + x: 0.39401984 + y: 0.41149133 + z: -0.029640475 + } +} + +world_landmarks { + landmark { + x: -0.008604452 + y: 0.08165767 + z: 0.0061365655 + } + landmark { + x: 0.027301773 + y: 0.061905317 + z: -0.00872007 + } + landmark { + x: 0.049898714 + y: 0.035359327 + z: -0.016682662 + } + landmark { + x: 0.050297678 + y: 0.005200807 + z: -0.028928496 + } + landmark { + x: 0.015639625 + y: -0.0063155442 + z: -0.03174634 + } + landmark { + x: 0.029161729 + y: -0.0024596984 + z: 0.0011553494 + } + landmark { + x: 0.034491 + y: -0.017581237 + z: -0.020781275 + } + landmark { + x: 0.034020264 + y: -0.0059247985 + z: -0.02573838 + } + landmark { + x: 0.02867364 + y: 0.011137734 + z: -0.009430941 + } + landmark { + x: 0.0015385814 + y: -0.004778851 + z: 0.0056454404 + } + landmark { + x: 0.010490709 + y: -0.019680617 + z: -0.027034117 + } + landmark { + x: 0.0132071925 + y: 0.0071370844 + z: -0.034802448 + } + landmark { + x: 0.0139978565 + y: 0.011672501 + z: -0.0040006908 + } + landmark { + x: -0.019919239 + y: -0.0006897822 + z: -0.0003317799 + } + landmark { + x: -0.01088193 + y: -0.008502296 + z: -0.02873486 + } + landmark { + x: -0.005327127 + y: 0.012745364 + z: -0.034153957 + } + landmark { + x: -0.0027040644 + y: 0.02167169 + z: -0.011669062 + } + landmark { + x: -0.038813893 + y: 0.011925209 + z: -0.0076287366 + } + landmark { + x: -0.030842202 + y: 0.0010964936 + z: -0.022697516 + } + landmark { + x: -0.01829514 + y: 0.013929318 + z: -0.032819964 + } + landmark { + x: -0.024175374 + y: 0.022456694 + z: -0.02357186 + } +} diff --git a/mediapipe/tasks/testdata/vision/gesture_recognizer_with_custom_classifier.task b/mediapipe/tasks/testdata/vision/gesture_recognizer_with_custom_classifier.task new file mode 100644 index 000000000..3c1da7b3d Binary files /dev/null and b/mediapipe/tasks/testdata/vision/gesture_recognizer_with_custom_classifier.task differ diff --git a/mediapipe/tasks/testdata/vision/hand_gesture_recognizer_with_custom_classifier.task b/mediapipe/tasks/testdata/vision/hand_gesture_recognizer_with_custom_classifier.task new file mode 100644 index 000000000..1390ca88d Binary files /dev/null and b/mediapipe/tasks/testdata/vision/hand_gesture_recognizer_with_custom_classifier.task differ diff --git a/mediapipe/tasks/testdata/vision/victory_landmarks.pbtxt b/mediapipe/tasks/testdata/vision/victory_landmarks.pbtxt new file mode 100644 index 000000000..7a704ee36 --- /dev/null +++ b/mediapipe/tasks/testdata/vision/victory_landmarks.pbtxt @@ -0,0 +1,223 @@ +classifications { + classification { + score: 1.0 + label: "Left" + display_name: "Left" + } +} + +landmarks { + landmark { + x: 0.5164316 + y: 0.804093 + z: 8.7653416e-07 + } + landmark { + x: 0.6063608 + y: 0.7111354 + z: -0.044089418 + } + landmark { + x: 0.6280186 + y: 0.588498 + z: -0.062358405 + } + landmark { + x: 0.5265348 + y: 0.52083343 + z: -0.08526791 + } + landmark { + x: 0.4243384 + y: 0.4993468 + z: -0.1077741 + } + landmark { + x: 0.5605667 + y: 0.4489705 + z: -0.016151091 + } + landmark { + x: 0.5766643 + y: 0.32260323 + z: -0.049342215 + } + landmark { + x: 0.5795845 + y: 0.24180722 + z: -0.07323826 + } + landmark { + x: 0.5827511 + y: 0.16940045 + z: -0.09069163 + } + landmark { + x: 0.4696163 + y: 0.4599558 + z: -0.032168437 + } + landmark { + x: 0.44361597 + y: 0.31689578 + z: -0.075698614 + } + landmark { + x: 0.42695498 + y: 0.22273324 + z: -0.10819675 + } + landmark { + x: 0.40697217 + y: 0.14279765 + z: -0.12666894 + } + landmark { + x: 0.39543492 + y: 0.50612336 + z: -0.055138163 + } + landmark { + x: 0.3618012 + y: 0.4388296 + z: -0.1298119 + } + landmark { + x: 0.4154368 + y: 0.52674913 + z: -0.1463017 + } + landmark { + x: 0.44916254 + y: 0.59442246 + z: -0.13470782 + } + landmark { + x: 0.33178204 + y: 0.5731769 + z: -0.08103096 + } + landmark { + x: 0.3092102 + y: 0.5040002 + z: -0.13258384 + } + landmark { + x: 0.35576707 + y: 0.5576498 + z: -0.12714732 + } + landmark { + x: 0.393444 + y: 0.6118667 + z: -0.11102459 + } +} + +world_landmarks { + landmark { + x: 0.01299962 + y: 0.09162361 + z: 0.011185312 + } + landmark { + x: 0.03726317 + y: 0.0638103 + z: -0.010005756 + } + landmark { + x: 0.03975261 + y: 0.03712649 + z: -0.02906275 + } + landmark { + x: 0.018798776 + y: 0.012429599 + z: -0.048737116 + } + landmark { + x: -0.0128555335 + y: 0.001022811 + z: -0.044505004 + } + landmark { + x: 0.025658218 + y: -0.008031519 + z: -0.0058278795 + } + landmark { + x: 0.028017294 + y: -0.038120236 + z: -0.010376478 + } + landmark { + x: 0.030067094 + y: -0.059907563 + z: -0.014568218 + } + landmark { + x: 0.027284538 + y: -0.07803874 + z: -0.032692235 + } + landmark { + x: 0.0013260426 + y: -0.005039873 + z: 0.005567288 + } + landmark { + x: -0.002380834 + y: -0.044605374 + z: -0.0038231965 + } + landmark { + x: -0.009240147 + y: -0.066279344 + z: -0.02161214 + } + landmark { + x: -0.0092535615 + y: -0.08933755 + z: -0.037401434 + } + landmark { + x: -0.01751284 + y: 0.0037118336 + z: 0.0047480655 + } + landmark { + x: -0.02195602 + y: -0.010006189 + z: -0.02371484 + } + landmark { + x: -0.012851426 + y: 0.008346066 + z: -0.037721373 + } + landmark { + x: -0.00018795021 + y: 0.026816685 + z: -0.03732748 + } + landmark { + x: -0.034864448 + y: 0.022316 + z: -0.0002774651 + } + landmark { + x: -0.035896845 + y: 0.01066218 + z: -0.017325373 + } + landmark { + x: -0.02358637 + y: 0.018667895 + z: -0.028403495 + } + landmark { + x: -0.013704676 + y: 0.033456434 + z: -0.02595728 + } +} diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 4b7309eef..4dcbc3bd9 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -244,6 +244,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/feature_tensor_meta.json?generation=1665422818797346"], ) + http_file( + name = "com_google_mediapipe_fist_jpg", + sha256 = "43fa1cabf3f90d574accc9a56986e2ee48638ce59fc65af1846487f73bb2ef24", + urls = ["https://storage.googleapis.com/mediapipe-assets/fist.jpg?generation=1666999359066679"], + ) + + http_file( + name = "com_google_mediapipe_fist_landmarks_pbtxt", + sha256 = "76d6489e6163211ce5e9080e51983165bb9b24ff50146cc7487bd629f011c598", + urls = ["https://storage.googleapis.com/mediapipe-assets/fist_landmarks.pbtxt?generation=1666999360561864"], + ) + http_file( name = "com_google_mediapipe_general_meta_json", sha256 = "b95363e4bae89b9c2af484498312aaad4efc7ff57c7eadcc4e5e7adca641445f", @@ -838,6 +850,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/universal_sentence_encoder_qa_with_metadata.tflite?generation=1665445919252005"], ) + http_file( + name = "com_google_mediapipe_victory_jpg", + sha256 = "84cb8853e3df614e0cb5c93a25e3e2f38ea5e4f92fd428ee7d867ed3479d5764", + urls = ["https://storage.googleapis.com/mediapipe-assets/victory.jpg?generation=1666999364225126"], + ) + + http_file( + name = "com_google_mediapipe_victory_landmarks_pbtxt", + sha256 = "b25ab4f222674489f543afb6454396ecbc1437a7ae6213dbf0553029ae939ab0", + urls = ["https://storage.googleapis.com/mediapipe-assets/victory_landmarks.pbtxt?generation=1666999366036622"], + ) + http_file( name = "com_google_mediapipe_vocab_for_regex_tokenizer_txt", sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923",