From f8af41b1eb49ff4bdad756ff19d1d36f486be614 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 28 Sep 2022 20:35:30 +0000 Subject: [PATCH] Internal change PiperOrigin-RevId: 477538515 --- WORKSPACE | 7 + docs/solutions/pose.md | 4 +- docs/solutions/selfie_segmentation.md | 2 +- mediapipe/calculators/core/BUILD | 28 +- .../calculators/core/end_loop_calculator.cc | 3 + .../core/get_vector_item_calculator.cc | 4 + .../core/get_vector_item_calculator.h | 7 +- .../core/get_vector_item_calculator_test.cc | 11 + .../core/vector_indices_calculator.cc | 33 + .../core/vector_indices_calculator.h | 65 + .../core/vector_indices_calculator_test.cc | 87 ++ mediapipe/calculators/tensor/BUILD | 68 + .../tensor/audio_to_tensor_calculator.cc | 228 +++- .../tensor/audio_to_tensor_calculator.proto | 24 + .../tensor/audio_to_tensor_calculator_test.cc | 215 +++- .../tensor/feedback_tensors_calculator.cc | 165 +++ .../tensor/feedback_tensors_calculator.proto | 47 + .../feedback_tensors_calculator_test.cc | 389 ++++++ ...ession_from_saved_model_calculator_test.cc | 2 +- ...session_from_saved_model_generator_test.cc | 2 +- mediapipe/calculators/tflite/BUILD | 1 + .../tflite/ssd_anchors_calculator.cc | 157 ++- .../tflite/ssd_anchors_calculator.proto | 26 + .../tflite/ssd_anchors_calculator_test.cc | 41 +- .../tflite/testdata/anchor_golden_file_2.txt | 1125 +++++++++++++++++ .../tflite/tflite_model_calculator.cc | 40 +- .../AppIcon.appiconset/83.5_c_Ipad_2x.png | Bin 0 -> 9566 bytes .../AppIcon.appiconset/Contents.json | 1 + mediapipe/framework/api2/BUILD | 2 + mediapipe/framework/api2/builder.h | 36 +- mediapipe/framework/api2/builder_test.cc | 93 +- mediapipe/framework/api2/packet.h | 4 +- mediapipe/framework/api2/port.h | 37 +- .../framework/formats/image_frame_opencv.cc | 12 +- .../framework/formats/image_frame_opencv.h | 2 +- .../formats/image_frame_opencv_test.cc | 38 +- mediapipe/framework/formats/image_opencv.cc | 2 +- mediapipe/framework/formats/image_opencv.h | 2 +- mediapipe/framework/formats/tensor.cc | 47 +- mediapipe/framework/formats/tensor.h | 6 +- mediapipe/framework/mediapipe_cc_test.bzl | 2 +- mediapipe/framework/profiler/BUILD | 1 + .../framework/profiler/graph_profiler.cc | 12 +- .../framework/testdata/perfetto_minimal.pbtxt | 28 +- mediapipe/framework/tool/BUILD | 13 + .../framework/tool/switch_demux_calculator.cc | 21 +- .../tool/switch_demux_calculator_test.cc | 135 ++ mediapipe/gpu/BUILD | 34 + mediapipe/gpu/gl_context.cc | 15 +- mediapipe/gpu/gl_context.h | 4 + mediapipe/gpu/gl_texture_buffer.cc | 18 +- mediapipe/gpu/gpu_buffer.cc | 24 +- mediapipe/gpu/gpu_buffer.h | 2 + .../graphs/pose_tracking/subgraphs/BUILD | 22 +- .../pose_landmarks_to_render_data.pbtxt | 236 ++++ .../subgraphs/pose_renderer_cpu.pbtxt | 219 +--- .../subgraphs/pose_renderer_gpu.pbtxt | 219 +--- .../components/ExternalTextureConverter.java | 16 +- .../framework/AndroidPacketCreator.java | 48 + .../java/com/google/mediapipe/framework/BUILD | 2 + .../framework/image/AndroidManifest.xml | 6 + .../google/mediapipe/framework/image/BUILD | 32 + .../framework/image/BitmapExtractor.java | 49 + .../framework/image/BitmapImageBuilder.java | 72 ++ .../framework/image/BitmapImageContainer.java | 60 + .../framework/image/ByteBufferExtractor.java | 254 ++++ .../image/ByteBufferImageBuilder.java | 71 ++ .../image/ByteBufferImageContainer.java | 58 + .../mediapipe/framework/image/Image.java | 241 ++++ .../framework/image/ImageConsumer.java | 27 + .../framework/image/ImageContainer.java | 25 + .../framework/image/ImageProducer.java | 22 + .../framework/image/ImageProperties.java | 80 ++ .../framework/image/MediaImageBuilder.java | 62 + .../framework/image/MediaImageContainer.java | 73 ++ .../framework/image/MediaImageExtractor.java | 49 + .../framework/jni/graph_texture_frame_jni.cc | 11 +- .../com/google/mediapipe/mediapipe_aar.bzl | 1 + .../audio_classifier/audio_classifier.cc | 7 +- .../audio_classifier_graph.cc | 15 +- .../audio_classifier/audio_classifier_test.cc | 41 +- .../cc/audio/audio_classifier/proto/BUILD | 2 +- .../proto/audio_classifier_options.proto | 4 +- mediapipe/tasks/cc/components/BUILD | 56 +- .../tasks/cc/components/calculators/BUILD | 63 + .../calculators/end_loop_calculator.cc | 29 + .../cc/components/calculators/tensor/BUILD | 4 +- .../tensors_to_segmentation_calculator.cc | 13 +- .../tensors_to_segmentation_calculator.proto | 4 +- ...tensors_to_segmentation_calculator_test.cc | 14 +- .../tensors_to_embeddings_calculator.cc | 158 +++ .../tensors_to_embeddings_calculator.proto | 35 + .../tensors_to_embeddings_calculator_test.cc | 249 ++++ .../classification_postprocessing.cc | 124 +- .../classification_postprocessing.h | 10 +- ...lassification_postprocessing_options.proto | 7 +- .../classification_postprocessing_test.cc | 143 ++- .../tasks/cc/components/classifier_options.cc | 6 +- .../tasks/cc/components/classifier_options.h | 4 +- .../cc/components/containers/proto/BUILD | 5 + .../containers/proto/embeddings.proto | 56 + .../tasks/cc/components/embedder_options.cc | 34 + .../tasks/cc/components/embedder_options.h | 47 + .../embedding_postprocessing_graph.cc | 232 ++++ .../embedding_postprocessing_graph.h | 61 + .../embedding_postprocessing_graph_test.cc | 136 ++ .../cc/components/image_preprocessing.cc | 40 +- .../tasks/cc/components/image_preprocessing.h | 7 +- .../image_preprocessing_options.proto | 11 +- mediapipe/tasks/cc/components/proto/BUILD | 53 + .../{ => proto}/classifier_options.proto | 2 +- .../components/proto/embedder_options.proto | 33 + ...bedding_postprocessing_graph_options.proto | 38 + .../{ => proto}/segmenter_options.proto | 2 +- .../text_preprocessing_graph_options.proto | 40 + .../cc/components/text_preprocessing_graph.cc | 266 ++++ .../cc/components/text_preprocessing_graph.h | 58 + mediapipe/tasks/cc/components/utils/BUILD | 44 + .../cc/components/utils/cosine_similarity.cc | 112 ++ .../cc/components/utils/cosine_similarity.h | 42 + .../utils/cosine_similarity_test.cc | 111 ++ .../components/utils/source_or_node_output.h | 66 + mediapipe/tasks/cc/core/BUILD | 2 + mediapipe/tasks/cc/core/base_options.cc | 39 +- mediapipe/tasks/cc/core/base_options.h | 19 +- mediapipe/tasks/cc/core/model_task_graph.cc | 12 +- mediapipe/tasks/cc/core/model_task_graph.h | 7 +- .../tasks/cc/core/proto/acceleration.proto | 2 +- .../tasks/cc/core/proto/base_options.proto | 4 +- .../tasks/cc/metadata/metadata_extractor.cc | 2 +- .../tasks/cc/metadata/metadata_extractor.h | 2 +- .../cc/metadata/metadata_parser.h.template | 2 +- .../tasks/cc/metadata/metadata_populator.cc | 2 +- .../tasks/cc/metadata/metadata_populator.h | 2 +- .../tasks/cc/metadata/metadata_version.cc | 2 +- .../tasks/cc/metadata/metadata_version.h | 2 +- .../cc/metadata/python/metadata_version.cc | 2 +- .../metadata/tests/metadata_extractor_test.cc | 2 +- .../cc/metadata/tests/metadata_parser_test.cc | 2 +- .../metadata/tests/metadata_version_test.cc | 2 +- .../metadata/utils/zip_readonly_mem_file.cc | 2 +- .../cc/metadata/utils/zip_readonly_mem_file.h | 2 +- .../metadata/utils/zip_writable_mem_file.cc | 2 +- .../cc/metadata/utils/zip_writable_mem_file.h | 2 +- .../cc/{components => text}/tokenizers/BUILD | 0 .../tokenizers/bert_tokenizer.cc | 8 +- .../tokenizers/bert_tokenizer.h | 18 +- .../tokenizers/bert_tokenizer_test.cc | 8 +- .../tokenizers/regex_tokenizer.cc | 8 +- .../tokenizers/regex_tokenizer.h | 14 +- .../tokenizers/regex_tokenizer_test.cc | 8 +- .../tokenizers/sentencepiece_tokenizer.h | 14 +- .../sentencepiece_tokenizer_test.cc | 8 +- .../tokenizers/tokenizer.h | 12 +- .../tokenizers/tokenizer_utils.cc | 12 +- .../tokenizers/tokenizer_utils.h | 16 +- .../tokenizers/tokenizer_utils_test.cc | 17 +- mediapipe/tasks/cc/vision/hand_detector/BUILD | 72 ++ .../hand_detector/hand_detector_graph.cc | 320 +++++ .../hand_detector/hand_detector_graph_test.cc | 205 +++ .../hand_detector_op_resolver.cc | 35 + .../hand_detector/hand_detector_op_resolver.h | 34 + .../tasks/cc/vision/hand_detector/proto/BUILD | 40 + .../proto/hand_detector_options.proto | 44 + .../proto/hand_detector_result.proto | 30 + .../cc/vision/hand_gesture_recognizer/BUILD | 5 +- .../hand_gesture_recognizer/calculators/BUILD | 14 +- ...and_landmarks_to_matrix_calculator_test.cc | 163 --- ...r.cc => landmarks_to_matrix_calculator.cc} | 129 +- .../landmarks_to_matrix_calculator_test.cc | 207 +++ .../hand_gesture_recognizer_subgraph.cc | 214 +++- .../hand_gesture_recognizer/proto/BUILD | 17 +- ..._gesture_recognizer_subgraph_options.proto | 4 +- .../landmarks_to_matrix_calculator.proto | 39 + .../{hand_landmark => hand_landmarker}/BUILD | 18 +- .../hand_landmarker_subgraph.cc} | 239 +++- .../hand_landmarker_subgraph_test.cc | 467 +++++++ .../cc/vision/hand_landmarker/proto/BUILD | 43 + .../proto/hand_landmarker_options.proto | 40 + .../hand_landmarker_subgraph_options.proto} | 6 +- .../image_classification/image_classifier.cc | 89 -- .../image_classification/image_classifier.h | 89 -- .../image_classifier_test.cc | 411 ------ .../BUILD | 34 +- .../image_classifier/image_classifier.cc | 219 ++++ .../image_classifier/image_classifier.h | 168 +++ .../image_classifier_graph.cc | 72 +- .../image_classifier/image_classifier_test.cc | 819 ++++++++++++ .../cc/vision/image_classifier/proto/BUILD | 30 + .../proto}/image_classifier_options.proto | 6 +- .../tasks/cc/vision/image_embedder/BUILD | 68 + .../vision/image_embedder/image_embedder.cc | 218 ++++ .../cc/vision/image_embedder/image_embedder.h | 161 +++ .../image_embedder/image_embedder_graph.cc | 172 +++ .../image_embedder/image_embedder_test.cc | 557 ++++++++ .../cc/vision/image_embedder/proto/BUILD | 30 + .../proto/image_embedder_graph_options.proto | 35 + .../tasks/cc/vision/image_segmenter/BUILD | 5 +- .../vision/image_segmenter/image_segmenter.cc | 56 +- .../vision/image_segmenter/image_segmenter.h | 54 +- .../image_segmenter/image_segmenter_graph.cc | 21 +- .../image_segmenter/image_segmenter_test.cc | 202 ++- .../cc/vision/image_segmenter/proto/BUILD | 2 +- .../proto/image_segmenter_options.proto | 4 +- .../tasks/cc/vision/object_detector/BUILD | 9 + .../vision/object_detector/object_detector.cc | 2 +- .../vision/object_detector/object_detector.h | 8 +- .../object_detector/object_detector_graph.cc | 136 +- .../object_detector/object_detector_test.cc | 75 +- mediapipe/tasks/java/BUILD | 1 + .../tasks/components/containers/BUILD | 13 + .../com/google/mediapipe/tasks/core/BUILD | 13 + .../google/mediapipe/tasks/vision/core/BUILD | 13 + .../vision/objectdetector/AndroidManifest.xml | 8 + .../tasks/vision/objectdetector/BUILD | 13 + mediapipe/tasks/javatests/BUILD | 1 + .../com/google/mediapipe/tasks/core/BUILD | 13 + .../vision/objectdetector/AndroidManifest.xml | 24 + .../tasks/vision/objectdetector/BUILD | 15 + mediapipe/tasks/metadata/metadata_schema.fbs | 2 +- .../components/containers/bounding_box.py | 2 +- .../python/components/containers/category.py | 2 +- .../components/containers/detections.py | 2 +- mediapipe/tasks/python/core/base_options.py | 26 +- mediapipe/tasks/python/test/test_util.py | 2 +- mediapipe/tasks/python/test/vision/BUILD | 19 +- .../test/vision/object_detector_test.py | 132 +- .../tasks/python/vision/object_detector.py | 56 +- mediapipe/tasks/testdata/text/BUILD | 12 +- mediapipe/tasks/testdata/vision/BUILD | 16 +- .../hand_detector_result_one_hand.pbtxt | 33 + .../hand_detector_result_two_hands.pbtxt | 40 + mediapipe/util/sequence/media_sequence.h | 7 + mediapipe/util/sequence/media_sequence.py | 12 + third_party/external_files.bzl | 86 +- third_party/pffft.BUILD | 19 + 236 files changed, 12865 insertions(+), 1923 deletions(-) create mode 100644 mediapipe/calculators/core/vector_indices_calculator.cc create mode 100644 mediapipe/calculators/core/vector_indices_calculator.h create mode 100644 mediapipe/calculators/core/vector_indices_calculator_test.cc create mode 100644 mediapipe/calculators/tensor/feedback_tensors_calculator.cc create mode 100644 mediapipe/calculators/tensor/feedback_tensors_calculator.proto create mode 100644 mediapipe/calculators/tensor/feedback_tensors_calculator_test.cc create mode 100644 mediapipe/calculators/tflite/testdata/anchor_golden_file_2.txt create mode 100644 mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/83.5_c_Ipad_2x.png create mode 100644 mediapipe/framework/tool/switch_demux_calculator_test.cc create mode 100644 mediapipe/graphs/pose_tracking/subgraphs/pose_landmarks_to_render_data.pbtxt create mode 100644 mediapipe/java/com/google/mediapipe/framework/image/AndroidManifest.xml create mode 100644 mediapipe/java/com/google/mediapipe/framework/image/BUILD create mode 100644 mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java create mode 100644 mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java create mode 100644 mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java create mode 100644 mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java create mode 100644 mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java create mode 100644 mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java create mode 100644 mediapipe/java/com/google/mediapipe/framework/image/Image.java create mode 100644 mediapipe/java/com/google/mediapipe/framework/image/ImageConsumer.java create mode 100644 mediapipe/java/com/google/mediapipe/framework/image/ImageContainer.java create mode 100644 mediapipe/java/com/google/mediapipe/framework/image/ImageProducer.java create mode 100644 mediapipe/java/com/google/mediapipe/framework/image/ImageProperties.java create mode 100644 mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java create mode 100644 mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java create mode 100644 mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java create mode 100644 mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc create mode 100644 mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.cc create mode 100644 mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto create mode 100644 mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator_test.cc create mode 100644 mediapipe/tasks/cc/components/containers/proto/embeddings.proto create mode 100644 mediapipe/tasks/cc/components/embedder_options.cc create mode 100644 mediapipe/tasks/cc/components/embedder_options.h create mode 100644 mediapipe/tasks/cc/components/embedding_postprocessing_graph.cc create mode 100644 mediapipe/tasks/cc/components/embedding_postprocessing_graph.h create mode 100644 mediapipe/tasks/cc/components/embedding_postprocessing_graph_test.cc create mode 100644 mediapipe/tasks/cc/components/proto/BUILD rename mediapipe/tasks/cc/components/{ => proto}/classifier_options.proto (97%) create mode 100644 mediapipe/tasks/cc/components/proto/embedder_options.proto create mode 100644 mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.proto rename mediapipe/tasks/cc/components/{ => proto}/segmenter_options.proto (97%) create mode 100644 mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto create mode 100644 mediapipe/tasks/cc/components/text_preprocessing_graph.cc create mode 100644 mediapipe/tasks/cc/components/text_preprocessing_graph.h create mode 100644 mediapipe/tasks/cc/components/utils/BUILD create mode 100644 mediapipe/tasks/cc/components/utils/cosine_similarity.cc create mode 100644 mediapipe/tasks/cc/components/utils/cosine_similarity.h create mode 100644 mediapipe/tasks/cc/components/utils/cosine_similarity_test.cc create mode 100644 mediapipe/tasks/cc/components/utils/source_or_node_output.h rename mediapipe/tasks/cc/{components => text}/tokenizers/BUILD (100%) rename mediapipe/tasks/cc/{components => text}/tokenizers/bert_tokenizer.cc (96%) rename mediapipe/tasks/cc/{components => text}/tokenizers/bert_tokenizer.h (92%) rename mediapipe/tasks/cc/{components => text}/tokenizers/bert_tokenizer_test.cc (97%) rename mediapipe/tasks/cc/{components => text}/tokenizers/regex_tokenizer.cc (96%) rename mediapipe/tasks/cc/{components => text}/tokenizers/regex_tokenizer.h (84%) rename mediapipe/tasks/cc/{components => text}/tokenizers/regex_tokenizer_test.cc (96%) rename mediapipe/tasks/cc/{components => text}/tokenizers/sentencepiece_tokenizer.h (85%) rename mediapipe/tasks/cc/{components => text}/tokenizers/sentencepiece_tokenizer_test.cc (94%) rename mediapipe/tasks/cc/{components => text}/tokenizers/tokenizer.h (84%) rename mediapipe/tasks/cc/{components => text}/tokenizers/tokenizer_utils.cc (95%) rename mediapipe/tasks/cc/{components => text}/tokenizers/tokenizer_utils.h (78%) rename mediapipe/tasks/cc/{components => text}/tokenizers/tokenizer_utils_test.cc (91%) create mode 100644 mediapipe/tasks/cc/vision/hand_detector/BUILD create mode 100644 mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc create mode 100644 mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc create mode 100644 mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc create mode 100644 mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h create mode 100644 mediapipe/tasks/cc/vision/hand_detector/proto/BUILD create mode 100644 mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto create mode 100644 mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.proto delete mode 100644 mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/hand_landmarks_to_matrix_calculator_test.cc rename mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/{hand_landmarks_to_matrix_calculator.cc => landmarks_to_matrix_calculator.cc} (57%) create mode 100644 mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc create mode 100644 mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.proto rename mediapipe/tasks/cc/vision/{hand_landmark => hand_landmarker}/BUILD (85%) rename mediapipe/tasks/cc/vision/{hand_landmark/hand_landmark_detector_graph.cc => hand_landmarker/hand_landmarker_subgraph.cc} (63%) create mode 100644 mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph_test.cc create mode 100644 mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD create mode 100644 mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_options.proto rename mediapipe/tasks/cc/vision/{hand_landmark/hand_landmark_detector_options.proto => hand_landmarker/proto/hand_landmarker_subgraph_options.proto} (90%) delete mode 100644 mediapipe/tasks/cc/vision/image_classification/image_classifier.cc delete mode 100644 mediapipe/tasks/cc/vision/image_classification/image_classifier.h delete mode 100644 mediapipe/tasks/cc/vision/image_classification/image_classifier_test.cc rename mediapipe/tasks/cc/vision/{image_classification => image_classifier}/BUILD (72%) create mode 100644 mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc create mode 100644 mediapipe/tasks/cc/vision/image_classifier/image_classifier.h rename mediapipe/tasks/cc/vision/{image_classification => image_classifier}/image_classifier_graph.cc (66%) create mode 100644 mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc create mode 100644 mediapipe/tasks/cc/vision/image_classifier/proto/BUILD rename mediapipe/tasks/cc/vision/{image_classification => image_classifier/proto}/image_classifier_options.proto (86%) create mode 100644 mediapipe/tasks/cc/vision/image_embedder/BUILD create mode 100644 mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc create mode 100644 mediapipe/tasks/cc/vision/image_embedder/image_embedder.h create mode 100644 mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc create mode 100644 mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc create mode 100644 mediapipe/tasks/cc/vision/image_embedder/proto/BUILD create mode 100644 mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto create mode 100644 mediapipe/tasks/java/BUILD create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/AndroidManifest.xml create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/BUILD create mode 100644 mediapipe/tasks/javatests/BUILD create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/BUILD create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/AndroidManifest.xml create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/BUILD create mode 100644 mediapipe/tasks/testdata/vision/hand_detector_result_one_hand.pbtxt create mode 100644 mediapipe/tasks/testdata/vision/hand_detector_result_two_hands.pbtxt create mode 100644 third_party/pffft.BUILD diff --git a/WORKSPACE b/WORKSPACE index d3cc40fbe..146916c5c 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -157,6 +157,13 @@ http_archive( urls = ["https://github.com/google/multichannel-audio-tools/archive/master.zip"], ) +http_archive( + name = "pffft", + strip_prefix = "jpommier-pffft-7c3b5a7dc510", + urls = ["https://bitbucket.org/jpommier/pffft/get/7c3b5a7dc510.zip"], + build_file = "@//third_party:pffft.BUILD", +) + # sentencepiece http_archive( name = "com_google_sentencepiece", diff --git a/docs/solutions/pose.md b/docs/solutions/pose.md index 8c57c033e..905800228 100644 --- a/docs/solutions/pose.md +++ b/docs/solutions/pose.md @@ -217,7 +217,7 @@ A list of pose landmarks. Each landmark consists of the following: *Fig 5. Example of MediaPipe Pose real-world 3D coordinates.* | :-----------------------------------------------------------: | - | + | Another list of pose landmarks in world coordinates. Each landmark consists of the following: @@ -238,7 +238,7 @@ for usage details. *Fig 6. Example of MediaPipe Pose segmentation mask.* | :---------------------------------------------------: | - | + | ### Python Solution API diff --git a/docs/solutions/selfie_segmentation.md b/docs/solutions/selfie_segmentation.md index 2cb155fb3..d8b17487c 100644 --- a/docs/solutions/selfie_segmentation.md +++ b/docs/solutions/selfie_segmentation.md @@ -22,7 +22,7 @@ nav_order: 7 *Fig 1. Example of MediaPipe Selfie Segmentation.* | :------------------------------------------------: | - | + | MediaPipe Selfie Segmentation segments the prominent humans in the scene. It can run in real-time on both smartphones and laptops. The intended use cases include diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index b28a3573a..26ada44bf 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -1294,8 +1294,8 @@ cc_library( deps = [ ":get_vector_item_calculator_cc_proto", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:packet", "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", @@ -1319,6 +1319,32 @@ cc_test( ], ) +cc_library( + name = "vector_indices_calculator", + srcs = ["vector_indices_calculator.cc"], + hdrs = ["vector_indices_calculator.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_test( + name = "vector_indices_calculator_test", + srcs = ["vector_indices_calculator_test.cc"], + deps = [ + ":vector_indices_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + ], +) + cc_library( name = "vector_size_calculator", srcs = ["vector_size_calculator.cc"], diff --git a/mediapipe/calculators/core/end_loop_calculator.cc b/mediapipe/calculators/core/end_loop_calculator.cc index d21bc03a4..b321f4275 100644 --- a/mediapipe/calculators/core/end_loop_calculator.cc +++ b/mediapipe/calculators/core/end_loop_calculator.cc @@ -40,6 +40,9 @@ REGISTER_CALCULATOR(EndLoopNormalizedLandmarkListVectorCalculator); typedef EndLoopCalculator> EndLoopBooleanCalculator; REGISTER_CALCULATOR(EndLoopBooleanCalculator); +typedef EndLoopCalculator> EndLoopFloatCalculator; +REGISTER_CALCULATOR(EndLoopFloatCalculator); + typedef EndLoopCalculator> EndLoopRenderDataCalculator; REGISTER_CALCULATOR(EndLoopRenderDataCalculator); diff --git a/mediapipe/calculators/core/get_vector_item_calculator.cc b/mediapipe/calculators/core/get_vector_item_calculator.cc index 56a2f3304..51fb46b98 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.cc +++ b/mediapipe/calculators/core/get_vector_item_calculator.cc @@ -24,6 +24,10 @@ using GetLandmarkListVectorItemCalculator = GetVectorItemCalculator; REGISTER_CALCULATOR(GetLandmarkListVectorItemCalculator); +using GetNormalizedLandmarkListVectorItemCalculator = + GetVectorItemCalculator; +REGISTER_CALCULATOR(GetNormalizedLandmarkListVectorItemCalculator); + using GetClassificationListVectorItemCalculator = GetVectorItemCalculator; REGISTER_CALCULATOR(GetClassificationListVectorItemCalculator); diff --git a/mediapipe/calculators/core/get_vector_item_calculator.h b/mediapipe/calculators/core/get_vector_item_calculator.h index be89aa3a3..dc98ccfe7 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.h +++ b/mediapipe/calculators/core/get_vector_item_calculator.h @@ -19,6 +19,7 @@ #include "mediapipe/calculators/core/get_vector_item_calculator.pb.h" #include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" @@ -58,7 +59,7 @@ template class GetVectorItemCalculator : public Node { public: static constexpr Input> kIn{"VECTOR"}; - static constexpr Input::Optional kIdx{"INDEX"}; + static constexpr Input>::Optional kIdx{"INDEX"}; static constexpr Output kOut{"ITEM"}; MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut); @@ -80,7 +81,9 @@ class GetVectorItemCalculator : public Node { int idx = 0; if (kIdx(cc).IsConnected() && !kIdx(cc).IsEmpty()) { - idx = kIdx(cc).Get(); + idx = kIdx(cc).Visit( + [](uint64_t idx_uint64_t) { return static_cast(idx_uint64_t); }, + [](int idx_int) { return idx_int; }); } else if (options.has_item_index()) { idx = options.item_index(); } else { diff --git a/mediapipe/calculators/core/get_vector_item_calculator_test.cc b/mediapipe/calculators/core/get_vector_item_calculator_test.cc index f2f788382..c148aa9d1 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator_test.cc +++ b/mediapipe/calculators/core/get_vector_item_calculator_test.cc @@ -227,4 +227,15 @@ TEST(TestGetIntVectorItemCalculatorTest, IndexOptionsTwoTimestamps) { testing::ElementsAre(TimestampValue(1), TimestampValue(2))); } +TEST(TestGetIntVectorItemCalculatorTest, IndexUint64) { + CalculatorRunner runner = MakeRunnerWithStream(); + const std::vector inputs = {1, 2, 3}; + const uint64_t index = 1; + AddInputVector(runner, inputs, 1); + AddInputIndex(runner, index, 1); + MP_ASSERT_OK(runner.Run()); + const std::vector& outputs = runner.Outputs().Tag("ITEM").packets; + EXPECT_THAT(outputs, testing::ElementsAre(IntPacket(inputs[index]))); +} + } // namespace mediapipe diff --git a/mediapipe/calculators/core/vector_indices_calculator.cc b/mediapipe/calculators/core/vector_indices_calculator.cc new file mode 100644 index 000000000..56baa1376 --- /dev/null +++ b/mediapipe/calculators/core/vector_indices_calculator.cc @@ -0,0 +1,33 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/core/vector_indices_calculator.h" + +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe { +namespace api2 { + +using IntVectorIndicesCalculator = VectorIndicesCalculator; +REGISTER_CALCULATOR(IntVectorIndicesCalculator); + +using Uint64tVectorIndicesCalculator = VectorIndicesCalculator; +REGISTER_CALCULATOR(Uint64tVectorIndicesCalculator); + +using NormalizedLandmarkListVectorIndicesCalculator = + VectorIndicesCalculator; +REGISTER_CALCULATOR(NormalizedLandmarkListVectorIndicesCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/core/vector_indices_calculator.h b/mediapipe/calculators/core/vector_indices_calculator.h new file mode 100644 index 000000000..98cdcf33b --- /dev/null +++ b/mediapipe/calculators/core/vector_indices_calculator.h @@ -0,0 +1,65 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_CORE_VECTOR_INDICES_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_CORE_VECTOR_INDICES_CALCULATOR_H_ + +#include + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { +namespace api2 { +// Calculator that takes a vector and consturct an index range vector based on +// the size of the input vector. +// +// Inputs: +// VECTOR - std::vector +// Vector whose range of indices to return. +// +// Outputs: +// INDICES - std::vector +// Indices vector of the input vector. +// +// Example config: +// node { +// calculator: "{SpecificType}VectorIndicesCalculator" +// input_stream: "VECTOR:vector" +// output_stream: "INDICES:indices" +// } +// +template +class VectorIndicesCalculator : public Node { + public: + static constexpr Input> kVector{"VECTOR"}; + static constexpr Output> kRange{"INDICES"}; + + MEDIAPIPE_NODE_CONTRACT(kVector, kRange); + + absl::Status Process(CalculatorContext* cc) final { + // Get the size of the input vector. + const int vector_size = kVector(cc).Get().size(); + std::vector out_idxs(vector_size); + std::iota(out_idxs.begin(), out_idxs.end(), 0); + kRange(cc).Send(out_idxs); + return absl::OkStatus(); + } +}; + +} // namespace api2 +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_CORE_VECTOR_INDICES_CALCULATOR_H_ diff --git a/mediapipe/calculators/core/vector_indices_calculator_test.cc b/mediapipe/calculators/core/vector_indices_calculator_test.cc new file mode 100644 index 000000000..ff54f1f4a --- /dev/null +++ b/mediapipe/calculators/core/vector_indices_calculator_test.cc @@ -0,0 +1,87 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/core/vector_indices_calculator.h" + +#include +#include +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { + +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; + +template +void AddInputVector(CalculatorRunner& runner, const std::vector& inputs, + int timestamp) { + runner.MutableInputs()->Tag("VECTOR").packets.push_back( + MakePacket>(inputs).At(Timestamp(timestamp))); +} + +template +struct TestParams { + const std::string test_name; + const std::vector inputs; + const int timestamp; + const std::vector expected_indices; +}; + +class IntVectorIndicesCalculatorTest + : public testing::TestWithParam> {}; + +TEST_P(IntVectorIndicesCalculatorTest, Succeeds) { + CalculatorRunner runner = CalculatorRunner(R"( + calculator: "IntVectorIndicesCalculator" + input_stream: "VECTOR:vector_stream" + output_stream: "INDICES:indices_stream" + )"); + const std::vector& inputs = GetParam().inputs; + std::vector expected_indices(inputs.size()); + AddInputVector(runner, inputs, GetParam().timestamp); + MP_ASSERT_OK(runner.Run()); + const std::vector& outputs = runner.Outputs().Tag("INDICES").packets; + EXPECT_EQ(1, outputs.size()); + EXPECT_THAT(outputs[0].Get>(), + testing::ElementsAreArray(GetParam().expected_indices)); +} + +INSTANTIATE_TEST_SUITE_P( + IntVectorIndicesCalculatorTest, IntVectorIndicesCalculatorTest, + Values(TestParams{ + /* test_name= */ "IntVectorIndices", + /* inputs= */ {1, 2, 3}, + /* timestamp= */ 1, + /* expected_indices= */ {0, 1, 2}, + }, + TestParams{ + /* test_name= */ "EmptyVector", + /* inputs= */ {}, + /* timestamp= */ 1, + /* expected_indices= */ {}, + }), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index c378df7d0..93f2dbd06 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -55,6 +55,14 @@ mediapipe_proto_library( cc_library( name = "audio_to_tensor_calculator", srcs = ["audio_to_tensor_calculator.cc"], + copts = select({ + # b/215212850 + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", + ], + "//conditions:default": [], + }), visibility = [ "//mediapipe/framework:mediapipe_internal", ], @@ -67,13 +75,16 @@ cc_library( "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:ret_check", "//mediapipe/util:time_series_util", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_audio_tools//audio/dsp:resampler_q", + "@com_google_audio_tools//audio/dsp:window_functions", "@org_tensorflow//tensorflow/lite/c:common", + "@pffft", ], alwayslink = 1, ) @@ -83,6 +94,7 @@ cc_test( srcs = ["audio_to_tensor_calculator_test.cc"], deps = [ ":audio_to_tensor_calculator", + ":audio_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -97,6 +109,58 @@ cc_test( ], ) +mediapipe_proto_library( + name = "feedback_tensors_calculator_proto", + srcs = ["feedback_tensors_calculator.proto"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "feedback_tensors_calculator", + srcs = ["feedback_tensors_calculator.cc"], + copts = select({ + # b/215212850 + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", + ], + "//conditions:default": [], + }), + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], + deps = [ + ":feedback_tensors_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:tensor", + "@com_google_absl//absl/status", + ], + alwayslink = 1, +) + +cc_test( + name = "feedback_tensors_calculator_test", + srcs = ["feedback_tensors_calculator_test.cc"], + deps = [ + ":feedback_tensors_calculator", + ":feedback_tensors_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "@org_tensorflow//tensorflow/lite/c:common", + ], +) + mediapipe_proto_library( name = "inference_calculator_proto", srcs = ["inference_calculator.proto"], @@ -346,6 +410,10 @@ cc_library( }), ) +# This target provides the InferenceCalculator and a default set of implementations tailored for the +# current build platforms. More implementations can be added as separate dependencies to a client; +# for clients that want a narrower set of implementations than the default should see the comment on +# inference_calculator_interface. cc_library( name = "inference_calculator", visibility = ["//visibility:public"], diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc index 474d6cf17..59c129191 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc @@ -12,9 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - #include +#include #include #include #include @@ -26,6 +25,7 @@ #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "audio/dsp/resampler_q.h" +#include "audio/dsp/window_functions.h" #include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h" #include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/packet.h" @@ -34,19 +34,60 @@ #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/util/time_series_util.h" +#include "pffft.h" namespace mediapipe { namespace api2 { +namespace { + +using Options = ::mediapipe::AudioToTensorCalculatorOptions; +using FlushMode = Options::FlushMode; + +std::vector HannWindow(int window_size, bool sqrt_hann) { + std::vector hann_window(window_size); + audio_dsp::HannWindow().GetPeriodicSamples(window_size, &hann_window); + if (sqrt_hann) { + absl::c_transform(hann_window, hann_window.begin(), + [](double x) { return std::sqrt(x); }); + } + return hann_window; +} + +// PFFFT only supports transforms for inputs of length N of the form +// N = (2^a)*(3^b)*(5^c) where b >=0 and c >= 0 and a >= 5 for the real FFT. +bool IsValidFftSize(int size) { + if (size <= 0) { + return false; + } + constexpr int kFactors[] = {2, 3, 5}; + int factorization[] = {0, 0, 0}; + int n = static_cast(size); + for (int i = 0; i < 3; ++i) { + while (n % kFactors[i] == 0) { + n = n / kFactors[i]; + ++factorization[i]; + } + } + return factorization[0] >= 5 && n == 1; +} + +} // namespace // Converts audio buffers into tensors, possibly with resampling, buffering // and framing, according to specified inputs and options. All input audio // buffers will be first resampled from the input sample rate to the target // sample rate if they are not equal. The resampled audio data (with the // buffered samples from the previous runs in the streaming mode) will be broken -// into fixed-sized, possibly overlapping frames. Finally, all frames will be -// converted to and outputted as MediaPipe Tensors. The last output tensor will -// be zero-padding if the remaining samples are insufficient. +// into fixed-sized, possibly overlapping frames. If the calculator is not asked +// to perform fft (the fft_size is not set in the calculator options), all +// frames will be converted to and outputted as MediaPipe Tensors. The last +// output tensor will be zero-padding if the remaining samples are insufficient. +// Otherwise, when the fft_size is set and valid, the calculator will perform +// fft on the fixed-sized audio frames, the complex DFT results will be +// converted to and outputted as 2D MediaPipe float Tensors where the first +// rows are the DFT real parts and the second rows are the DFT imagery parts. // // This calculator assumes that the input timestamps refer to the first // sample in each Matrix. The output timestamps follow this same convention. @@ -86,11 +127,15 @@ namespace api2 { // Outputs: // TENSORS - std::vector // Vector containing a single Tensor that represents a fix-sized audio -// frame. +// frame or the complex DFT results. // TIMESTAMPS - std::vector @Optional // Vector containing the output timestamps emitted by the current Process() // invocation. In the non-streaming mode, the vector contains all of the // output timestamps for an input audio buffer. +// DC_AND_NYQUIST - std::pair @Optional. +// A pair of dc component and nyquest component. Only can be connected when +// the calculator performs fft (the fft_size is set in the calculator +// options). // // Example: // node { @@ -116,12 +161,14 @@ class AudioToTensorCalculator : public Node { // such as sample rate. static constexpr Input::Optional kAudioSampleRateIn{"SAMPLE_RATE"}; static constexpr Output> kTensorsOut{"TENSORS"}; + static constexpr Output>::Optional kDcAndNyquistOut{ + "DC_AND_NYQUIST"}; // A vector of the output timestamps emitted by the current Process() // invocation. The packet timestamp is the last emitted timestamp. static constexpr Output>::Optional kTimestampsOut{ "TIMESTAMPS"}; MEDIAPIPE_NODE_CONTRACT(kAudioIn, kAudioSampleRateIn, kTensorsOut, - kTimestampsOut); + kDcAndNyquistOut, kTimestampsOut); static absl::Status UpdateContract(CalculatorContract* cc); absl::Status Open(CalculatorContext* cc); @@ -138,6 +185,9 @@ class AudioToTensorCalculator : public Node { int frame_step_; bool stream_mode_; bool check_inconsistent_timestamps_; + int padding_samples_before_; + int padding_samples_after_; + FlushMode flush_mode_; Timestamp initial_timestamp_ = Timestamp::Unstarted(); int64 cumulative_input_samples_ = 0; Timestamp next_output_timestamp_ = Timestamp::Unstarted(); @@ -151,22 +201,33 @@ class AudioToTensorCalculator : public Node { Matrix sample_buffer_; int processed_buffer_cols_ = 0; + // The internal state of the FFT library. + PFFFT_Setup* fft_state_ = nullptr; + int fft_size_ = 0; + std::vector fft_window_; + std::vector> fft_input_buffer_; + // pffft requires memory to work with to avoid using the stack. + std::vector> fft_workplace_; + std::vector> fft_output_; + absl::Status ProcessStreamingData(CalculatorContext* cc, const Matrix& input); absl::Status ProcessNonStreamingData(CalculatorContext* cc, const Matrix& input); absl::Status SetupStreamingResampler(double input_sample_rate_); void AppendToSampleBuffer(Matrix buffer_to_append); + void AppendZerosToSampleBuffer(int num_samples); absl::StatusOr> ConvertToTensor( - const Matrix& frame_to_convert); - absl::Status OutputTensors(const Matrix& buffer, bool should_flush, + const Matrix& block, std::vector tensor_dims); + absl::Status OutputTensor(const Matrix& block, Timestamp timestamp, + CalculatorContext* cc); + absl::Status ProcessBuffer(const Matrix& buffer, bool should_flush, CalculatorContext* cc); }; absl::Status AudioToTensorCalculator::UpdateContract(CalculatorContract* cc) { - const auto& options = - cc->Options(); + const auto& options = cc->Options(); if (!options.has_num_channels() || !options.has_num_samples() || !options.has_target_sample_rate()) { return absl::InvalidArgumentError( @@ -174,13 +235,21 @@ absl::Status AudioToTensorCalculator::UpdateContract(CalculatorContract* cc) { "`num_channels`, `num_samples`, and `target_sample_rate`."); } if (options.stream_mode()) { - // Explicitly disables tiemstamp offset to disallow the timestamp bound + // Explicitly disables timestamp offset to disallow the timestamp bound // from the input streams to be propagated to the output streams. // In the streaming mode, the output timestamp bound is based on // next_output_timestamp_, which can be smaller than the current input // timestamps. cc->SetTimestampOffset(TimestampDiff::Unset()); } + if (options.padding_samples_before() < 0 || + options.padding_samples_after() < 0) { + return absl::InvalidArgumentError("Negative zero padding unsupported"); + } + if (options.flush_mode() != Options::ENTIRE_TAIL_AT_TIMESTAMP_MAX && + options.flush_mode() != Options::PROCEED_AS_USUAL) { + return absl::InvalidArgumentError("Unsupported flush mode"); + } return absl::OkStatus(); } @@ -202,6 +271,9 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) { check_inconsistent_timestamps_ = options.check_inconsistent_timestamps(); sample_buffer_.resize(num_channels_, Eigen::NoChange); } + padding_samples_before_ = options.padding_samples_before(); + padding_samples_after_ = options.padding_samples_after(); + flush_mode_ = options.flush_mode(); RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^ !kAudioIn(cc).Header().IsEmpty()) @@ -217,6 +289,25 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) { source_sample_rate_ = input_header.sample_rate(); } } + AppendZerosToSampleBuffer(padding_samples_before_); + if (options.has_fft_size()) { + RET_CHECK(IsValidFftSize(options.fft_size())) + << "FFT size must be of the form fft_size = (2^a)*(3^b)*(5^c) where b " + ">=0 and c >= 0 and a >= 5, the requested fft size is " + << options.fft_size(); + RET_CHECK_EQ(1, num_channels_) + << "Currently only support applying FFT on mono channel."; + fft_size_ = options.fft_size(); + fft_state_ = pffft_new_setup(fft_size_, PFFFT_REAL); + fft_window_ = HannWindow(fft_size_, /* sqrt_hann = */ false); + fft_input_buffer_.resize(fft_size_); + fft_workplace_.resize(fft_size_); + fft_output_.resize(fft_size_); + } else { + RET_CHECK(!kDcAndNyquistOut(cc).IsConnected()) + << "The DC_AND_NYQUIST output stream can only be connected when the " + "calculator outputs fft tensors"; + } return absl::OkStatus(); } @@ -262,7 +353,12 @@ absl::Status AudioToTensorCalculator::Close(CalculatorContext* cc) { resampler_->Flush(&resampled_buffer); AppendToSampleBuffer(std::move(resampled_buffer)); } - return OutputTensors(sample_buffer_, /*should_flush=*/true, cc); + AppendZerosToSampleBuffer(padding_samples_after_); + MP_RETURN_IF_ERROR(ProcessBuffer(sample_buffer_, /*should_flush=*/true, cc)); + if (fft_state_) { + pffft_destroy_setup(fft_state_); + } + return absl::OkStatus(); } absl::Status AudioToTensorCalculator::ProcessStreamingData( @@ -303,7 +399,7 @@ absl::Status AudioToTensorCalculator::ProcessStreamingData( } } - MP_RETURN_IF_ERROR(OutputTensors(sample_buffer_, /*should_flush=*/false, cc)); + MP_RETURN_IF_ERROR(ProcessBuffer(sample_buffer_, /*should_flush=*/false, cc)); // Removes the processed samples from the global sample buffer. sample_buffer_ = Matrix(sample_buffer_.rightCols(sample_buffer_.cols() - processed_buffer_cols_ - 1)); @@ -323,9 +419,9 @@ absl::Status AudioToTensorCalculator::ProcessNonStreamingData( input_frame); Eigen::Map matrix_mapping(resampled.data(), num_channels_, resampled.size() / num_channels_); - return OutputTensors(matrix_mapping, /*should_flush=*/true, cc); + return ProcessBuffer(matrix_mapping, /*should_flush=*/true, cc); } - return OutputTensors(input_frame, /*should_flush=*/true, cc); + return ProcessBuffer(input_frame, /*should_flush=*/true, cc); } absl::Status AudioToTensorCalculator::SetupStreamingResampler( @@ -344,6 +440,16 @@ absl::Status AudioToTensorCalculator::SetupStreamingResampler( return absl::OkStatus(); } +void AudioToTensorCalculator::AppendZerosToSampleBuffer(int num_samples) { + CHECK_GE(num_samples, 0); // Ensured by `UpdateContract`. + if (num_samples == 0) { + return; + } + sample_buffer_.conservativeResize(Eigen::NoChange, + sample_buffer_.cols() + num_samples); + sample_buffer_.rightCols(num_samples).setZero(); +} + void AudioToTensorCalculator::AppendToSampleBuffer(Matrix buffer_to_append) { sample_buffer_.conservativeResize( Eigen::NoChange, sample_buffer_.cols() + buffer_to_append.cols()); @@ -351,49 +457,89 @@ void AudioToTensorCalculator::AppendToSampleBuffer(Matrix buffer_to_append) { } absl::StatusOr> AudioToTensorCalculator::ConvertToTensor( - const Matrix& frame_to_convert) { - Tensor tensor(Tensor::ElementType::kFloat32, - Tensor::Shape({num_channels_, num_samples_})); + const Matrix& block, std::vector tensor_dims) { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape(tensor_dims)); auto buffer_view = tensor.GetCpuWriteView(); - if (frame_to_convert.size() < num_channels_ * num_samples_) { + int total_size = 1; + for (int dim : tensor_dims) { + total_size *= dim; + } + if (block.size() < total_size) { std::memset(buffer_view.buffer(), 0, tensor.bytes()); } - std::memcpy(buffer_view.buffer(), frame_to_convert.data(), - frame_to_convert.size() * sizeof(float)); + std::memcpy(buffer_view.buffer(), block.data(), + block.size() * sizeof(float)); std::vector tensor_vector; tensor_vector.push_back(std::move(tensor)); return tensor_vector; } -absl::Status AudioToTensorCalculator::OutputTensors(const Matrix& buffer, +absl::Status AudioToTensorCalculator::OutputTensor(const Matrix& block, + Timestamp timestamp, + CalculatorContext* cc) { + std::vector output_tensor; + if (fft_state_) { + Eigen::VectorXf time_series_data = + Eigen::VectorXf::Map(block.data(), block.size()); + // Window on input audio prior to FFT. + std::transform(time_series_data.begin(), time_series_data.end(), + fft_window_.begin(), fft_input_buffer_.begin(), + std::multiplies()); + pffft_transform_ordered(fft_state_, fft_input_buffer_.data(), + fft_output_.data(), fft_workplace_.data(), + PFFFT_FORWARD); + if (kDcAndNyquistOut(cc).IsConnected()) { + kDcAndNyquistOut(cc).Send(std::make_pair(fft_output_[0], fft_output_[1]), + timestamp); + } + Matrix fft_output_matrix = + Eigen::Map(fft_output_.data() + 2, 1, fft_size_ - 2); + fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_); + // The last two elements are the DFT Nyquist values. + fft_output_matrix(fft_size_ - 2) = fft_output_[1]; // Nyquist real part + fft_output_matrix(fft_size_ - 1) = 0.0f; // Nyquist imagery part + ASSIGN_OR_RETURN(output_tensor, + ConvertToTensor(fft_output_matrix, {2, fft_size_ / 2})); + } else { + ASSIGN_OR_RETURN(output_tensor, + ConvertToTensor(block, {num_channels_, num_samples_})); + } + kTensorsOut(cc).Send(std::move(output_tensor), timestamp); + return absl::OkStatus(); +} + +absl::Status AudioToTensorCalculator::ProcessBuffer(const Matrix& buffer, bool should_flush, CalculatorContext* cc) { + const bool should_flush_at_timestamp_max = + stream_mode_ && should_flush && + flush_mode_ == Options::ENTIRE_TAIL_AT_TIMESTAMP_MAX; int next_frame_first_col = 0; std::vector timestamps; - while ((!stream_mode_ || !should_flush) && - next_frame_first_col + num_samples_ <= buffer.cols()) { - ASSIGN_OR_RETURN(auto output_tensor, ConvertToTensor(buffer.block( - 0, next_frame_first_col, - num_channels_, num_samples_))); - kTensorsOut(cc).Send(std::move(output_tensor), next_output_timestamp_); - timestamps.push_back(next_output_timestamp_); - next_output_timestamp_ += round(frame_step_ / target_sample_rate_ * - Timestamp::kTimestampUnitsPerSecond); - next_frame_first_col += frame_step_; + if (!should_flush_at_timestamp_max) { + while (next_frame_first_col + num_samples_ <= buffer.cols()) { + MP_RETURN_IF_ERROR(OutputTensor( + buffer.block(0, next_frame_first_col, num_channels_, num_samples_), + next_output_timestamp_, cc)); + timestamps.push_back(next_output_timestamp_); + next_output_timestamp_ += round(frame_step_ / target_sample_rate_ * + Timestamp::kTimestampUnitsPerSecond); + next_frame_first_col += frame_step_; + } } if (should_flush && next_frame_first_col < buffer.cols()) { - ASSIGN_OR_RETURN(auto output_tensor, - ConvertToTensor(buffer.block( - 0, next_frame_first_col, num_channels_, - std::min(num_samples_, - (int)buffer.cols() - next_frame_first_col)))); // In the streaming mode, the flush happens in Close() and a packet at // Timestamp::Max() will be emitted. In the non-streaming mode, each // Process() invocation will process the entire buffer completely. - Timestamp timestamp = - stream_mode_ ? Timestamp::Max() : next_output_timestamp_; + Timestamp timestamp = should_flush_at_timestamp_max + ? Timestamp::Max() + : next_output_timestamp_; + MP_RETURN_IF_ERROR(OutputTensor( + buffer.block( + 0, next_frame_first_col, num_channels_, + std::min(num_samples_, (int)buffer.cols() - next_frame_first_col)), + timestamp, cc)); timestamps.push_back(timestamp); - kTensorsOut(cc).Send(std::move(output_tensor), timestamp); } if (kTimestampsOut(cc).IsConnected()) { Timestamp timestamp = timestamps.back(); diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto index 2090fbb81..cff6b2878 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto @@ -44,4 +44,28 @@ message AudioToTensorCalculatorOptions { // Set to false to disable checks for jitter in timestamp values. Useful with // live audio input. optional bool check_inconsistent_timestamps = 6 [default = true]; + + // Size of the fft in number of bins. If set, the calculator outputs fft + // tensors. + optional int64 fft_size = 7; + + // The amount of padding samples to add before the audio after resampling. + // Note that the timestamps shift. Currently, only zero padding is supported. + optional int64 padding_samples_before = 8; + + // The amount of padding samples to add after the audio after resampling. + // Currently, only zero padding is supported. + optional int64 padding_samples_after = 9; + + // Determines the "flushing" behavior in stream mode. + enum FlushMode { + // Unspecified (causes an error). Won't be used because of the default. + NONE = 0; + // Emit a packet with the entire remainder at `Timestamp::Max`. + ENTIRE_TAIL_AT_TIMESTAMP_MAX = 1; + // Continue emitting framed packets with relevant timestamps. + PROCEED_AS_USUAL = 2; + } + + optional FlushMode flush_mode = 10 [default = ENTIRE_TAIL_AT_TIMESTAMP_MAX]; } diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/audio_to_tensor_calculator_test.cc index c2062134d..60fcfcd82 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator_test.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include #include #include "absl/strings/substitute.h" #include "audio/dsp/resampler_q.h" +#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h" #include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -32,6 +32,14 @@ namespace mediapipe { namespace { +using ::testing::Not; +using Options = ::mediapipe::AudioToTensorCalculatorOptions; +using FlushMode = Options::FlushMode; + +int DivideRoundedUp(int dividend, int divisor) { + return (dividend + divisor - 1) / divisor; +} + std::unique_ptr CreateTestMatrix(int num_channels, int num_samples, int timestamp) { auto matrix = std::make_unique(num_channels, num_samples); @@ -292,16 +300,17 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test { num_iterations_ = num_iterations; } - int GetExpectedNumOfSamples() { - Matrix* expected_matrix = - resampled_buffer_ ? resampled_buffer_.get() : sample_buffer_.get(); - return expected_matrix->cols(); - } + int GetExpectedNumOfSamples() { return output_sample_buffer_->cols(); } void Run(int num_samples, int num_overlapping_samples, - double resampling_factor) { + double resampling_factor, int padding_before = 0, + int padding_after = 0, bool expect_init_error = false) { double input_sample_rate = 10000; double target_sample_rate = input_sample_rate * resampling_factor; + FlushMode flush_mode = (padding_before != 0 || padding_after != 0) + ? Options::PROCEED_AS_USUAL + : Options::ENTIRE_TAIL_AT_TIMESTAMP_MAX; + auto graph_config = ParseTextProtoOrDie( absl::Substitute(R"( input_stream: "audio" @@ -319,16 +328,25 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test { num_overlapping_samples: $1 target_sample_rate: $2 stream_mode:true + padding_samples_before: $3 + padding_samples_after: $4 + flush_mode: $5 } } } )", /*$0=*/num_samples, /*$1=*/num_overlapping_samples, - /*$2=*/target_sample_rate)); + /*$2=*/target_sample_rate, /*$3=*/padding_before, + /*$4=*/padding_after, /*$5=*/flush_mode)); tool::AddVectorSink("tensors", &graph_config, &tensors_packets_); // Run the graph. - MP_ASSERT_OK(graph_.Initialize(graph_config)); + const absl::Status init_status = graph_.Initialize(graph_config); + if (expect_init_error) { + EXPECT_THAT(init_status, Not(IsOk())); + return; + } + MP_ASSERT_OK(init_status); MP_ASSERT_OK(graph_.StartRun({})); for (int i = 0; i < num_iterations_; ++i) { Timestamp input_timestamp(Timestamp::kTimestampUnitsPerSecond * i); @@ -345,8 +363,18 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test { } MP_ASSERT_OK(graph_.CloseAllInputStreams()); MP_ASSERT_OK(graph_.WaitUntilIdle()); - if (resampling_factor != 1) { - resampled_buffer_ = ResampleBuffer(*sample_buffer_, resampling_factor); + if (resampling_factor == 1) { + output_sample_buffer_ = std::make_unique(*sample_buffer_); + } else { + output_sample_buffer_ = + ResampleBuffer(*sample_buffer_, resampling_factor); + } + if (padding_before != 0 || padding_after != 0) { + Matrix padded = Matrix::Zero( + 2, padding_before + output_sample_buffer_->cols() + padding_after); + padded.block(0, padding_before, 2, output_sample_buffer_->cols()) = + *output_sample_buffer_; + output_sample_buffer_->swap(padded); } } @@ -372,15 +400,13 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test { auto buffer = output_tensor.GetCpuReadView().buffer(); int num_values = output_tensor.shape().num_elements(); std::vector output_floats(buffer, buffer + num_values); - Matrix* expected_matrix = - resampled_buffer_ ? resampled_buffer_.get() : sample_buffer_.get(); for (int i = 0; i < num_values; ++i) { - if (i + sample_offset >= expected_matrix->size()) { + if (i + sample_offset >= output_sample_buffer_->size()) { EXPECT_FLOAT_EQ(output_floats[i], 0); } else { EXPECT_NEAR(output_floats[i], - expected_matrix->coeff((i + sample_offset) % 2, - (i + sample_offset) / 2), + output_sample_buffer_->coeff((i + sample_offset) % 2, + (i + sample_offset) / 2), 0.001) << "i=" << i << ", sample_offset=" << sample_offset << ", packet index=" << index; @@ -391,7 +417,8 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test { // Fully close graph at end, otherwise calculator+tensors are destroyed // after calling WaitUntilDone(). - void CloseGraph() { MP_EXPECT_OK(graph_.WaitUntilDone()); } + absl::Status TryCloseGraph() { return graph_.WaitUntilDone(); } + void CloseGraph() { MP_EXPECT_OK(TryCloseGraph()); } private: int input_buffer_num_samples_ = 10; @@ -399,7 +426,7 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test { CalculatorGraph graph_; std::vector tensors_packets_; std::unique_ptr sample_buffer_; - std::unique_ptr resampled_buffer_; + std::unique_ptr output_sample_buffer_; }; TEST_F(AudioToTensorCalculatorStreamingModeTest, @@ -408,7 +435,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, /*resampling_factor=*/1.0f); CheckTensorsOutputPackets( /*sample_offset=*/10, - /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 5), + /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 5), /*timestamp_interval=*/500, /*output_last_at_close=*/false); CloseGraph(); @@ -419,7 +446,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, OutputRemainingInCloseMethod) { /*resampling_factor=*/1.0f); CheckTensorsOutputPackets( /*sample_offset=*/12, - /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 6), + /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 6), /*timestamp_interval=*/600, /*output_last_at_close=*/true); CloseGraph(); @@ -431,7 +458,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, OutputOverlappingFp32Tensors) { /*resampling_factor=*/1.0f); CheckTensorsOutputPackets( /*sample_offset=*/16, - /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 8), + /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 8), /*timestamp_interval=*/800, /*output_last_at_close=*/true); CloseGraph(); @@ -443,7 +470,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, Downsampling) { /*resampling_factor=*/0.5f); CheckTensorsOutputPackets( /*sample_offset=*/512, - /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 256), + /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 256), /*timestamp_interval=*/51200, /*output_last_at_close=*/true); CloseGraph(); @@ -455,7 +482,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, DownsamplingWithOverlapping) { /*resampling_factor=*/0.5f); CheckTensorsOutputPackets( /*sample_offset=*/384, - /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 192), + /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 192), /*timestamp_interval=*/38400, /*output_last_at_close=*/true); CloseGraph(); @@ -467,7 +494,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, Upsampling) { /*resampling_factor=*/2.0f); CheckTensorsOutputPackets( /*sample_offset=*/512, - /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 256), + /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 256), /*timestamp_interval=*/12800, /*output_last_at_close=*/true); CloseGraph(); @@ -479,12 +506,33 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, UpsamplingWithOverlapping) { /*resampling_factor=*/2.0f); CheckTensorsOutputPackets( /*sample_offset=*/384, - /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 192), + /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 192), /*timestamp_interval=*/9600, /*output_last_at_close=*/true); CloseGraph(); } +TEST_F(AudioToTensorCalculatorStreamingModeTest, + UpsamplingWithOverlappingAndPadding) { + SetInputBufferNumSamplesPerChannel(1024); + Run(/*num_samples=*/256, /*num_overlapping_samples=*/64, + /*resampling_factor=*/2.0f, /*padding_before=*/13, /*padding_after=*/999); + CheckTensorsOutputPackets( + /*sample_offset=*/384, + /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 192), + /*timestamp_interval=*/9600, + /*output_last_at_close=*/false); + CloseGraph(); +} + +TEST_F(AudioToTensorCalculatorStreamingModeTest, NegativePaddingUnsupported) { + SetInputBufferNumSamplesPerChannel(1024); + Run(/*num_samples=*/256, /*num_overlapping_samples=*/64, + /*resampling_factor=*/2.0f, /*padding_before=*/13, /*padding_after=*/-3, + /*expect_init_error=*/true); + EXPECT_THAT(TryCloseGraph(), Not(IsOk())); +} + TEST_F(AudioToTensorCalculatorStreamingModeTest, OnlyOutputInCloseIfNoSufficientSamples) { SetNumIterations(1); @@ -498,5 +546,122 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, CloseGraph(); } +class AudioToTensorCalculatorFftTest : public ::testing::Test { + protected: + // Creates an audio matrix containing a single sample of 1.0 at a specified + // offset. + std::unique_ptr CreateImpulseSignalData(int64 num_samples, + int impulse_offset_idx) { + Matrix impulse = Matrix::Zero(1, num_samples); + impulse(0, impulse_offset_idx) = 1.0; + return std::make_unique(std::move(impulse)); + } + + void ConfigGraph(int num_channels, int num_samples, + int num_overlapping_samples, double sample_rate, + int fft_size) { + graph_config_ = ParseTextProtoOrDie( + absl::Substitute(R"( + input_stream: "audio" + input_stream: "sample_rate" + output_stream: "tensors" + output_stream: "dc_and_nyquist" + node { + calculator: "AudioToTensorCalculator" + input_stream: "AUDIO:audio" + input_stream: "SAMPLE_RATE:sample_rate" + output_stream: "TENSORS:tensors" + output_stream: "DC_AND_NYQUIST:dc_and_nyquist" + options { + [mediapipe.AudioToTensorCalculatorOptions.ext] { + num_channels: $0 + num_samples: $1 + num_overlapping_samples: $2 + target_sample_rate: $3 + fft_size: $4 + } + } + } + )", + /*$0=*/num_channels, + /*$1=*/num_samples, + /*$2=*/num_overlapping_samples, + /*$3=*/sample_rate, /*$4=*/fft_size)); + std::vector tensors_packets; + tool::AddVectorSink("tensors", &graph_config_, &tensors_packets_); + std::vector dc_and_nyquist_packets; + tool::AddVectorSink("dc_and_nyquist", &graph_config_, + &dc_and_nyquist_packets_); + } + + void RunGraph(std::unique_ptr input_data, double sample_rate) { + MP_ASSERT_OK(graph_.Initialize(graph_config_)); + MP_ASSERT_OK(graph_.StartRun({})); + MP_ASSERT_OK(graph_.AddPacketToInputStream( + "sample_rate", MakePacket(sample_rate).At(Timestamp(0)))); + MP_ASSERT_OK(graph_.AddPacketToInputStream( + "audio", MakePacket(*input_data).At(Timestamp(0)))); + MP_ASSERT_OK(graph_.CloseAllInputStreams()); + MP_ASSERT_OK(graph_.WaitUntilIdle()); + ASSERT_EQ(tensors_packets_.size(), dc_and_nyquist_packets_.size()); + } + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + void CloseGraph() { MP_EXPECT_OK(graph_.WaitUntilDone()); } + + std::vector tensors_packets_; + std::vector dc_and_nyquist_packets_; + CalculatorGraphConfig graph_config_; + CalculatorGraph graph_; +}; + +TEST_F(AudioToTensorCalculatorFftTest, TestInvalidFftSize) { + ConfigGraph(1, 320, 160, 16000, 103); + MP_ASSERT_OK(graph_.Initialize(graph_config_)); + MP_ASSERT_OK(graph_.StartRun({})); + auto status = graph_.WaitUntilIdle(); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_THAT(status.message(), + ::testing::HasSubstr("FFT size must be of the form")); +} + +TEST_F(AudioToTensorCalculatorFftTest, TestInvalidNumChannels) { + ConfigGraph(3, 320, 160, 16000, 256); + MP_ASSERT_OK(graph_.Initialize(graph_config_)); + MP_ASSERT_OK(graph_.StartRun({})); + auto status = graph_.WaitUntilIdle(); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_THAT( + status.message(), + ::testing::HasSubstr("only support applying FFT on mono channel")); +} + +TEST_F(AudioToTensorCalculatorFftTest, TestImpulseSignal) { + constexpr double sample_rate = 16000; + ConfigGraph(1, 320, 160, sample_rate, 320); + RunGraph(CreateImpulseSignalData(320, 160), sample_rate); + for (int i = 0; i < tensors_packets_.size(); ++i) { + const auto& tensors = tensors_packets_[i].Get>(); + ASSERT_EQ(1, tensors.size()); + const Tensor& output_tensor = + tensors_packets_[0].Get>()[0]; + auto* buffer = output_tensor.GetCpuReadView().buffer(); + int num_values = output_tensor.shape().num_elements(); + const std::vector output_floats(buffer, buffer + num_values); + // Impulse signal should have (approximately) const power across all + // frequency bins. + const auto& pair = + dc_and_nyquist_packets_[i].Get>(); + EXPECT_FLOAT_EQ(pair.first, 1.0f); + EXPECT_FLOAT_EQ(pair.second, 1.0f); + for (int j = 0; j < num_values / 2; ++j) { + std::complex cf(output_floats[j * 2], output_floats[j * 2 + 1]); + EXPECT_FLOAT_EQ(std::norm(cf), 1.0f); + } + } + CloseGraph(); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/feedback_tensors_calculator.cc b/mediapipe/calculators/tensor/feedback_tensors_calculator.cc new file mode 100644 index 000000000..5fa171ef3 --- /dev/null +++ b/mediapipe/calculators/tensor/feedback_tensors_calculator.cc @@ -0,0 +1,165 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/status/status.h" +#include "mediapipe/calculators/tensor/feedback_tensors_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" + +namespace mediapipe { +namespace api2 { + +namespace { +constexpr char kInputTensorsTag[] = "INPUT_TENSORS"; +constexpr char kFeedbackTensorsTag[] = "FEEDBACK_TENSORS"; +constexpr char kOutputTensorsTag[] = "TENSORS"; + +using Tensors = std::vector; +} // namespace + +// FeedbackTensorsCalculator groups the input and the feedback (typically +// recurrent neural network cell state output tensors from the previous run) +// tensor vectors as the input tensor vector for the next recurrent model cell +// inference. On the first step, the feedback tensor is filled with zeros to +// jumpstart the loop. +class FeedbackTensorsCalculator : public Node { + public: + static constexpr Input kFeedbackTensorsIn{kFeedbackTensorsTag}; + static constexpr Input kInputTensorsIn{kInputTensorsTag}; + static constexpr Output kTensorsOut{kOutputTensorsTag}; + + MEDIAPIPE_NODE_CONTRACT(kFeedbackTensorsIn, kInputTensorsIn, kTensorsOut); + + static absl::Status GetContract(CalculatorContract* cc) { + cc->SetProcessTimestampBounds(true); + return absl::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) override { + const auto& options = + cc->Options(); + + const auto& shape_dims = options.feedback_tensor_shape().dims(); + feedback_tensor_shape_.dims.assign(shape_dims.begin(), shape_dims.end()); + feedback_tensor_size_ = feedback_tensor_shape_.num_elements(); + + num_feedback_tensors_ = options.num_feedback_tensors(); + + feedback_tensors_location_ = options.location(); + + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + if (feedback_tensors_location_ == + mediapipe::FeedbackTensorsCalculatorOptions::NONE) { + kTensorsOut(cc).Send(kInputTensorsIn(cc).packet().As()); + return absl::OkStatus(); + } + + std::vector outputs; + switch (feedback_tensors_location_) { + case mediapipe::FeedbackTensorsCalculatorOptions::PREPENDED: + MP_RETURN_IF_ERROR(AddFeedbackTensors(cc, outputs)); + MP_RETURN_IF_ERROR(AddInputTensors(cc, outputs)); + break; + case mediapipe::FeedbackTensorsCalculatorOptions::APPENDED: + MP_RETURN_IF_ERROR(AddInputTensors(cc, outputs)); + MP_RETURN_IF_ERROR(AddFeedbackTensors(cc, outputs)); + break; + default: + return absl::InvalidArgumentError( + "Unsupported feedback tensors location"); + } + kTensorsOut(cc).Send(std::move(outputs)); + return absl::OkStatus(); + } + + private: + absl::Status AddInputTensors(CalculatorContext* cc, + std::vector& outputs) { + absl::StatusOr>> input_tensors = + cc->Inputs() + .Tag(kInputTensorsTag) + .Value() + .Consume>(); + if (!input_tensors.ok()) { + return absl::InternalError("The input tensors packet is not consumable"); + } + RET_CHECK(*input_tensors); + std::vector& inputs = **input_tensors; + outputs.insert(outputs.end(), std::make_move_iterator(inputs.begin()), + std::make_move_iterator(inputs.end())); + return absl::OkStatus(); + } + + absl::Status AddFeedbackTensors(CalculatorContext* cc, + std::vector& outputs) { + if (first_run_) { + for (int index = 0; index < num_feedback_tensors_; ++index) { + Tensor initial_feedback_tensor(Tensor::ElementType::kFloat32, + feedback_tensor_shape_); + float* data = initial_feedback_tensor.GetCpuWriteView().buffer(); + std::fill_n(data, feedback_tensor_size_, 0.0f); + outputs.push_back(std::move(initial_feedback_tensor)); + } + first_run_ = false; + return absl::OkStatus(); + } + + if (num_feedback_tensors_ != kFeedbackTensorsIn(cc)->size()) { + return absl::InvalidArgumentError( + "The number of tensors fed back differs from the configuration"); + } + absl::StatusOr>> feedback_tensors = + cc->Inputs() + .Tag(kFeedbackTensorsTag) + .Value() + .Consume>(); + if (!feedback_tensors.ok()) { + return absl::InternalError( + "The feedback tensors packet is not consumable"); + } + RET_CHECK(*feedback_tensors); + std::vector& feedbacks = **feedback_tensors; + for (const auto& feedback : feedbacks) { + if (feedback.shape().dims != feedback_tensor_shape_.dims) { + return absl::InvalidArgumentError( + "The shape of a tensor fed back differs from the configuration"); + } + } + outputs.insert(outputs.end(), std::make_move_iterator(feedbacks.begin()), + std::make_move_iterator(feedbacks.end())); + + return absl::OkStatus(); + } + + Tensor::Shape feedback_tensor_shape_; + int num_feedback_tensors_ = 0; + mediapipe::FeedbackTensorsCalculatorOptions::FeedbackTensorsLocation + feedback_tensors_location_; + + int feedback_tensor_size_ = 0; + bool first_run_ = true; +}; + +MEDIAPIPE_REGISTER_NODE(FeedbackTensorsCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/feedback_tensors_calculator.proto b/mediapipe/calculators/tensor/feedback_tensors_calculator.proto new file mode 100644 index 000000000..ac36b6780 --- /dev/null +++ b/mediapipe/calculators/tensor/feedback_tensors_calculator.proto @@ -0,0 +1,47 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message FeedbackTensorsCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional FeedbackTensorsCalculatorOptions ext = 474496252; + } + + // Represents the dimensions of a tensor starting from the outermost size. + message TensorShape { + repeated int32 dims = 1 [packed = true]; + } + + // The shape of the feedback tensors to add. + optional TensorShape feedback_tensor_shape = 1; + // The number of the feedback tensors to add. + optional int32 num_feedback_tensors = 2 [default = 1]; + + enum FeedbackTensorsLocation { + // The feedback tensors will not be added. + NONE = 0; + // The feedback tensors will be added before the input tensors. + PREPENDED = 1; + // The feedback tensors will be added after the input tensors. + APPENDED = 2; + } + + // Determines the location of the feedback tensor(s) in the output vector. + optional FeedbackTensorsLocation location = 3 [default = APPENDED]; +} diff --git a/mediapipe/calculators/tensor/feedback_tensors_calculator_test.cc b/mediapipe/calculators/tensor/feedback_tensors_calculator_test.cc new file mode 100644 index 000000000..5797cc31c --- /dev/null +++ b/mediapipe/calculators/tensor/feedback_tensors_calculator_test.cc @@ -0,0 +1,389 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "mediapipe/calculators/tensor/feedback_tensors_calculator.pb.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { +namespace { + +using ::mediapipe::CalculatorGraphConfig; +using ::testing::ElementsAreArray; +using ::testing::Not; +using Tensors = std::vector; + +template +struct TensorElementType { + static constexpr Tensor::ElementType value = Tensor::ElementType::kNone; +}; + +template <> +struct TensorElementType { + static constexpr Tensor::ElementType value = Tensor::ElementType::kFloat32; +}; + +template <> +struct TensorElementType { + static constexpr Tensor::ElementType value = Tensor::ElementType::kInt8; +}; + +template <> +struct TensorElementType { + static constexpr Tensor::ElementType value = Tensor::ElementType::kUInt8; +}; + +template <> +struct TensorElementType { + static constexpr Tensor::ElementType value = Tensor::ElementType::kInt32; +}; + +template +Tensor MakeTensor(std::initializer_list shape, + std::initializer_list values) { + Tensor tensor(TensorElementType::value, shape); + CHECK_EQ(values.size(), tensor.shape().num_elements()) + << "The size of `values` is incompatible with `shape`"; + absl::c_copy(values, tensor.GetCpuWriteView().buffer()); + return tensor; +} + +template +void ValidateTensor(const Tensor& tensor, + const std::vector& expected_shape, + const std::vector& expected_values) { + ASSERT_EQ(tensor.element_type(), TensorElementType::value); + EXPECT_EQ(tensor.shape().dims, expected_shape); + EXPECT_EQ(tensor.shape().num_elements(), expected_values.size()); + + auto* tensor_buffer = tensor.GetCpuReadView().buffer(); + const std::vector tensor_values( + tensor_buffer, tensor_buffer + tensor.shape().num_elements()); + EXPECT_THAT(tensor_values, ElementsAreArray(expected_values)); +} + +TEST(FeedbackTensorsCalculatorTest, AppendsFeedback) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: "input" + input_stream: "feedback" + node { + calculator: "FeedbackTensorsCalculator" + input_stream: "INPUT_TENSORS:input" + input_stream: "FEEDBACK_TENSORS:feedback" + output_stream: "TENSORS:output" + options: { + [mediapipe.FeedbackTensorsCalculatorOptions.ext] { + feedback_tensor_shape: { dims: 2 dims: 3 } + location: APPENDED + } + } + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + + auto initial_input_tensors = std::make_unique(); + initial_input_tensors->push_back( + MakeTensor({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", Adopt(initial_input_tensors.release()).At(Timestamp(1)))); + // At the beginning, the loopback packet with the model feedback is missing. + // The calculator has to assume it's all-zero with the shape from the options. + + auto later_input_tensors = std::make_unique(); + later_input_tensors->push_back( + MakeTensor({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", Adopt(later_input_tensors.release()).At(Timestamp(2)))); + auto later_feedback_tensors = std::make_unique(); + later_feedback_tensors->push_back( + MakeTensor({2, 3}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2)))); + + MP_ASSERT_OK(graph.CloseAllInputStreams()) + << "Couldn't close the graph inputs"; + MP_ASSERT_OK(graph.WaitUntilDone()) << "Couldn't finalize the graph run"; + + ASSERT_EQ(output_packets.size(), 2); + + const Tensors& initial_combined_tensors = output_packets[0].Get(); + ASSERT_EQ(initial_combined_tensors.size(), 2); + ValidateTensor(initial_combined_tensors[0], + /*expected_shape=*/{2, 4}, + /*expected_values=*/{1, 2, 3, 4, 5, 6, 7, 8}); + // The initial feedback is zero. + ValidateTensor(initial_combined_tensors[1], /*expected_shape=*/{2, 3}, + /*expected_values=*/{0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + + const Tensors& later_combined_tensors = output_packets[1].Get(); + ASSERT_EQ(later_combined_tensors.size(), 2); + ValidateTensor(later_combined_tensors[0], + /*expected_shape=*/{2, 4}, + /*expected_values=*/{8, 7, 6, 5, 4, 3, 2, 1}); + // Afterwards, the provided feedback is passed through. + ValidateTensor( + later_combined_tensors[1], /*expected_shape=*/{2, 3}, + /*expected_values=*/{-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}); +} + +TEST(FeedbackTensorsCalculatorTest, PrependsFeedback) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: "input" + input_stream: "feedback" + node { + calculator: "FeedbackTensorsCalculator" + input_stream: "INPUT_TENSORS:input" + input_stream: "FEEDBACK_TENSORS:feedback" + output_stream: "TENSORS:output" + options: { + [mediapipe.FeedbackTensorsCalculatorOptions.ext] { + feedback_tensor_shape: { dims: 3 dims: 2 } + location: PREPENDED + } + } + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + + auto initial_input_tensors = std::make_unique(); + initial_input_tensors->push_back( + MakeTensor({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", Adopt(initial_input_tensors.release()).At(Timestamp(1)))); + // At the beginning, the loopback packet with the model feedback is missing. + // The calculator has to assume it's all-zero with the shape from the options. + + auto later_input_tensors = std::make_unique(); + later_input_tensors->push_back( + MakeTensor({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", Adopt(later_input_tensors.release()).At(Timestamp(2)))); + auto later_feedback_tensors = std::make_unique(); + later_feedback_tensors->push_back( + MakeTensor({3, 2}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2)))); + + MP_ASSERT_OK(graph.CloseAllInputStreams()) + << "Couldn't close the graph inputs"; + MP_ASSERT_OK(graph.WaitUntilDone()) << "Couldn't finalize the graph run"; + + ASSERT_EQ(output_packets.size(), 2); + + const Tensors& initial_combined_tensors = output_packets[0].Get(); + ASSERT_EQ(initial_combined_tensors.size(), 2); + // The initial feedback is zero. + ValidateTensor(initial_combined_tensors[0], /*expected_shape=*/{3, 2}, + /*expected_values=*/{0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + ValidateTensor(initial_combined_tensors[1], + /*expected_shape=*/{2, 4}, + /*expected_values=*/{1, 2, 3, 4, 5, 6, 7, 8}); + + const Tensors& later_combined_tensors = output_packets[1].Get(); + ASSERT_EQ(later_combined_tensors.size(), 2); + // Afterwards, the provided feedback is passed through. + ValidateTensor( + later_combined_tensors[0], /*expected_shape=*/{3, 2}, + /*expected_values=*/{-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}); + ValidateTensor(later_combined_tensors[1], + /*expected_shape=*/{2, 4}, + /*expected_values=*/{8, 7, 6, 5, 4, 3, 2, 1}); +} + +TEST(FeedbackTensorsCalculatorTest, NoFeedback) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: "input" + input_stream: "feedback" + node { + calculator: "FeedbackTensorsCalculator" + input_stream: "INPUT_TENSORS:input" + input_stream: "FEEDBACK_TENSORS:feedback" + output_stream: "TENSORS:output" + options: { + [mediapipe.FeedbackTensorsCalculatorOptions.ext] { + feedback_tensor_shape: { dims: 3 dims: 4 } + location: NONE + } + } + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + + auto initial_input_tensors = std::make_unique(); + initial_input_tensors->push_back( + MakeTensor({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", Adopt(initial_input_tensors.release()).At(Timestamp(1)))); + // At the beginning, the loopback packet with the model feedback is missing. + + auto later_input_tensors = std::make_unique(); + later_input_tensors->push_back( + MakeTensor({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", Adopt(later_input_tensors.release()).At(Timestamp(2)))); + // This feedback should be ignored due to `location: NONE`. + auto later_feedback_tensors = std::make_unique(); + later_feedback_tensors->push_back( + MakeTensor({2, 3}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2)))); + + MP_ASSERT_OK(graph.CloseAllInputStreams()) + << "Couldn't close the graph inputs"; + MP_ASSERT_OK(graph.WaitUntilDone()) << "Couldn't finalize the graph run"; + + ASSERT_EQ(output_packets.size(), 2); + + const Tensors& initial_combined_tensors = output_packets[0].Get(); + ASSERT_EQ(initial_combined_tensors.size(), 1); + ValidateTensor(initial_combined_tensors[0], + /*expected_shape=*/{2, 4}, + /*expected_values=*/{1, 2, 3, 4, 5, 6, 7, 8}); + // No feedback due to `location: NONE`. + + const Tensors& later_combined_tensors = output_packets[1].Get(); + ASSERT_EQ(later_combined_tensors.size(), 1); + ValidateTensor(later_combined_tensors[0], + /*expected_shape=*/{2, 4}, + /*expected_values=*/{8, 7, 6, 5, 4, 3, 2, 1}); +} + +TEST(FeedbackTensorsCalculatorTest, ChecksTensorNumber) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: "input" + input_stream: "feedback" + node { + calculator: "FeedbackTensorsCalculator" + input_stream: "INPUT_TENSORS:input" + input_stream: "FEEDBACK_TENSORS:feedback" + output_stream: "TENSORS:output" + options: { + [mediapipe.FeedbackTensorsCalculatorOptions.ext] { + num_feedback_tensors: 2 + feedback_tensor_shape: { dims: 2 dims: 3 } + location: PREPENDED + } + } + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + + auto initial_input_tensors = std::make_unique(); + initial_input_tensors->push_back( + MakeTensor({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", Adopt(initial_input_tensors.release()).At(Timestamp(1)))); + // At the beginning, the loopback packet with the model feedback is missing. + + auto later_input_tensors = std::make_unique(); + later_input_tensors->push_back( + MakeTensor({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", Adopt(later_input_tensors.release()).At(Timestamp(2)))); + // This feedback should be ignored due to `location: NONE`. + auto later_feedback_tensors = std::make_unique(); + later_feedback_tensors->push_back( + MakeTensor({2, 3}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2)))); + + MP_ASSERT_OK(graph.CloseAllInputStreams()) + << "Couldn't close the graph inputs"; + EXPECT_THAT(graph.WaitUntilDone(), Not(IsOk())) + << "Tensor number mismatch missed"; +} + +TEST(FeedbackTensorsCalculatorTest, ChecksShape) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: "input" + input_stream: "feedback" + node { + calculator: "FeedbackTensorsCalculator" + input_stream: "INPUT_TENSORS:input" + input_stream: "FEEDBACK_TENSORS:feedback" + output_stream: "TENSORS:output" + options: { + [mediapipe.FeedbackTensorsCalculatorOptions.ext] { + feedback_tensor_shape: { dims: 3 dims: 4 } + location: APPENDED + } + } + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + + auto initial_input_tensors = std::make_unique(); + initial_input_tensors->push_back( + MakeTensor({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", Adopt(initial_input_tensors.release()).At(Timestamp(1)))); + // At the beginning, the loopback packet with the model feedback is missing. + + auto later_input_tensors = std::make_unique(); + later_input_tensors->push_back( + MakeTensor({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", Adopt(later_input_tensors.release()).At(Timestamp(2)))); + // This feedback should be ignored due to `location: NONE`. + auto later_feedback_tensors = std::make_unique(); + later_feedback_tensors->push_back( + MakeTensor({2, 3}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2)))); + + MP_ASSERT_OK(graph.CloseAllInputStreams()) + << "Couldn't close the graph inputs"; + EXPECT_THAT(graph.WaitUntilDone(), Not(IsOk())) + << "Tensor shape mismatch missed"; +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc index 52cd9e0bb..e1809a017 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc @@ -231,7 +231,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, // Session must be set. ASSERT_NE(session.session, nullptr); std::vector devices; - ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::Status::OK()); + ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::OkStatus()); EXPECT_THAT(devices.size(), 10); } diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc index 46cbf41cb..5c6de3e86 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc @@ -220,7 +220,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, // Session must be set. ASSERT_NE(session.session, nullptr); std::vector devices; - ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::Status::OK()); + ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::OkStatus()); EXPECT_THAT(devices.size(), 10); } diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index b132db01d..1d2f279aa 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -135,6 +135,7 @@ filegroup( srcs = [ "testdata/anchor_golden_file_0.txt", "testdata/anchor_golden_file_1.txt", + "testdata/anchor_golden_file_2.txt", ], ) diff --git a/mediapipe/calculators/tflite/ssd_anchors_calculator.cc b/mediapipe/calculators/tflite/ssd_anchors_calculator.cc index f618b2f6a..1d8f6e3ea 100644 --- a/mediapipe/calculators/tflite/ssd_anchors_calculator.cc +++ b/mediapipe/calculators/tflite/ssd_anchors_calculator.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h" @@ -24,6 +25,19 @@ namespace mediapipe { namespace { +struct MultiScaleAnchorInfo { + int32 level; + std::vector aspect_ratios; + std::vector scales; + std::pair base_anchor_size; + std::pair anchor_stride; +}; + +struct FeatureMapDim { + int height; + int width; +}; + float CalculateScale(float min_scale, float max_scale, int stride_index, int num_strides) { if (num_strides == 1) { @@ -34,6 +48,71 @@ float CalculateScale(float min_scale, float max_scale, int stride_index, } } +int GetNumLayers(const SsdAnchorsCalculatorOptions& options) { + if (options.multiscale_anchor_generation()) { + return (options.max_level() - options.min_level() + 1); + } + return options.num_layers(); +} + +FeatureMapDim GetFeatureMapDimensions( + const SsdAnchorsCalculatorOptions& options, int index) { + FeatureMapDim feature_map_dims; + if (options.feature_map_height_size()) { + feature_map_dims.height = options.feature_map_height(index); + feature_map_dims.width = options.feature_map_width(index); + } else { + const int stride = options.strides(index); + feature_map_dims.height = + std::ceil(1.0f * options.input_size_height() / stride); + feature_map_dims.width = + std::ceil(1.0f * options.input_size_width() / stride); + } + return feature_map_dims; +} + +// Although we have stride for both x and y, only one value is used for offset +// calculation. See +// tensorflow_models/object_detection/anchor_generators/multiscale_grid_anchor_generator.py;l=121 +std::pair GetMultiScaleAnchorOffset( + const SsdAnchorsCalculatorOptions& options, const float stride, + const int level) { + std::pair result(0., 0.); + int denominator = std::pow(2, level); + if (options.input_size_height() % denominator == 0 || + options.input_size_height() == 1) { + result.first = stride / 2.0; + } + if (options.input_size_width() % denominator == 0 || + options.input_size_width() == 1) { + result.second = stride / 2.0; + } + return result; +} + +void NormalizeAnchor(const int input_height, const int input_width, + Anchor* anchor) { + anchor->set_h(anchor->h() / (float)input_height); + anchor->set_w(anchor->w() / (float)input_width); + anchor->set_y_center(anchor->y_center() / (float)input_height); + anchor->set_x_center(anchor->x_center() / (float)input_width); +} + +Anchor CalculateAnchorBox(const int y_center, const int x_center, + const float scale, const float aspect_ratio, + const std::pair base_anchor_size, + // y-height first + const std::pair anchor_stride, + const std::pair anchor_offset) { + Anchor result; + float ratio_sqrt = std::sqrt(aspect_ratio); + result.set_h(scale * base_anchor_size.first / ratio_sqrt); + result.set_w(scale * ratio_sqrt * base_anchor_size.second); + result.set_y_center(y_center * anchor_stride.first + anchor_offset.first); + result.set_x_center(x_center * anchor_stride.second + anchor_offset.second); + return result; +} + } // namespace // Generate anchors for SSD object detection model. @@ -95,9 +174,77 @@ class SsdAnchorsCalculator : public CalculatorBase { private: static absl::Status GenerateAnchors( std::vector* anchors, const SsdAnchorsCalculatorOptions& options); + + static absl::Status GenerateMultiScaleAnchors( + std::vector* anchors, const SsdAnchorsCalculatorOptions& options); }; REGISTER_CALCULATOR(SsdAnchorsCalculator); +// Generates grid anchors on the fly corresponding to multiple CNN layers as +// described in: +// "Focal Loss for Dense Object Detection" (https://arxiv.org/abs/1708.02002) +// T.-Y. Lin, P. Goyal, R. Girshick, K. He, P. Dollar +absl::Status SsdAnchorsCalculator::GenerateMultiScaleAnchors( + std::vector* anchors, const SsdAnchorsCalculatorOptions& options) { + std::vector anchor_infos; + for (int i = options.min_level(); i <= options.max_level(); ++i) { + MultiScaleAnchorInfo current_anchor_info; + // level + current_anchor_info.level = i; + // aspect_ratios + for (const float aspect_ratio : options.aspect_ratios()) { + current_anchor_info.aspect_ratios.push_back(aspect_ratio); + } + + // scale + for (int i = 0; i < options.scales_per_octave(); ++i) { + current_anchor_info.scales.push_back( + std::pow(2.0, (double)i / (double)options.scales_per_octave())); + } + + // anchor stride + float anchor_stride = std::pow(2.0, i); + current_anchor_info.anchor_stride = + std::make_pair(anchor_stride, anchor_stride); + + // base_anchor_size + current_anchor_info.base_anchor_size = + std::make_pair(anchor_stride * options.anchor_scale(), + anchor_stride * options.anchor_scale()); + anchor_infos.push_back(current_anchor_info); + } + + for (unsigned int i = 0; i < anchor_infos.size(); ++i) { + FeatureMapDim dimensions = GetFeatureMapDimensions(options, i); + for (int y = 0; y < dimensions.height; ++y) { + for (int x = 0; x < dimensions.width; ++x) { + // loop over combination of scale and aspect ratio + for (unsigned int j = 0; j < anchor_infos[i].aspect_ratios.size(); + ++j) { + for (unsigned int k = 0; k < anchor_infos[i].scales.size(); ++k) { + Anchor anchor = CalculateAnchorBox( + /*y_center=*/y, /*x_center=*/x, anchor_infos[i].scales[k], + anchor_infos[i].aspect_ratios[j], + anchor_infos[i].base_anchor_size, + /*anchor_stride=*/anchor_infos[i].anchor_stride, + /*anchor_offset=*/ + GetMultiScaleAnchorOffset(options, + anchor_infos[i].anchor_stride.first, + anchor_infos[i].level)); + if (options.normalize_coordinates()) { + NormalizeAnchor(options.input_size_height(), + options.input_size_width(), &anchor); + } + anchors->push_back(anchor); + } + } + } + } + } + + return absl::OkStatus(); +} + absl::Status SsdAnchorsCalculator::GenerateAnchors( std::vector* anchors, const SsdAnchorsCalculatorOptions& options) { // Verify the options. @@ -106,15 +253,21 @@ absl::Status SsdAnchorsCalculator::GenerateAnchors( "Both feature map shape and strides are missing. Must provide either " "one."); } + const int kNumLayers = GetNumLayers(options); + if (options.feature_map_height_size()) { if (options.strides_size()) { LOG(ERROR) << "Found feature map shapes. Strides will be ignored."; } - CHECK_EQ(options.feature_map_height_size(), options.num_layers()); + CHECK_EQ(options.feature_map_height_size(), kNumLayers); CHECK_EQ(options.feature_map_height_size(), options.feature_map_width_size()); } else { - CHECK_EQ(options.strides_size(), options.num_layers()); + CHECK_EQ(options.strides_size(), kNumLayers); + } + + if (options.multiscale_anchor_generation()) { + return GenerateMultiScaleAnchors(anchors, options); } int layer_id = 0; diff --git a/mediapipe/calculators/tflite/ssd_anchors_calculator.proto b/mediapipe/calculators/tflite/ssd_anchors_calculator.proto index 911e4ac92..3b1e36700 100644 --- a/mediapipe/calculators/tflite/ssd_anchors_calculator.proto +++ b/mediapipe/calculators/tflite/ssd_anchors_calculator.proto @@ -60,4 +60,30 @@ message SsdAnchorsCalculatorOptions { // This option can be used when the predicted anchor width and height are in // pixels. optional bool fixed_anchor_size = 14 [default = false]; + + // Generates grid anchors on the fly corresponding to multiple CNN layers as + // described in: + // "Focal Loss for Dense Object Detection" (https://arxiv.org/abs/1708.02002) + // T.-Y. Lin, P. Goyal, R. Girshick, K. He, P. Dollar + optional bool multiscale_anchor_generation = 15 [default = false]; + + // minimum level in feature pyramid + // for multiscale_anchor_generation only! + optional int32 min_level = 16 [default = 3]; + + // maximum level in feature pyramid + // for multiscale_anchor_generation only! + optional int32 max_level = 17 [default = 7]; + + // Scale of anchor to feature stride + // for multiscale_anchor_generation only! + optional float anchor_scale = 18 [default = 4.0]; + + // Number of intermediate scale each scale octave + // for multiscale_anchor_generation only! + optional int32 scales_per_octave = 19 [default = 2]; + + // Whether to produce anchors in normalized coordinates. + // for multiscale_anchor_generation only! + optional bool normalize_coordinates = 20 [default = true]; } diff --git a/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc b/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc index 3b72a287e..595dda8bc 100644 --- a/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc +++ b/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc @@ -33,9 +33,6 @@ std::string GetGoldenFilePath(const std::string& filename) { void ParseAnchorsFromText(const std::string& text, std::vector* anchors) { - const std::string line_delimiter = "\n"; - const std::string number_delimiter = ","; - std::istringstream stream(text); std::string line; while (std::getline(stream, line)) { @@ -64,6 +61,8 @@ void CompareAnchors(const std::vector& anchors_0, testing::FloatNear(anchor_1.x_center(), 1e-5)); EXPECT_THAT(anchor_0.y_center(), testing::FloatNear(anchor_1.y_center(), 1e-5)); + EXPECT_THAT(anchor_0.h(), testing::FloatNear(anchor_1.h(), 1e-5)); + EXPECT_THAT(anchor_0.w(), testing::FloatNear(anchor_1.w(), 1e-5)); } } @@ -148,4 +147,40 @@ TEST(SsdAnchorCalculatorTest, MobileSSDConfig) { CompareAnchors(anchors, anchors_golden); } +TEST(SsdAnchorCalculatorTest, RetinaNetSSDConfig) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + options { + [mediapipe.SsdAnchorsCalculatorOptions.ext] { + input_size_height: 640 + input_size_width: 640 + strides: 64 + strides: 128 + aspect_ratios: 1.0 + aspect_ratios: 2.0 + aspect_ratios: 0.5 + multiscale_anchor_generation: true + min_level: 6 + max_level: 7 + anchor_scale: 3.0 + scales_per_octave: 3 + } + } + )pb")); + + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const auto& anchors = + runner.OutputSidePackets().Index(0).Get>(); + + std::string anchors_string; + MP_EXPECT_OK(mediapipe::file::GetContents( + GetGoldenFilePath("anchor_golden_file_2.txt"), &anchors_string)); + + std::vector anchors_golden; + ParseAnchorsFromText(anchors_string, &anchors_golden); + + CompareAnchors(anchors, anchors_golden); +} + } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/testdata/anchor_golden_file_2.txt b/mediapipe/calculators/tflite/testdata/anchor_golden_file_2.txt new file mode 100644 index 000000000..651d977f5 --- /dev/null +++ b/mediapipe/calculators/tflite/testdata/anchor_golden_file_2.txt @@ -0,0 +1,1125 @@ +0.05 0.05 0.3 0.3 +0.05 0.05 0.377976 0.377976 +0.05 0.05 0.47622 0.47622 +0.05 0.05 0.424264 0.212132 +0.05 0.05 0.534539 0.26727 +0.05 0.05 0.673477 0.336739 +0.05 0.05 0.212132 0.424264 +0.05 0.05 0.26727 0.534539 +0.05 0.05 0.336739 0.673477 +0.15 0.05 0.3 0.3 +0.15 0.05 0.377976 0.377976 +0.15 0.05 0.47622 0.47622 +0.15 0.05 0.424264 0.212132 +0.15 0.05 0.534539 0.26727 +0.15 0.05 0.673477 0.336739 +0.15 0.05 0.212132 0.424264 +0.15 0.05 0.26727 0.534539 +0.15 0.05 0.336739 0.673477 +0.25 0.05 0.3 0.3 +0.25 0.05 0.377976 0.377976 +0.25 0.05 0.47622 0.47622 +0.25 0.05 0.424264 0.212132 +0.25 0.05 0.534539 0.26727 +0.25 0.05 0.673477 0.336739 +0.25 0.05 0.212132 0.424264 +0.25 0.05 0.26727 0.534539 +0.25 0.05 0.336739 0.673477 +0.35 0.05 0.3 0.3 +0.35 0.05 0.377976 0.377976 +0.35 0.05 0.47622 0.47622 +0.35 0.05 0.424264 0.212132 +0.35 0.05 0.534539 0.26727 +0.35 0.05 0.673477 0.336739 +0.35 0.05 0.212132 0.424264 +0.35 0.05 0.26727 0.534539 +0.35 0.05 0.336739 0.673477 +0.45 0.05 0.3 0.3 +0.45 0.05 0.377976 0.377976 +0.45 0.05 0.47622 0.47622 +0.45 0.05 0.424264 0.212132 +0.45 0.05 0.534539 0.26727 +0.45 0.05 0.673477 0.336739 +0.45 0.05 0.212132 0.424264 +0.45 0.05 0.26727 0.534539 +0.45 0.05 0.336739 0.673477 +0.55 0.05 0.3 0.3 +0.55 0.05 0.377976 0.377976 +0.55 0.05 0.47622 0.47622 +0.55 0.05 0.424264 0.212132 +0.55 0.05 0.534539 0.26727 +0.55 0.05 0.673477 0.336739 +0.55 0.05 0.212132 0.424264 +0.55 0.05 0.26727 0.534539 +0.55 0.05 0.336739 0.673477 +0.65 0.05 0.3 0.3 +0.65 0.05 0.377976 0.377976 +0.65 0.05 0.47622 0.47622 +0.65 0.05 0.424264 0.212132 +0.65 0.05 0.534539 0.26727 +0.65 0.05 0.673477 0.336739 +0.65 0.05 0.212132 0.424264 +0.65 0.05 0.26727 0.534539 +0.65 0.05 0.336739 0.673477 +0.75 0.05 0.3 0.3 +0.75 0.05 0.377976 0.377976 +0.75 0.05 0.47622 0.47622 +0.75 0.05 0.424264 0.212132 +0.75 0.05 0.534539 0.26727 +0.75 0.05 0.673477 0.336739 +0.75 0.05 0.212132 0.424264 +0.75 0.05 0.26727 0.534539 +0.75 0.05 0.336739 0.673477 +0.85 0.05 0.3 0.3 +0.85 0.05 0.377976 0.377976 +0.85 0.05 0.47622 0.47622 +0.85 0.05 0.424264 0.212132 +0.85 0.05 0.534539 0.26727 +0.85 0.05 0.673477 0.336739 +0.85 0.05 0.212132 0.424264 +0.85 0.05 0.26727 0.534539 +0.85 0.05 0.336739 0.673477 +0.95 0.05 0.3 0.3 +0.95 0.05 0.377976 0.377976 +0.95 0.05 0.47622 0.47622 +0.95 0.05 0.424264 0.212132 +0.95 0.05 0.534539 0.26727 +0.95 0.05 0.673477 0.336739 +0.95 0.05 0.212132 0.424264 +0.95 0.05 0.267269 0.534539 +0.95 0.05 0.336739 0.673477 +0.05 0.15 0.3 0.3 +0.05 0.15 0.377976 0.377976 +0.05 0.15 0.47622 0.47622 +0.05 0.15 0.424264 0.212132 +0.05 0.15 0.534539 0.26727 +0.05 0.15 0.673477 0.336739 +0.05 0.15 0.212132 0.424264 +0.05 0.15 0.26727 0.534539 +0.05 0.15 0.336739 0.673477 +0.15 0.15 0.3 0.3 +0.15 0.15 0.377976 0.377976 +0.15 0.15 0.47622 0.47622 +0.15 0.15 0.424264 0.212132 +0.15 0.15 0.534539 0.26727 +0.15 0.15 0.673477 0.336739 +0.15 0.15 0.212132 0.424264 +0.15 0.15 0.26727 0.534539 +0.15 0.15 0.336739 0.673477 +0.25 0.15 0.3 0.3 +0.25 0.15 0.377976 0.377976 +0.25 0.15 0.47622 0.47622 +0.25 0.15 0.424264 0.212132 +0.25 0.15 0.534539 0.26727 +0.25 0.15 0.673477 0.336739 +0.25 0.15 0.212132 0.424264 +0.25 0.15 0.26727 0.534539 +0.25 0.15 0.336739 0.673477 +0.35 0.15 0.3 0.3 +0.35 0.15 0.377976 0.377976 +0.35 0.15 0.47622 0.47622 +0.35 0.15 0.424264 0.212132 +0.35 0.15 0.534539 0.26727 +0.35 0.15 0.673477 0.336739 +0.35 0.15 0.212132 0.424264 +0.35 0.15 0.26727 0.534539 +0.35 0.15 0.336739 0.673477 +0.45 0.15 0.3 0.3 +0.45 0.15 0.377976 0.377976 +0.45 0.15 0.47622 0.47622 +0.45 0.15 0.424264 0.212132 +0.45 0.15 0.534539 0.26727 +0.45 0.15 0.673477 0.336739 +0.45 0.15 0.212132 0.424264 +0.45 0.15 0.26727 0.534539 +0.45 0.15 0.336739 0.673477 +0.55 0.15 0.3 0.3 +0.55 0.15 0.377976 0.377976 +0.55 0.15 0.47622 0.47622 +0.55 0.15 0.424264 0.212132 +0.55 0.15 0.534539 0.26727 +0.55 0.15 0.673477 0.336739 +0.55 0.15 0.212132 0.424264 +0.55 0.15 0.26727 0.534539 +0.55 0.15 0.336739 0.673477 +0.65 0.15 0.3 0.3 +0.65 0.15 0.377976 0.377976 +0.65 0.15 0.47622 0.47622 +0.65 0.15 0.424264 0.212132 +0.65 0.15 0.534539 0.26727 +0.65 0.15 0.673477 0.336739 +0.65 0.15 0.212132 0.424264 +0.65 0.15 0.26727 0.534539 +0.65 0.15 0.336739 0.673477 +0.75 0.15 0.3 0.3 +0.75 0.15 0.377976 0.377976 +0.75 0.15 0.47622 0.47622 +0.75 0.15 0.424264 0.212132 +0.75 0.15 0.534539 0.26727 +0.75 0.15 0.673477 0.336739 +0.75 0.15 0.212132 0.424264 +0.75 0.15 0.26727 0.534539 +0.75 0.15 0.336739 0.673477 +0.85 0.15 0.3 0.3 +0.85 0.15 0.377976 0.377976 +0.85 0.15 0.47622 0.47622 +0.85 0.15 0.424264 0.212132 +0.85 0.15 0.534539 0.26727 +0.85 0.15 0.673477 0.336739 +0.85 0.15 0.212132 0.424264 +0.85 0.15 0.26727 0.534539 +0.85 0.15 0.336739 0.673477 +0.95 0.15 0.3 0.3 +0.95 0.15 0.377976 0.377976 +0.95 0.15 0.47622 0.47622 +0.95 0.15 0.424264 0.212132 +0.95 0.15 0.534539 0.26727 +0.95 0.15 0.673477 0.336739 +0.95 0.15 0.212132 0.424264 +0.95 0.15 0.267269 0.534539 +0.95 0.15 0.336739 0.673477 +0.05 0.25 0.3 0.3 +0.05 0.25 0.377976 0.377976 +0.05 0.25 0.47622 0.47622 +0.05 0.25 0.424264 0.212132 +0.05 0.25 0.534539 0.26727 +0.05 0.25 0.673477 0.336739 +0.05 0.25 0.212132 0.424264 +0.05 0.25 0.26727 0.534539 +0.05 0.25 0.336739 0.673477 +0.15 0.25 0.3 0.3 +0.15 0.25 0.377976 0.377976 +0.15 0.25 0.47622 0.47622 +0.15 0.25 0.424264 0.212132 +0.15 0.25 0.534539 0.26727 +0.15 0.25 0.673477 0.336739 +0.15 0.25 0.212132 0.424264 +0.15 0.25 0.26727 0.534539 +0.15 0.25 0.336739 0.673477 +0.25 0.25 0.3 0.3 +0.25 0.25 0.377976 0.377976 +0.25 0.25 0.47622 0.47622 +0.25 0.25 0.424264 0.212132 +0.25 0.25 0.534539 0.26727 +0.25 0.25 0.673477 0.336739 +0.25 0.25 0.212132 0.424264 +0.25 0.25 0.26727 0.534539 +0.25 0.25 0.336739 0.673477 +0.35 0.25 0.3 0.3 +0.35 0.25 0.377976 0.377976 +0.35 0.25 0.47622 0.47622 +0.35 0.25 0.424264 0.212132 +0.35 0.25 0.534539 0.26727 +0.35 0.25 0.673477 0.336739 +0.35 0.25 0.212132 0.424264 +0.35 0.25 0.26727 0.534539 +0.35 0.25 0.336739 0.673477 +0.45 0.25 0.3 0.3 +0.45 0.25 0.377976 0.377976 +0.45 0.25 0.47622 0.47622 +0.45 0.25 0.424264 0.212132 +0.45 0.25 0.534539 0.26727 +0.45 0.25 0.673477 0.336739 +0.45 0.25 0.212132 0.424264 +0.45 0.25 0.26727 0.534539 +0.45 0.25 0.336739 0.673477 +0.55 0.25 0.3 0.3 +0.55 0.25 0.377976 0.377976 +0.55 0.25 0.47622 0.47622 +0.55 0.25 0.424264 0.212132 +0.55 0.25 0.534539 0.26727 +0.55 0.25 0.673477 0.336739 +0.55 0.25 0.212132 0.424264 +0.55 0.25 0.26727 0.534539 +0.55 0.25 0.336739 0.673477 +0.65 0.25 0.3 0.3 +0.65 0.25 0.377976 0.377976 +0.65 0.25 0.47622 0.47622 +0.65 0.25 0.424264 0.212132 +0.65 0.25 0.534539 0.26727 +0.65 0.25 0.673477 0.336739 +0.65 0.25 0.212132 0.424264 +0.65 0.25 0.26727 0.534539 +0.65 0.25 0.336739 0.673477 +0.75 0.25 0.3 0.3 +0.75 0.25 0.377976 0.377976 +0.75 0.25 0.47622 0.47622 +0.75 0.25 0.424264 0.212132 +0.75 0.25 0.534539 0.26727 +0.75 0.25 0.673477 0.336739 +0.75 0.25 0.212132 0.424264 +0.75 0.25 0.26727 0.534539 +0.75 0.25 0.336739 0.673477 +0.85 0.25 0.3 0.3 +0.85 0.25 0.377976 0.377976 +0.85 0.25 0.47622 0.47622 +0.85 0.25 0.424264 0.212132 +0.85 0.25 0.534539 0.26727 +0.85 0.25 0.673477 0.336739 +0.85 0.25 0.212132 0.424264 +0.85 0.25 0.26727 0.534539 +0.85 0.25 0.336739 0.673477 +0.95 0.25 0.3 0.3 +0.95 0.25 0.377976 0.377976 +0.95 0.25 0.47622 0.47622 +0.95 0.25 0.424264 0.212132 +0.95 0.25 0.534539 0.26727 +0.95 0.25 0.673477 0.336739 +0.95 0.25 0.212132 0.424264 +0.95 0.25 0.267269 0.534539 +0.95 0.25 0.336739 0.673477 +0.05 0.35 0.3 0.3 +0.05 0.35 0.377976 0.377976 +0.05 0.35 0.47622 0.47622 +0.05 0.35 0.424264 0.212132 +0.05 0.35 0.534539 0.26727 +0.05 0.35 0.673477 0.336739 +0.05 0.35 0.212132 0.424264 +0.05 0.35 0.26727 0.534539 +0.05 0.35 0.336739 0.673477 +0.15 0.35 0.3 0.3 +0.15 0.35 0.377976 0.377976 +0.15 0.35 0.47622 0.47622 +0.15 0.35 0.424264 0.212132 +0.15 0.35 0.534539 0.26727 +0.15 0.35 0.673477 0.336739 +0.15 0.35 0.212132 0.424264 +0.15 0.35 0.26727 0.534539 +0.15 0.35 0.336739 0.673477 +0.25 0.35 0.3 0.3 +0.25 0.35 0.377976 0.377976 +0.25 0.35 0.47622 0.47622 +0.25 0.35 0.424264 0.212132 +0.25 0.35 0.534539 0.26727 +0.25 0.35 0.673477 0.336739 +0.25 0.35 0.212132 0.424264 +0.25 0.35 0.26727 0.534539 +0.25 0.35 0.336739 0.673477 +0.35 0.35 0.3 0.3 +0.35 0.35 0.377976 0.377976 +0.35 0.35 0.47622 0.47622 +0.35 0.35 0.424264 0.212132 +0.35 0.35 0.534539 0.26727 +0.35 0.35 0.673477 0.336739 +0.35 0.35 0.212132 0.424264 +0.35 0.35 0.26727 0.534539 +0.35 0.35 0.336739 0.673477 +0.45 0.35 0.3 0.3 +0.45 0.35 0.377976 0.377976 +0.45 0.35 0.47622 0.47622 +0.45 0.35 0.424264 0.212132 +0.45 0.35 0.534539 0.26727 +0.45 0.35 0.673477 0.336739 +0.45 0.35 0.212132 0.424264 +0.45 0.35 0.26727 0.534539 +0.45 0.35 0.336739 0.673477 +0.55 0.35 0.3 0.3 +0.55 0.35 0.377976 0.377976 +0.55 0.35 0.47622 0.47622 +0.55 0.35 0.424264 0.212132 +0.55 0.35 0.534539 0.26727 +0.55 0.35 0.673477 0.336739 +0.55 0.35 0.212132 0.424264 +0.55 0.35 0.26727 0.534539 +0.55 0.35 0.336739 0.673477 +0.65 0.35 0.3 0.3 +0.65 0.35 0.377976 0.377976 +0.65 0.35 0.47622 0.47622 +0.65 0.35 0.424264 0.212132 +0.65 0.35 0.534539 0.26727 +0.65 0.35 0.673477 0.336739 +0.65 0.35 0.212132 0.424264 +0.65 0.35 0.26727 0.534539 +0.65 0.35 0.336739 0.673477 +0.75 0.35 0.3 0.3 +0.75 0.35 0.377976 0.377976 +0.75 0.35 0.47622 0.47622 +0.75 0.35 0.424264 0.212132 +0.75 0.35 0.534539 0.26727 +0.75 0.35 0.673477 0.336739 +0.75 0.35 0.212132 0.424264 +0.75 0.35 0.26727 0.534539 +0.75 0.35 0.336739 0.673477 +0.85 0.35 0.3 0.3 +0.85 0.35 0.377976 0.377976 +0.85 0.35 0.47622 0.47622 +0.85 0.35 0.424264 0.212132 +0.85 0.35 0.534539 0.26727 +0.85 0.35 0.673477 0.336739 +0.85 0.35 0.212132 0.424264 +0.85 0.35 0.26727 0.534539 +0.85 0.35 0.336739 0.673477 +0.95 0.35 0.3 0.3 +0.95 0.35 0.377976 0.377976 +0.95 0.35 0.47622 0.47622 +0.95 0.35 0.424264 0.212132 +0.95 0.35 0.534539 0.26727 +0.95 0.35 0.673477 0.336739 +0.95 0.35 0.212132 0.424264 +0.95 0.35 0.267269 0.534539 +0.95 0.35 0.336739 0.673477 +0.05 0.45 0.3 0.3 +0.05 0.45 0.377976 0.377976 +0.05 0.45 0.47622 0.47622 +0.05 0.45 0.424264 0.212132 +0.05 0.45 0.534539 0.26727 +0.05 0.45 0.673477 0.336739 +0.05 0.45 0.212132 0.424264 +0.05 0.45 0.26727 0.534539 +0.05 0.45 0.336739 0.673477 +0.15 0.45 0.3 0.3 +0.15 0.45 0.377976 0.377976 +0.15 0.45 0.47622 0.47622 +0.15 0.45 0.424264 0.212132 +0.15 0.45 0.534539 0.26727 +0.15 0.45 0.673477 0.336739 +0.15 0.45 0.212132 0.424264 +0.15 0.45 0.26727 0.534539 +0.15 0.45 0.336739 0.673477 +0.25 0.45 0.3 0.3 +0.25 0.45 0.377976 0.377976 +0.25 0.45 0.47622 0.47622 +0.25 0.45 0.424264 0.212132 +0.25 0.45 0.534539 0.26727 +0.25 0.45 0.673477 0.336739 +0.25 0.45 0.212132 0.424264 +0.25 0.45 0.26727 0.534539 +0.25 0.45 0.336739 0.673477 +0.35 0.45 0.3 0.3 +0.35 0.45 0.377976 0.377976 +0.35 0.45 0.47622 0.47622 +0.35 0.45 0.424264 0.212132 +0.35 0.45 0.534539 0.26727 +0.35 0.45 0.673477 0.336739 +0.35 0.45 0.212132 0.424264 +0.35 0.45 0.26727 0.534539 +0.35 0.45 0.336739 0.673477 +0.45 0.45 0.3 0.3 +0.45 0.45 0.377976 0.377976 +0.45 0.45 0.47622 0.47622 +0.45 0.45 0.424264 0.212132 +0.45 0.45 0.534539 0.26727 +0.45 0.45 0.673477 0.336739 +0.45 0.45 0.212132 0.424264 +0.45 0.45 0.26727 0.534539 +0.45 0.45 0.336739 0.673477 +0.55 0.45 0.3 0.3 +0.55 0.45 0.377976 0.377976 +0.55 0.45 0.47622 0.47622 +0.55 0.45 0.424264 0.212132 +0.55 0.45 0.534539 0.26727 +0.55 0.45 0.673477 0.336739 +0.55 0.45 0.212132 0.424264 +0.55 0.45 0.26727 0.534539 +0.55 0.45 0.336739 0.673477 +0.65 0.45 0.3 0.3 +0.65 0.45 0.377976 0.377976 +0.65 0.45 0.47622 0.47622 +0.65 0.45 0.424264 0.212132 +0.65 0.45 0.534539 0.26727 +0.65 0.45 0.673477 0.336739 +0.65 0.45 0.212132 0.424264 +0.65 0.45 0.26727 0.534539 +0.65 0.45 0.336739 0.673477 +0.75 0.45 0.3 0.3 +0.75 0.45 0.377976 0.377976 +0.75 0.45 0.47622 0.47622 +0.75 0.45 0.424264 0.212132 +0.75 0.45 0.534539 0.26727 +0.75 0.45 0.673477 0.336739 +0.75 0.45 0.212132 0.424264 +0.75 0.45 0.26727 0.534539 +0.75 0.45 0.336739 0.673477 +0.85 0.45 0.3 0.3 +0.85 0.45 0.377976 0.377976 +0.85 0.45 0.47622 0.47622 +0.85 0.45 0.424264 0.212132 +0.85 0.45 0.534539 0.26727 +0.85 0.45 0.673477 0.336739 +0.85 0.45 0.212132 0.424264 +0.85 0.45 0.26727 0.534539 +0.85 0.45 0.336739 0.673477 +0.95 0.45 0.3 0.3 +0.95 0.45 0.377976 0.377976 +0.95 0.45 0.47622 0.47622 +0.95 0.45 0.424264 0.212132 +0.95 0.45 0.534539 0.26727 +0.95 0.45 0.673477 0.336739 +0.95 0.45 0.212132 0.424264 +0.95 0.45 0.267269 0.534539 +0.95 0.45 0.336739 0.673477 +0.05 0.55 0.3 0.3 +0.05 0.55 0.377976 0.377976 +0.05 0.55 0.47622 0.47622 +0.05 0.55 0.424264 0.212132 +0.05 0.55 0.534539 0.26727 +0.05 0.55 0.673477 0.336739 +0.05 0.55 0.212132 0.424264 +0.05 0.55 0.26727 0.534539 +0.05 0.55 0.336739 0.673477 +0.15 0.55 0.3 0.3 +0.15 0.55 0.377976 0.377976 +0.15 0.55 0.47622 0.47622 +0.15 0.55 0.424264 0.212132 +0.15 0.55 0.534539 0.26727 +0.15 0.55 0.673477 0.336739 +0.15 0.55 0.212132 0.424264 +0.15 0.55 0.26727 0.534539 +0.15 0.55 0.336739 0.673477 +0.25 0.55 0.3 0.3 +0.25 0.55 0.377976 0.377976 +0.25 0.55 0.47622 0.47622 +0.25 0.55 0.424264 0.212132 +0.25 0.55 0.534539 0.26727 +0.25 0.55 0.673477 0.336739 +0.25 0.55 0.212132 0.424264 +0.25 0.55 0.26727 0.534539 +0.25 0.55 0.336739 0.673477 +0.35 0.55 0.3 0.3 +0.35 0.55 0.377976 0.377976 +0.35 0.55 0.47622 0.47622 +0.35 0.55 0.424264 0.212132 +0.35 0.55 0.534539 0.26727 +0.35 0.55 0.673477 0.336739 +0.35 0.55 0.212132 0.424264 +0.35 0.55 0.26727 0.534539 +0.35 0.55 0.336739 0.673477 +0.45 0.55 0.3 0.3 +0.45 0.55 0.377976 0.377976 +0.45 0.55 0.47622 0.47622 +0.45 0.55 0.424264 0.212132 +0.45 0.55 0.534539 0.26727 +0.45 0.55 0.673477 0.336739 +0.45 0.55 0.212132 0.424264 +0.45 0.55 0.26727 0.534539 +0.45 0.55 0.336739 0.673477 +0.55 0.55 0.3 0.3 +0.55 0.55 0.377976 0.377976 +0.55 0.55 0.47622 0.47622 +0.55 0.55 0.424264 0.212132 +0.55 0.55 0.534539 0.26727 +0.55 0.55 0.673477 0.336739 +0.55 0.55 0.212132 0.424264 +0.55 0.55 0.26727 0.534539 +0.55 0.55 0.336739 0.673477 +0.65 0.55 0.3 0.3 +0.65 0.55 0.377976 0.377976 +0.65 0.55 0.47622 0.47622 +0.65 0.55 0.424264 0.212132 +0.65 0.55 0.534539 0.26727 +0.65 0.55 0.673477 0.336739 +0.65 0.55 0.212132 0.424264 +0.65 0.55 0.26727 0.534539 +0.65 0.55 0.336739 0.673477 +0.75 0.55 0.3 0.3 +0.75 0.55 0.377976 0.377976 +0.75 0.55 0.47622 0.47622 +0.75 0.55 0.424264 0.212132 +0.75 0.55 0.534539 0.26727 +0.75 0.55 0.673477 0.336739 +0.75 0.55 0.212132 0.424264 +0.75 0.55 0.26727 0.534539 +0.75 0.55 0.336739 0.673477 +0.85 0.55 0.3 0.3 +0.85 0.55 0.377976 0.377976 +0.85 0.55 0.47622 0.47622 +0.85 0.55 0.424264 0.212132 +0.85 0.55 0.534539 0.26727 +0.85 0.55 0.673477 0.336739 +0.85 0.55 0.212132 0.424264 +0.85 0.55 0.26727 0.534539 +0.85 0.55 0.336739 0.673477 +0.95 0.55 0.3 0.3 +0.95 0.55 0.377976 0.377976 +0.95 0.55 0.47622 0.47622 +0.95 0.55 0.424264 0.212132 +0.95 0.55 0.534539 0.26727 +0.95 0.55 0.673477 0.336739 +0.95 0.55 0.212132 0.424264 +0.95 0.55 0.267269 0.534539 +0.95 0.55 0.336739 0.673477 +0.05 0.65 0.3 0.3 +0.05 0.65 0.377976 0.377976 +0.05 0.65 0.47622 0.47622 +0.05 0.65 0.424264 0.212132 +0.05 0.65 0.534539 0.26727 +0.05 0.65 0.673477 0.336739 +0.05 0.65 0.212132 0.424264 +0.05 0.65 0.26727 0.534539 +0.05 0.65 0.336739 0.673477 +0.15 0.65 0.3 0.3 +0.15 0.65 0.377976 0.377976 +0.15 0.65 0.47622 0.47622 +0.15 0.65 0.424264 0.212132 +0.15 0.65 0.534539 0.26727 +0.15 0.65 0.673477 0.336739 +0.15 0.65 0.212132 0.424264 +0.15 0.65 0.26727 0.534539 +0.15 0.65 0.336739 0.673477 +0.25 0.65 0.3 0.3 +0.25 0.65 0.377976 0.377976 +0.25 0.65 0.47622 0.47622 +0.25 0.65 0.424264 0.212132 +0.25 0.65 0.534539 0.26727 +0.25 0.65 0.673477 0.336739 +0.25 0.65 0.212132 0.424264 +0.25 0.65 0.26727 0.534539 +0.25 0.65 0.336739 0.673477 +0.35 0.65 0.3 0.3 +0.35 0.65 0.377976 0.377976 +0.35 0.65 0.47622 0.47622 +0.35 0.65 0.424264 0.212132 +0.35 0.65 0.534539 0.26727 +0.35 0.65 0.673477 0.336739 +0.35 0.65 0.212132 0.424264 +0.35 0.65 0.26727 0.534539 +0.35 0.65 0.336739 0.673477 +0.45 0.65 0.3 0.3 +0.45 0.65 0.377976 0.377976 +0.45 0.65 0.47622 0.47622 +0.45 0.65 0.424264 0.212132 +0.45 0.65 0.534539 0.26727 +0.45 0.65 0.673477 0.336739 +0.45 0.65 0.212132 0.424264 +0.45 0.65 0.26727 0.534539 +0.45 0.65 0.336739 0.673477 +0.55 0.65 0.3 0.3 +0.55 0.65 0.377976 0.377976 +0.55 0.65 0.47622 0.47622 +0.55 0.65 0.424264 0.212132 +0.55 0.65 0.534539 0.26727 +0.55 0.65 0.673477 0.336739 +0.55 0.65 0.212132 0.424264 +0.55 0.65 0.26727 0.534539 +0.55 0.65 0.336739 0.673477 +0.65 0.65 0.3 0.3 +0.65 0.65 0.377976 0.377976 +0.65 0.65 0.47622 0.47622 +0.65 0.65 0.424264 0.212132 +0.65 0.65 0.534539 0.26727 +0.65 0.65 0.673477 0.336739 +0.65 0.65 0.212132 0.424264 +0.65 0.65 0.26727 0.534539 +0.65 0.65 0.336739 0.673477 +0.75 0.65 0.3 0.3 +0.75 0.65 0.377976 0.377976 +0.75 0.65 0.47622 0.47622 +0.75 0.65 0.424264 0.212132 +0.75 0.65 0.534539 0.26727 +0.75 0.65 0.673477 0.336739 +0.75 0.65 0.212132 0.424264 +0.75 0.65 0.26727 0.534539 +0.75 0.65 0.336739 0.673477 +0.85 0.65 0.3 0.3 +0.85 0.65 0.377976 0.377976 +0.85 0.65 0.47622 0.47622 +0.85 0.65 0.424264 0.212132 +0.85 0.65 0.534539 0.26727 +0.85 0.65 0.673477 0.336739 +0.85 0.65 0.212132 0.424264 +0.85 0.65 0.26727 0.534539 +0.85 0.65 0.336739 0.673477 +0.95 0.65 0.3 0.3 +0.95 0.65 0.377976 0.377976 +0.95 0.65 0.47622 0.47622 +0.95 0.65 0.424264 0.212132 +0.95 0.65 0.534539 0.26727 +0.95 0.65 0.673477 0.336739 +0.95 0.65 0.212132 0.424264 +0.95 0.65 0.267269 0.534539 +0.95 0.65 0.336739 0.673477 +0.05 0.75 0.3 0.3 +0.05 0.75 0.377976 0.377976 +0.05 0.75 0.47622 0.47622 +0.05 0.75 0.424264 0.212132 +0.05 0.75 0.534539 0.26727 +0.05 0.75 0.673477 0.336739 +0.05 0.75 0.212132 0.424264 +0.05 0.75 0.26727 0.534539 +0.05 0.75 0.336739 0.673477 +0.15 0.75 0.3 0.3 +0.15 0.75 0.377976 0.377976 +0.15 0.75 0.47622 0.47622 +0.15 0.75 0.424264 0.212132 +0.15 0.75 0.534539 0.26727 +0.15 0.75 0.673477 0.336739 +0.15 0.75 0.212132 0.424264 +0.15 0.75 0.26727 0.534539 +0.15 0.75 0.336739 0.673477 +0.25 0.75 0.3 0.3 +0.25 0.75 0.377976 0.377976 +0.25 0.75 0.47622 0.47622 +0.25 0.75 0.424264 0.212132 +0.25 0.75 0.534539 0.26727 +0.25 0.75 0.673477 0.336739 +0.25 0.75 0.212132 0.424264 +0.25 0.75 0.26727 0.534539 +0.25 0.75 0.336739 0.673477 +0.35 0.75 0.3 0.3 +0.35 0.75 0.377976 0.377976 +0.35 0.75 0.47622 0.47622 +0.35 0.75 0.424264 0.212132 +0.35 0.75 0.534539 0.26727 +0.35 0.75 0.673477 0.336739 +0.35 0.75 0.212132 0.424264 +0.35 0.75 0.26727 0.534539 +0.35 0.75 0.336739 0.673477 +0.45 0.75 0.3 0.3 +0.45 0.75 0.377976 0.377976 +0.45 0.75 0.47622 0.47622 +0.45 0.75 0.424264 0.212132 +0.45 0.75 0.534539 0.26727 +0.45 0.75 0.673477 0.336739 +0.45 0.75 0.212132 0.424264 +0.45 0.75 0.26727 0.534539 +0.45 0.75 0.336739 0.673477 +0.55 0.75 0.3 0.3 +0.55 0.75 0.377976 0.377976 +0.55 0.75 0.47622 0.47622 +0.55 0.75 0.424264 0.212132 +0.55 0.75 0.534539 0.26727 +0.55 0.75 0.673477 0.336739 +0.55 0.75 0.212132 0.424264 +0.55 0.75 0.26727 0.534539 +0.55 0.75 0.336739 0.673477 +0.65 0.75 0.3 0.3 +0.65 0.75 0.377976 0.377976 +0.65 0.75 0.47622 0.47622 +0.65 0.75 0.424264 0.212132 +0.65 0.75 0.534539 0.26727 +0.65 0.75 0.673477 0.336739 +0.65 0.75 0.212132 0.424264 +0.65 0.75 0.26727 0.534539 +0.65 0.75 0.336739 0.673477 +0.75 0.75 0.3 0.3 +0.75 0.75 0.377976 0.377976 +0.75 0.75 0.47622 0.47622 +0.75 0.75 0.424264 0.212132 +0.75 0.75 0.534539 0.26727 +0.75 0.75 0.673477 0.336739 +0.75 0.75 0.212132 0.424264 +0.75 0.75 0.26727 0.534539 +0.75 0.75 0.336739 0.673477 +0.85 0.75 0.3 0.3 +0.85 0.75 0.377976 0.377976 +0.85 0.75 0.47622 0.47622 +0.85 0.75 0.424264 0.212132 +0.85 0.75 0.534539 0.26727 +0.85 0.75 0.673477 0.336739 +0.85 0.75 0.212132 0.424264 +0.85 0.75 0.26727 0.534539 +0.85 0.75 0.336739 0.673477 +0.95 0.75 0.3 0.3 +0.95 0.75 0.377976 0.377976 +0.95 0.75 0.47622 0.47622 +0.95 0.75 0.424264 0.212132 +0.95 0.75 0.534539 0.26727 +0.95 0.75 0.673477 0.336739 +0.95 0.75 0.212132 0.424264 +0.95 0.75 0.267269 0.534539 +0.95 0.75 0.336739 0.673477 +0.05 0.85 0.3 0.3 +0.05 0.85 0.377976 0.377976 +0.05 0.85 0.47622 0.47622 +0.05 0.85 0.424264 0.212132 +0.05 0.85 0.534539 0.26727 +0.05 0.85 0.673477 0.336739 +0.05 0.85 0.212132 0.424264 +0.05 0.85 0.26727 0.534539 +0.05 0.85 0.336739 0.673477 +0.15 0.85 0.3 0.3 +0.15 0.85 0.377976 0.377976 +0.15 0.85 0.47622 0.47622 +0.15 0.85 0.424264 0.212132 +0.15 0.85 0.534539 0.26727 +0.15 0.85 0.673477 0.336739 +0.15 0.85 0.212132 0.424264 +0.15 0.85 0.26727 0.534539 +0.15 0.85 0.336739 0.673477 +0.25 0.85 0.3 0.3 +0.25 0.85 0.377976 0.377976 +0.25 0.85 0.47622 0.47622 +0.25 0.85 0.424264 0.212132 +0.25 0.85 0.534539 0.26727 +0.25 0.85 0.673477 0.336739 +0.25 0.85 0.212132 0.424264 +0.25 0.85 0.26727 0.534539 +0.25 0.85 0.336739 0.673477 +0.35 0.85 0.3 0.3 +0.35 0.85 0.377976 0.377976 +0.35 0.85 0.47622 0.47622 +0.35 0.85 0.424264 0.212132 +0.35 0.85 0.534539 0.26727 +0.35 0.85 0.673477 0.336739 +0.35 0.85 0.212132 0.424264 +0.35 0.85 0.26727 0.534539 +0.35 0.85 0.336739 0.673477 +0.45 0.85 0.3 0.3 +0.45 0.85 0.377976 0.377976 +0.45 0.85 0.47622 0.47622 +0.45 0.85 0.424264 0.212132 +0.45 0.85 0.534539 0.26727 +0.45 0.85 0.673477 0.336739 +0.45 0.85 0.212132 0.424264 +0.45 0.85 0.26727 0.534539 +0.45 0.85 0.336739 0.673477 +0.55 0.85 0.3 0.3 +0.55 0.85 0.377976 0.377976 +0.55 0.85 0.47622 0.47622 +0.55 0.85 0.424264 0.212132 +0.55 0.85 0.534539 0.26727 +0.55 0.85 0.673477 0.336739 +0.55 0.85 0.212132 0.424264 +0.55 0.85 0.26727 0.534539 +0.55 0.85 0.336739 0.673477 +0.65 0.85 0.3 0.3 +0.65 0.85 0.377976 0.377976 +0.65 0.85 0.47622 0.47622 +0.65 0.85 0.424264 0.212132 +0.65 0.85 0.534539 0.26727 +0.65 0.85 0.673477 0.336739 +0.65 0.85 0.212132 0.424264 +0.65 0.85 0.26727 0.534539 +0.65 0.85 0.336739 0.673477 +0.75 0.85 0.3 0.3 +0.75 0.85 0.377976 0.377976 +0.75 0.85 0.47622 0.47622 +0.75 0.85 0.424264 0.212132 +0.75 0.85 0.534539 0.26727 +0.75 0.85 0.673477 0.336739 +0.75 0.85 0.212132 0.424264 +0.75 0.85 0.26727 0.534539 +0.75 0.85 0.336739 0.673477 +0.85 0.85 0.3 0.3 +0.85 0.85 0.377976 0.377976 +0.85 0.85 0.47622 0.47622 +0.85 0.85 0.424264 0.212132 +0.85 0.85 0.534539 0.26727 +0.85 0.85 0.673477 0.336739 +0.85 0.85 0.212132 0.424264 +0.85 0.85 0.26727 0.534539 +0.85 0.85 0.336739 0.673477 +0.95 0.85 0.3 0.3 +0.95 0.85 0.377976 0.377976 +0.95 0.85 0.47622 0.47622 +0.95 0.85 0.424264 0.212132 +0.95 0.85 0.534539 0.26727 +0.95 0.85 0.673477 0.336739 +0.95 0.85 0.212132 0.424264 +0.95 0.85 0.267269 0.534539 +0.95 0.85 0.336739 0.673477 +0.05 0.95 0.3 0.3 +0.05 0.95 0.377976 0.377976 +0.05 0.95 0.47622 0.47622 +0.05 0.95 0.424264 0.212132 +0.05 0.95 0.534539 0.26727 +0.05 0.95 0.673477 0.336739 +0.05 0.95 0.212132 0.424264 +0.05 0.95 0.26727 0.534539 +0.05 0.95 0.336739 0.673477 +0.15 0.95 0.3 0.3 +0.15 0.95 0.377976 0.377976 +0.15 0.95 0.47622 0.47622 +0.15 0.95 0.424264 0.212132 +0.15 0.95 0.534539 0.26727 +0.15 0.95 0.673477 0.336739 +0.15 0.95 0.212132 0.424264 +0.15 0.95 0.26727 0.534539 +0.15 0.95 0.336739 0.673477 +0.25 0.95 0.3 0.3 +0.25 0.95 0.377976 0.377976 +0.25 0.95 0.47622 0.47622 +0.25 0.95 0.424264 0.212132 +0.25 0.95 0.534539 0.26727 +0.25 0.95 0.673477 0.336739 +0.25 0.95 0.212132 0.424264 +0.25 0.95 0.26727 0.534539 +0.25 0.95 0.336739 0.673477 +0.35 0.95 0.3 0.3 +0.35 0.95 0.377976 0.377976 +0.35 0.95 0.47622 0.47622 +0.35 0.95 0.424264 0.212132 +0.35 0.95 0.534539 0.26727 +0.35 0.95 0.673477 0.336739 +0.35 0.95 0.212132 0.424264 +0.35 0.95 0.26727 0.534539 +0.35 0.95 0.336739 0.673477 +0.45 0.95 0.3 0.3 +0.45 0.95 0.377976 0.377976 +0.45 0.95 0.47622 0.47622 +0.45 0.95 0.424264 0.212132 +0.45 0.95 0.534539 0.26727 +0.45 0.95 0.673477 0.336739 +0.45 0.95 0.212132 0.424264 +0.45 0.95 0.26727 0.534539 +0.45 0.95 0.336739 0.673477 +0.55 0.95 0.3 0.3 +0.55 0.95 0.377976 0.377976 +0.55 0.95 0.47622 0.47622 +0.55 0.95 0.424264 0.212132 +0.55 0.95 0.534539 0.26727 +0.55 0.95 0.673477 0.336739 +0.55 0.95 0.212132 0.424264 +0.55 0.95 0.26727 0.534539 +0.55 0.95 0.336739 0.673477 +0.65 0.95 0.3 0.3 +0.65 0.95 0.377976 0.377976 +0.65 0.95 0.47622 0.47622 +0.65 0.95 0.424264 0.212132 +0.65 0.95 0.534539 0.26727 +0.65 0.95 0.673477 0.336739 +0.65 0.95 0.212132 0.424264 +0.65 0.95 0.26727 0.534539 +0.65 0.95 0.336739 0.673477 +0.75 0.95 0.3 0.3 +0.75 0.95 0.377976 0.377976 +0.75 0.95 0.47622 0.47622 +0.75 0.95 0.424264 0.212132 +0.75 0.95 0.534539 0.26727 +0.75 0.95 0.673477 0.336739 +0.75 0.95 0.212132 0.424264 +0.75 0.95 0.26727 0.534539 +0.75 0.95 0.336739 0.673477 +0.85 0.95 0.3 0.3 +0.85 0.95 0.377976 0.377976 +0.85 0.95 0.47622 0.47622 +0.85 0.95 0.424264 0.212132 +0.85 0.95 0.534539 0.26727 +0.85 0.95 0.673477 0.336739 +0.85 0.95 0.212132 0.424264 +0.85 0.95 0.26727 0.534539 +0.85 0.95 0.336739 0.673477 +0.95 0.95 0.3 0.3 +0.95 0.95 0.377976 0.377976 +0.95 0.95 0.47622 0.47622 +0.95 0.95 0.424264 0.212132 +0.95 0.95 0.534539 0.26727 +0.95 0.95 0.673477 0.336739 +0.95 0.95 0.212132 0.424264 +0.95 0.95 0.267269 0.534539 +0.95 0.95 0.336739 0.673477 +0.1 0.1 0.6 0.6 +0.1 0.1 0.755953 0.755953 +0.1 0.1 0.952441 0.952441 +0.1 0.1 0.848528 0.424264 +0.1 0.1 1.06908 0.534539 +0.1 0.1 1.34695 0.673477 +0.1 0.1 0.424264 0.848528 +0.1 0.1 0.534539 1.06908 +0.1 0.1 0.673477 1.34695 +0.3 0.1 0.6 0.6 +0.3 0.1 0.755953 0.755953 +0.3 0.1 0.952441 0.952441 +0.3 0.1 0.848528 0.424264 +0.3 0.1 1.06908 0.534539 +0.3 0.1 1.34695 0.673477 +0.3 0.1 0.424264 0.848528 +0.3 0.1 0.534539 1.06908 +0.3 0.1 0.673477 1.34695 +0.5 0.1 0.6 0.6 +0.5 0.1 0.755953 0.755953 +0.5 0.1 0.952441 0.952441 +0.5 0.1 0.848528 0.424264 +0.5 0.1 1.06908 0.534539 +0.5 0.1 1.34695 0.673477 +0.5 0.1 0.424264 0.848528 +0.5 0.1 0.534539 1.06908 +0.5 0.1 0.673477 1.34695 +0.7 0.1 0.6 0.6 +0.7 0.1 0.755953 0.755953 +0.7 0.1 0.952441 0.952441 +0.7 0.1 0.848528 0.424264 +0.7 0.1 1.06908 0.534539 +0.7 0.1 1.34695 0.673477 +0.7 0.1 0.424264 0.848528 +0.7 0.1 0.534539 1.06908 +0.7 0.1 0.673477 1.34695 +0.9 0.1 0.6 0.6 +0.9 0.1 0.755953 0.755953 +0.9 0.1 0.952441 0.952441 +0.9 0.1 0.848528 0.424264 +0.9 0.1 1.06908 0.534539 +0.9 0.1 1.34695 0.673477 +0.9 0.1 0.424264 0.848528 +0.9 0.1 0.534539 1.06908 +0.9 0.1 0.673477 1.34695 +0.1 0.3 0.6 0.6 +0.1 0.3 0.755953 0.755953 +0.1 0.3 0.952441 0.952441 +0.1 0.3 0.848528 0.424264 +0.1 0.3 1.06908 0.534539 +0.1 0.3 1.34695 0.673477 +0.1 0.3 0.424264 0.848528 +0.1 0.3 0.534539 1.06908 +0.1 0.3 0.673477 1.34695 +0.3 0.3 0.6 0.6 +0.3 0.3 0.755953 0.755953 +0.3 0.3 0.952441 0.952441 +0.3 0.3 0.848528 0.424264 +0.3 0.3 1.06908 0.534539 +0.3 0.3 1.34695 0.673477 +0.3 0.3 0.424264 0.848528 +0.3 0.3 0.534539 1.06908 +0.3 0.3 0.673477 1.34695 +0.5 0.3 0.6 0.6 +0.5 0.3 0.755953 0.755953 +0.5 0.3 0.952441 0.952441 +0.5 0.3 0.848528 0.424264 +0.5 0.3 1.06908 0.534539 +0.5 0.3 1.34695 0.673477 +0.5 0.3 0.424264 0.848528 +0.5 0.3 0.534539 1.06908 +0.5 0.3 0.673477 1.34695 +0.7 0.3 0.6 0.6 +0.7 0.3 0.755953 0.755953 +0.7 0.3 0.952441 0.952441 +0.7 0.3 0.848528 0.424264 +0.7 0.3 1.06908 0.534539 +0.7 0.3 1.34695 0.673477 +0.7 0.3 0.424264 0.848528 +0.7 0.3 0.534539 1.06908 +0.7 0.3 0.673477 1.34695 +0.9 0.3 0.6 0.6 +0.9 0.3 0.755953 0.755953 +0.9 0.3 0.952441 0.952441 +0.9 0.3 0.848528 0.424264 +0.9 0.3 1.06908 0.534539 +0.9 0.3 1.34695 0.673477 +0.9 0.3 0.424264 0.848528 +0.9 0.3 0.534539 1.06908 +0.9 0.3 0.673477 1.34695 +0.1 0.5 0.6 0.6 +0.1 0.5 0.755953 0.755953 +0.1 0.5 0.952441 0.952441 +0.1 0.5 0.848528 0.424264 +0.1 0.5 1.06908 0.534539 +0.1 0.5 1.34695 0.673477 +0.1 0.5 0.424264 0.848528 +0.1 0.5 0.534539 1.06908 +0.1 0.5 0.673477 1.34695 +0.3 0.5 0.6 0.6 +0.3 0.5 0.755953 0.755953 +0.3 0.5 0.952441 0.952441 +0.3 0.5 0.848528 0.424264 +0.3 0.5 1.06908 0.534539 +0.3 0.5 1.34695 0.673477 +0.3 0.5 0.424264 0.848528 +0.3 0.5 0.534539 1.06908 +0.3 0.5 0.673477 1.34695 +0.5 0.5 0.6 0.6 +0.5 0.5 0.755953 0.755953 +0.5 0.5 0.952441 0.952441 +0.5 0.5 0.848528 0.424264 +0.5 0.5 1.06908 0.534539 +0.5 0.5 1.34695 0.673477 +0.5 0.5 0.424264 0.848528 +0.5 0.5 0.534539 1.06908 +0.5 0.5 0.673477 1.34695 +0.7 0.5 0.6 0.6 +0.7 0.5 0.755953 0.755953 +0.7 0.5 0.952441 0.952441 +0.7 0.5 0.848528 0.424264 +0.7 0.5 1.06908 0.534539 +0.7 0.5 1.34695 0.673477 +0.7 0.5 0.424264 0.848528 +0.7 0.5 0.534539 1.06908 +0.7 0.5 0.673477 1.34695 +0.9 0.5 0.6 0.6 +0.9 0.5 0.755953 0.755953 +0.9 0.5 0.952441 0.952441 +0.9 0.5 0.848528 0.424264 +0.9 0.5 1.06908 0.534539 +0.9 0.5 1.34695 0.673477 +0.9 0.5 0.424264 0.848528 +0.9 0.5 0.534539 1.06908 +0.9 0.5 0.673477 1.34695 +0.1 0.7 0.6 0.6 +0.1 0.7 0.755953 0.755953 +0.1 0.7 0.952441 0.952441 +0.1 0.7 0.848528 0.424264 +0.1 0.7 1.06908 0.534539 +0.1 0.7 1.34695 0.673477 +0.1 0.7 0.424264 0.848528 +0.1 0.7 0.534539 1.06908 +0.1 0.7 0.673477 1.34695 +0.3 0.7 0.6 0.6 +0.3 0.7 0.755953 0.755953 +0.3 0.7 0.952441 0.952441 +0.3 0.7 0.848528 0.424264 +0.3 0.7 1.06908 0.534539 +0.3 0.7 1.34695 0.673477 +0.3 0.7 0.424264 0.848528 +0.3 0.7 0.534539 1.06908 +0.3 0.7 0.673477 1.34695 +0.5 0.7 0.6 0.6 +0.5 0.7 0.755953 0.755953 +0.5 0.7 0.952441 0.952441 +0.5 0.7 0.848528 0.424264 +0.5 0.7 1.06908 0.534539 +0.5 0.7 1.34695 0.673477 +0.5 0.7 0.424264 0.848528 +0.5 0.7 0.534539 1.06908 +0.5 0.7 0.673477 1.34695 +0.7 0.7 0.6 0.6 +0.7 0.7 0.755953 0.755953 +0.7 0.7 0.952441 0.952441 +0.7 0.7 0.848528 0.424264 +0.7 0.7 1.06908 0.534539 +0.7 0.7 1.34695 0.673477 +0.7 0.7 0.424264 0.848528 +0.7 0.7 0.534539 1.06908 +0.7 0.7 0.673477 1.34695 +0.9 0.7 0.6 0.6 +0.9 0.7 0.755953 0.755953 +0.9 0.7 0.952441 0.952441 +0.9 0.7 0.848528 0.424264 +0.9 0.7 1.06908 0.534539 +0.9 0.7 1.34695 0.673477 +0.9 0.7 0.424264 0.848528 +0.9 0.7 0.534539 1.06908 +0.9 0.7 0.673477 1.34695 +0.1 0.9 0.6 0.6 +0.1 0.9 0.755953 0.755953 +0.1 0.9 0.952441 0.952441 +0.1 0.9 0.848528 0.424264 +0.1 0.9 1.06908 0.534539 +0.1 0.9 1.34695 0.673477 +0.1 0.9 0.424264 0.848528 +0.1 0.9 0.534539 1.06908 +0.1 0.9 0.673477 1.34695 +0.3 0.9 0.6 0.6 +0.3 0.9 0.755953 0.755953 +0.3 0.9 0.952441 0.952441 +0.3 0.9 0.848528 0.424264 +0.3 0.9 1.06908 0.534539 +0.3 0.9 1.34695 0.673477 +0.3 0.9 0.424264 0.848528 +0.3 0.9 0.534539 1.06908 +0.3 0.9 0.673477 1.34695 +0.5 0.9 0.6 0.6 +0.5 0.9 0.755953 0.755953 +0.5 0.9 0.952441 0.952441 +0.5 0.9 0.848528 0.424264 +0.5 0.9 1.06908 0.534539 +0.5 0.9 1.34695 0.673477 +0.5 0.9 0.424264 0.848528 +0.5 0.9 0.534539 1.06908 +0.5 0.9 0.673477 1.34695 +0.7 0.9 0.6 0.6 +0.7 0.9 0.755953 0.755953 +0.7 0.9 0.952441 0.952441 +0.7 0.9 0.848528 0.424264 +0.7 0.9 1.06908 0.534539 +0.7 0.9 1.34695 0.673477 +0.7 0.9 0.424264 0.848528 +0.7 0.9 0.534539 1.06908 +0.7 0.9 0.673477 1.34695 +0.9 0.9 0.6 0.6 +0.9 0.9 0.755953 0.755953 +0.9 0.9 0.952441 0.952441 +0.9 0.9 0.848528 0.424264 +0.9 0.9 1.06908 0.534539 +0.9 0.9 1.34695 0.673477 +0.9 0.9 0.424264 0.848528 +0.9 0.9 0.534539 1.06908 +0.9 0.9 0.673477 1.34695 diff --git a/mediapipe/calculators/tflite/tflite_model_calculator.cc b/mediapipe/calculators/tflite/tflite_model_calculator.cc index ca28910e5..d118e878c 100644 --- a/mediapipe/calculators/tflite/tflite_model_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_model_calculator.cc @@ -19,6 +19,7 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/port/ret_check.h" +#include "tensorflow/lite/allocation.h" #include "tensorflow/lite/model.h" namespace mediapipe { @@ -32,6 +33,8 @@ namespace mediapipe { // it to the graph as input side packet or you can use some of // calculators like LocalFileContentsCalculator to get model // blob and use it as input here. +// MODEL_FD - Tflite model file descriptor std::tuple +// containing (fd, offset, size). // // Output side packets: // MODEL - TfLite model. (std::unique_ptr>; static absl::Status GetContract(CalculatorContract* cc) { - cc->InputSidePackets().Tag("MODEL_BLOB").Set(); + if (cc->InputSidePackets().HasTag("MODEL_BLOB")) { + cc->InputSidePackets().Tag("MODEL_BLOB").Set(); + } + + if (cc->InputSidePackets().HasTag("MODEL_FD")) { + cc->InputSidePackets() + .Tag("MODEL_FD") + .Set>(); + } + cc->OutputSidePackets().Tag("MODEL").Set(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { - const Packet& model_packet = cc->InputSidePackets().Tag("MODEL_BLOB"); - const std::string& model_blob = model_packet.Get(); - std::unique_ptr model = - tflite::FlatBufferModel::BuildFromBuffer(model_blob.data(), - model_blob.size()); + Packet model_packet; + std::unique_ptr model; + + if (cc->InputSidePackets().HasTag("MODEL_BLOB")) { + model_packet = cc->InputSidePackets().Tag("MODEL_BLOB"); + const std::string& model_blob = model_packet.Get(); + model = tflite::FlatBufferModel::BuildFromBuffer(model_blob.data(), + model_blob.size()); + } + + if (cc->InputSidePackets().HasTag("MODEL_FD")) { + model_packet = cc->InputSidePackets().Tag("MODEL_FD"); + const auto& model_fd = + model_packet.Get>(); + auto model_allocation = std::make_unique( + std::get<0>(model_fd), std::get<1>(model_fd), std::get<2>(model_fd), + tflite::DefaultErrorReporter()); + model = tflite::FlatBufferModel::BuildFromAllocation( + std::move(model_allocation), tflite::DefaultErrorReporter()); + } + RET_CHECK(model) << "Failed to load TfLite model from blob."; cc->OutputSidePackets().Tag("MODEL").Set( diff --git a/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/83.5_c_Ipad_2x.png b/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/83.5_c_Ipad_2x.png new file mode 100644 index 0000000000000000000000000000000000000000..e0a66942bdb65de1cec992e013e16e98762fc94e GIT binary patch literal 9566 zcmbt)WmH_v((d4rU?C9P-JQW9xVyVE5Zr@XaF<}gU4#4Ju7hiEcL)&tBj>#LobUd- zYu&w9@2akTs;YZSt?J$#uB0G|f{cd@002;=rNmU;Xz|~I2>*6O3IOmA1^^!3ME(Z=fGZ0CaAX7k@TLL)pByqYZh7Q8BA691^b)da|_TwEM@nVH?)-I?6kne3e` znOS*wc$it(nAzAE-w=$>o^~$A9*lO*_(O+=Q;@-n0 zu}C063fy1_C*xeI-CDMTr>?PSOKazss?sTrsnwVR>MLY^O_%{{$7A-pEn8E%9*MV(1fVdihi;XcYh(`}UPyn7S1%@aS zmmW_s6`^=gc)+PEh}~3t8|ixEH_nh2s)vI`g%!UsQ6>=#?kE)6L=;vu^;po1DTyi1 zMt53G-es_-!`Js_3<&%{b`6d!7Bq#}fT&%61j~Vl0nEMbL)QN~yIKf@v##gjUXXDf zi(nkQ`tyi47(BkR(cp=np1rU8r;v_%(B-43B&kcuIPMGao~#hZm_x9BgjHusJ}h2D zj>xh}jA>LO0#IT`7)D-K-Y!Oc3VcAjVLxn&Y#@c55N;*oE=>=F>#6i-`;H*U$79Yl zNvBGUZ0w{jNfjlVNV&q6swMp`oPb-8R4$SGhd%3gE5#u3eOE|N1M4AHi1P3b5cv9zqq)l zQiv{=hWW}veWPDthjqerEfXy?VmF(?cB4WVr;w%D*=4n&c2&^Vs&h?Q-?(zOP$L9= zz{2u$_9FFpDTL^8;9k}uk$t}J-ryjG>s?++3e~{@^m|PY+Fu=BZO@VNuJby|*19sX zU$9Bs6y=uPz2HxTTn#wzacU)C*`tl%VPGW*IpoL=ZXwcj(XO|;l`CKT%+pT&V?G#6 zl~1|5IP|>$1PQ>4+XlbnryztRC~)mzC>QxQ1b8Kn>u2jA`>K@<4~UYUS>nb1>SD8i z&tb&#!UPlh4wa`~h#Tm}B`MQ|0flK-dH|v%K9_tui*Mfq!>OJ!7=*pQuZs~lDH^1= z9nc@=2+Ka&U?|``;d^>O8^;s*d0c*}VuvhBV)+2fW!dSqP+vlSN z(yU29x-y8ZiqFY>>euTu(s3J`(VSInIlm;GN&H^b{z2i3`lsA!tyVGBoh@k79ESy6bp4E98GrnT~3F&a|hN( zY_eGXH;>R}Wci$&-1{E!Fw9vZApVOhq`G)kcfvT-j$kx}ON;;^?@@(EGz2$YLcB6z zqn$&zAy2}0v9#TYB|L~mA2#)Hh-wTyFA;7z?Cp4>z%8j-QN8gb3gYD8E{)nYk4=z{c76DiX61_cB`N{xU{RIyY`a zO%|s}ylk2MSyvq!GPl=lE|~A1L;;IPNay@SV!V5O0(}*0XZMpR^~&Tthx5uJwxxXQ zRq=TgiepZH_Im&IgozK9Wglc{3shoKYeatY4gUpzf?347WUsReE+IkvcQnd zOWEfk3!6F{k@7NZJ#h}-b@n-*nJ#2`tG{A7AdH5R`yecH=6t#lYv{aEKKeb;c-AS{q;M*EEE3P> zmsWOefU14TuRo#oYrFp1UF*ss7q0zz*)Czqt>6#We$4*7tQ;nihcNT(!Z@)`j{Q1h zLSJ%nq3ZhvTCE=+GS2H)tjlc!JAYV5IG&q)C{#wBTL21shYIHSZU>zcnN3OcSYrp`@arG^<0Y^xFlfYWdv z9Xa;SI}hVHFKXrUQTk8&PZ>1jf4eVk7>sdGD{9;wdjaYLzngt+$$$n@@1lD^desi4 z=|(bXzvIS2v&Rpp#Y)tnSWT@1o4=FV#oM9>nbfjFU$lk{fIVc(*EBRmJKd@5lo$(D z1Ka)1QN7l&zhO+FV=Yb}r@WlvM7ez(9MJLdf1HbzEjlnqe{Q-WFa+CDTX0U&ZPsw; zRFL7#E+p^0oSe2*=2rAwa&j>k8rJR5c$P(K4B>nZTJ^B<@R6^X42GVN|1HUMX(1pI z+F4Yvc*AU(So?(aVT1sBZBEI?&&8^`V;~G1qTp3AJ>+BQpy>lE?|9cZp1QpwT+>X5 z@=MTd@)R`<^0K+L{{6AaE+SX|<+#`ZRRo^hUs3zV^4qI%ZG zC9&HXf@M=$>ZYIQQ86e}VBM}1+4F}gXocuV{n`}>flrtf30i#8;0x| z#_WEn8M*`d^_SFq{v*Rv+hh&;O(M{D;$PENjLlWExv>i$tHlZdEqCQSrec`so+gusUdYs>kXfIeM4@??kzV7zDY4RrBdo)Y zl}J=+=){IR+s|vGff1B<{ILAWQ`Zt)wd9G0 z8$Y%0VACLdgX1QmAMrKAaCiI;PZ8B0J;3+Vk=)hpL_JQ{U%WlF|Eru)u36qEc~Hfo z4D=V>nwVYRuq(2t2T_=PC$xk!XJui8HM{nkRx2u|GaTGb!N&wT#%ZlulC@?BUf8eo zS86yzQgKrro(WP7+9(rZI2Z#B@ueLEd~$$rVJtCaM!3}LMzDxi<`5;E{e_f= zIZAWr*@_aKwr*hVjl&XN`1Ag?Tbism)*Nm@Or)tVh(hAn@KIZhw#(S7pj%=&Uz zML!S2a$8d5j)+|hcCkep9cA{&D{X5|VZr4_0ANIYRy6y@UVG;E3%PA$n)wG&xp%RL zTMW!ZnV=6pb7o`9zEI)py{pcuVa_~mtVYN5Lx4!u65D_+J`?oy^m2G3<@ z3R?Po)>;YFcQ4X|8lh~Cwxje_BfnBz%@?$O>Mh|AkP934aQHT~I#;H_74L>xNo?Pz zr}^Qf=p1^313eQqr1$BCE~wZ_Td`}F-OE9OvQr{k1_u(8E?FjpLH3sEm+?!#VzbMQl9Z@|hkrFZpf6~RDJ|s*>BHig#gP+z=rl$Qn@z!QR>k-4SI`59(iv5Xc4J)Giixno z5iFB1K@Y1j8YgCFdWTm`8^q<+KA|$N4V{7sa->$Q4s|bQ^ ztn_b6e9WgOni0E7=iKg2qnfwTsg2K13?tuVcv!S;JI$)h&_nLVoj`KHtSCyl)bP+} zm&SB!_)^o7Qt9R|X0~oCWKV2=TyIfJcxQA(jRl)0Q`hX_P$b=4XmrR{^kjHy-Ke3=bK!uRhRT7NQVl{51?7ctB(M{xk%cL@VLhiX}u_Z(d^j=LLZuG|Rw`haE z*@(&`?Ndw`Xr_2Qr$l?+m%?wxxE2^7kup*XTXhmQa@FSeLqL9~Mw&%Rg5 zvA}QLm1N|E&bo%Tmy2cz@ylK%ZsmP7>+r}iQyiskJ%>l2;E};Y_Oej8LU~J4;;=^Q z=AGvolw?p3o?DJ1|Kfsj;m-h!2~Z7QUih zVu)7qayS%CH|d;)+f%_{7d8ZjSc_oI(Y{)6)`>^UfvgpNb-;>n#);g(qD~A&SiAKW z8p5K%FkmRR`4HdZS4%fz>CwSBHpFA(}>W<@O`QFbKsN{-*X_ZYo8H zIU*n>`7D|xVYD;M(>9!0g&yw0S?sV011$+676*sc%2%n}D;TaiWRr8^cv_T-9nJ5< z{+F*0QV*d#_~10PedyFDNfQD@f=n0^(7SuT7xDn)x&pQMULJLVVQ;LTD`$&;nx7pt zV^590zi-%QNPZ=-zW8daZ>DDN%GaOSPNoriyxkW zLsf9Bhy*n{-FDo&a>WvxyH~4u_sKw|Dspf3%o%&db%5RDEf3leEaqWU(QuG~$@zL2 zae_YJl`Q_XZr{ye&iwY|i1(T~ecz>1t;IrOq13857vQ*hp^CK@5Pss}G<*NvO$5}> zobII{%FgA_l1Qcbce+%JgBRWA)1}#cKtCAOrGhP&+-x0wZJGn~L$sHJu0(1tW2P*h zrm_eXFh|tNiCd@RBsxJPbBd@p%yDH8%52aJ33BR(Fh1Hs1X2^f28M}?pILg-Uw!_B zY~Q|Ef-U{leJB$!fAH%4<|b6R|Lod)dalNO@MPeg{X9D}$!kPsev^ZKJleV?T#~Mu zIV+)jZzW<&_3*op{?_H(zaw0%^xVYei1^p6Fi@M;t9HMY@z2Co^Q$Db%k$x|Ta!Qp zL8i>D5E{#ZB-6TUNJlCfJ}*!_fGg5T=L1$r5@hlQb=g?}2IKs}vqXL_FIis?~%yJ!?s=UR-MSbx#4=0Gx@u?9uhZjrX0oj}wXQhTfO~ zl;YmY!&9@upDmuB!7#a~VJY>yufGYT!ro#jQq8DgrovmmElOMgR|dZY+%)vIr~bp8 zw}8vra=Ie(zd#k?^wR9{I;=Rza2Az`u}a;v>{TZp;V-~0 z3yimb;fe@E5)Bo=bCJ}Y`TMqGhZntLu15@~_s_pDws@9o+Eb1j3|p9wLt%1b=DEXS zE7G4vIuK5Cb=}*>J{~y_(6yR;(@Ww^eN>Ex`%{N-rz1qNo{NYPF<4?MI{*3!4_FR0 zP4Yan-sYU`7h3gLZCPQ%51U_H8d3%%L4vf`#dz_|kK>rx=g0^~K-Fml*hl8UyYs~TW!yI@ zrVSf@Xo+vVWiOl3>xT$Idq4;1)rpDgAXJ~T9?x>3qX*u9&lvrF!Go?lAU%yn`?g(qoqH+9z`^Qxv>04qI-n1z@Ohzcwpd~M=K_zbguHD0)1q2Y#{xj*c%z{i6t zO@WP`0DEN`MDY{uVnH*uXuoD&pI#ejLK!`&7c68*IdVu4Kpx@(r{ZcU9#4w$@=;l} z_T-D)scaa*kc@+ZNuRqFLL|7{ZPTfGysp;{olC20_SSzAeV%$8_v40W(z(38t45&e zhZZ{aiG!C63Rs)t&Jk4Wd)k(Uoe`A}E3XQ40}~J=9mGarFV%b+imf{A%xC}7y}tcJxB2q>ZD-HpQw!uWPuifU{H+%F4BG z)(^KqBW|iTO&C_PHxIcV!_;%O@I9mmI7>wOMB`k*dWe^48kwdDc9?ysc+QIqGohw1 zobC!xBHx@5_8X&ut+(S?zdZy!PTj?7bK?`YN7Fp=;U8{sjLxKW&;J=v)q1v2s~$obfw_?E)9X!W=_qs)SER zonELrUlRf$elKtd3mPd)&vVI*vCg|z{h%(-0Ng@Vf>(U*B;$$W(TGZaUmNPq`ve2j z3N#0%Ef5093=`Ou047oOR$byc!YSg9h)B)e8S8PzeFyr5oAYxL!X6oF(fM&JOtBVU zAU5Gr`=r6=use6kaO)%W8VIn!v$Gn#ADY~kTyL2H1u;6`{X z7xQp4Z)MkCli(mL4?1J_(GzW"nE@C}W8x+qIxu>!o_XYjSL{q=>If_@mah$bU7 zZ}w`NsEDa^h1<+;b#8la);*rxB~v?}Nk}&7W0y(`$j2X(y;9!59SYo7BmG#Y>$a>2 zGnQV4Tuw`G%v;wrl3JyPE1LLgF!MojaPV|D%h&5-5mb~kP>~jl>f3pzmXq?QNx37n6~>*@~XhW3zFKAtG6+l`ActTnWBzZ~pn zKSE0~S{OWnAJOCb0#e2Zwb3S29#SY%FPQ3gUY~PCZ$;(t^(7JEfz$j6W=LFifTp)gcF(`~0vU%i4g+3WxNMXP_J_c97TC$Sx(+~1K5^~Od`Dqh1kO)bY z&E06tcyfRJrBV54WFuzSvm>k#SofJgR8>$--`GpLj|)aAXxa(-ehqk5$kX6n%)IIN z7NKR@DqLwJ7L9DWpYtgi@upyNd2aN|S_=bDM)*tIP{NCY`xAd2MenCgJIZftL`B%t ztg{RugwJ>C{5ZpdXQ2eHUpvCUG4{a=SAVdVNpIK@Pg0bNYx%&zycI8*^i50%27y{Y zfIFSgM1pDwa#@QOz@zSlI__sv-;S5K_8xT_r#o|0F>OTHibsbEi{i6(^tb*_v_x{D zgRR;G2K1kJgbGAk{c;qZVsY(pG-O_Kl}xc_*#d9b>T^!YLSt-HoelmThxQ2e8ugy+ z^qqD2r3x^@yaEJvNfT-fH(%<|$FOTSkoBgXz7wN@_ zU7;unziS)6TRfKUdlbhzwr7e@g5T_uAv$Iwc3!IqFy4QPZHk_&O1z-~yr+H)J<+8Y zAvk6Z9msDHB&J6TpbFC~gx*IqdYvz%560P(R8gps5_b}UMOCoPS(bsF zw9h{J$rCmif?EnMMLlQnFfr)?%TMGz_vyAiWzh3q(?G9Ia5C%Rl6anAjY|57id>ZsLxH}G+G`?OPEOn=TBGePt?$bU!W_m}Ed9*!7icbE0wDS2=bHB5Zh0132x6EsIeLAxr$D zU%%&URjKuCpuDpjgykoBwT!7LsdDys1V=w0=H9i^*%EdOPp~~zzw8{Y$j;veA>^Ir zK4_IV+AC0jO_!E?VBg79Au4q|GW@%Qh6MzQ55fJKm(T;Ali3;xN!JtzBr+jVd}5yu zDEiZ&X;{}WTmE30EUQ=dW2$K~zVU*1d6LgIv5Y4pXM(x*wT*0lo3OME3h}`6_XR|F zdU*#h(Yw#QM4VTnbBRa8E?qG#z3xX{>iRNHZ_C7^oyCvixZ0%Bkzgt0bB|ykXq-`_ zB(sDmfdrTNHS)KFVvTm!If2k=y(Qb($C?hrx;a_P7k1_%)wq{2nz#BH~z~; zx#fw+I-m`Wy*H~s2^r%+Q}IWGz2TXAP@IQVTo-{0$@Z~TD|NHs4TY*$)4_tX_iwSb za%dTixpat|MT@o61Jt{#;e6ZqpdNI?#KTb8m>I3!!MO6kpP_&FZjUhvDQYJ9z4j;h z6BZdS^((EU;qHB7t*)SPt_PxBquB0y62lY@$IhY7UAg+y=O@pn~q z_Ooc)+a}>E%TxV~fMN@B%Cz|jiP?L0Ul+1o2c(*_I5mjGE$dcAIo3wp@>azxH$5Pn zWQ{nL(=$49023n}Ry*N%jx+0n06de$5(S1-Xiv;UzH=bEAF-V1Ix?Zb=KvY^jq1a? zd9t0gu}2$+JX@=jtIr$<0@yfC(p*%U5M%%6U;73%nr)A3T#?}bc9~d@+>cw#Z@ zc;&XBi_WVaG6c4U{ht$lCLuedt$FQnlt<#*F9Y<{C)WVgs@5f%%_nS$;hMqDMv@ek zve?;SyJW6#qbPm9>}J!>4j-HR@m?3()v{CQenfMZj68de@JA55@Dro|%f_`N0m)Uf z#F%s;BD8v+G3r%)=g+2o-!xaTWCm;c31){33`E4~lB;orxBo%g=0|T+w}!KSDPjS$jQk1E4Aom;$Tjqah@ z1&0-LYb*4*8|rk?s*D&}WA$`$ga}qLG=w&CFY$5Bv0KqPk}E}*OtM@(@+CC-rrq!! z8Cr{*F&iN}71zs80t1=D{I;Asenzym~l{@Ed)O7P&`N)(Qd8P6i5vr$TO8S zCN0XQ?5-Gfs56+tG>cDLAB!KocF!v$WS(W!AA83!OC(H-=^G;Y^bQaEMk@+M?IAr8 z!9|K_3$7kyA^o%%z+ZLQ)J{>V^Ha6#SrpdUnpHd4XBoNfy? zm #include +#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" #include "mediapipe/framework/api2/const_str.h" #include "mediapipe/framework/api2/contract.h" #include "mediapipe/framework/api2/node.h" @@ -46,7 +48,7 @@ struct TagIndexLocation { template class TagIndexMap { public: - std::vector>& operator[](const std::string& tag) { + std::vector>& operator[](absl::string_view tag) { return map_[tag]; } @@ -72,7 +74,7 @@ class TagIndexMap { // Note: entries are held by a unique_ptr to ensure pointers remain valid. // Should use absl::flat_hash_map but ordering keys for now. - std::map>> map_; + absl::btree_map>> map_; }; class Graph; @@ -169,6 +171,16 @@ class SourceImpl { return AddTarget(dest); } + template + struct AllowCast + : public std::integral_constant && + !std::is_same_v> {}; + + template {}, int> = 0> + SourceImpl Cast() { + return SourceImpl(base_); + } + private: // Never null. SourceBase* base_; @@ -212,19 +224,19 @@ class NodeBase { // of its entries by index. However, for nodes without visible contracts we // can't know whether a tag is indexable or not, so we would need the // multi-port to also be usable as a port directly (representing index 0). - MultiSource<> Out(const std::string& tag) { + MultiSource<> Out(absl::string_view tag) { return MultiSource<>(&out_streams_[tag]); } - MultiDestination<> In(const std::string& tag) { + MultiDestination<> In(absl::string_view tag) { return MultiDestination<>(&in_streams_[tag]); } - MultiSideSource<> SideOut(const std::string& tag) { + MultiSideSource<> SideOut(absl::string_view tag) { return MultiSideSource<>(&out_sides_[tag]); } - MultiSideDestination<> SideIn(const std::string& tag) { + MultiSideDestination<> SideIn(absl::string_view tag) { return MultiSideDestination<>(&in_sides_[tag]); } @@ -359,11 +371,11 @@ class PacketGenerator { public: PacketGenerator(std::string type) : type_(std::move(type)) {} - MultiSideSource<> SideOut(const std::string& tag) { + MultiSideSource<> SideOut(absl::string_view tag) { return MultiSideSource<>(&out_sides_[tag]); } - MultiSideDestination<> SideIn(const std::string& tag) { + MultiSideDestination<> SideIn(absl::string_view tag) { return MultiSideDestination<>(&in_sides_[tag]); } @@ -452,19 +464,19 @@ class Graph { } // Graph ports, non-typed. - MultiSource<> In(const std::string& graph_input) { + MultiSource<> In(absl::string_view graph_input) { return graph_boundary_.Out(graph_input); } - MultiDestination<> Out(const std::string& graph_output) { + MultiDestination<> Out(absl::string_view graph_output) { return graph_boundary_.In(graph_output); } - MultiSideSource<> SideIn(const std::string& graph_input) { + MultiSideSource<> SideIn(absl::string_view graph_input) { return graph_boundary_.SideOut(graph_input); } - MultiSideDestination<> SideOut(const std::string& graph_output) { + MultiSideDestination<> SideOut(absl::string_view graph_output) { return graph_boundary_.SideIn(graph_output); } diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index 28e42da97..3244e092d 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -2,6 +2,7 @@ #include +#include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/packet.h" @@ -296,6 +297,32 @@ TEST(BuilderTest, EmptyTag) { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } +TEST(BuilderTest, StringLikeTags) { + const char kA[] = "A"; + const std::string kB = "B"; + constexpr absl::string_view kC = "C"; + + builder::Graph graph; + auto& foo = graph.AddNode("Foo"); + graph.In(kA).SetName("a") >> foo.In(kA); + graph.In(kB).SetName("b") >> foo.In(kB); + foo.Out(kC).SetName("c") >> graph.Out(kC); + + CalculatorGraphConfig expected = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "A:a" + input_stream: "B:b" + output_stream: "C:c" + node { + calculator: "Foo" + input_stream: "A:a" + input_stream: "B:b" + output_stream: "C:c" + } + )pb"); + EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); +} + TEST(BuilderTest, GraphIndexes) { builder::Graph graph; auto& foo = graph.AddNode("Foo"); @@ -326,57 +353,91 @@ TEST(BuilderTest, GraphIndexes) { class AnyAndSameTypeCalculator : public NodeIntf { public: - static constexpr Input kAnyTypeInput{"INPUT"}; - static constexpr Output kAnyTypeOutput{"ANY_OUTPUT"}; - static constexpr Output> kSameTypeOutput{ + static constexpr Input::Optional kAnyTypeInput{"INPUT"}; + static constexpr Output::Optional kAnyTypeOutput{"ANY_OUTPUT"}; + static constexpr Output>::Optional kSameTypeOutput{ "SAME_OUTPUT"}; + static constexpr Output> kRecursiveSameTypeOutput{ + "RECURSIVE_SAME_OUTPUT"}; - static constexpr Input kIntInput{"INT_INPUT"}; + static constexpr Input::Optional kIntInput{"INT_INPUT"}; // `SameType` usage for this output is only for testing purposes. // // `SameType` is designed to work with inputs of `AnyType` and, normally, you // would not use `Output>` in a real calculator. You // should write `Output` instead, since the type is known. - static constexpr Output> kSameIntOutput{ + static constexpr Output>::Optional kSameIntOutput{ "SAME_INT_OUTPUT"}; + static constexpr Output> kRecursiveSameIntOutput{ + "RECURSIVE_SAME_INT_OUTPUT"}; - MEDIAPIPE_NODE_INTERFACE(AnyTypeCalculator, kAnyTypeInput, kAnyTypeOutput, - kSameTypeOutput); + MEDIAPIPE_NODE_INTERFACE(AnyAndSameTypeCalculator, kAnyTypeInput, + kAnyTypeOutput, kSameTypeOutput); }; TEST(BuilderTest, AnyAndSameTypeHandledProperly) { builder::Graph graph; - builder::Source any_input = - graph[Input{"GRAPH_ANY_INPUT"}]; + builder::Source any_input = graph[Input{"GRAPH_ANY_INPUT"}]; builder::Source int_input = graph[Input{"GRAPH_INT_INPUT"}]; auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; int_input >> node[AnyAndSameTypeCalculator::kIntInput]; - builder::Source any_type_output = + builder::Source any_type_output = node[AnyAndSameTypeCalculator::kAnyTypeOutput]; any_type_output.SetName("any_type_output"); - builder::Source same_type_output = + builder::Source same_type_output = node[AnyAndSameTypeCalculator::kSameTypeOutput]; same_type_output.SetName("same_type_output"); - builder::Source same_int_output = + builder::Source recursive_same_type_output = + node[AnyAndSameTypeCalculator::kRecursiveSameTypeOutput]; + recursive_same_type_output.SetName("recursive_same_type_output"); + builder::Source same_int_output = node[AnyAndSameTypeCalculator::kSameIntOutput]; same_int_output.SetName("same_int_output"); + builder::Source recursive_same_int_type_output = + node[AnyAndSameTypeCalculator::kRecursiveSameIntOutput]; + recursive_same_int_type_output.SetName("recursive_same_int_type_output"); + + CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie< + CalculatorGraphConfig>(R"pb( + node { + calculator: "AnyAndSameTypeCalculator" + input_stream: "INPUT:__stream_0" + input_stream: "INT_INPUT:__stream_1" + output_stream: "ANY_OUTPUT:any_type_output" + output_stream: "RECURSIVE_SAME_INT_OUTPUT:recursive_same_int_type_output" + output_stream: "RECURSIVE_SAME_OUTPUT:recursive_same_type_output" + output_stream: "SAME_INT_OUTPUT:same_int_output" + output_stream: "SAME_OUTPUT:same_type_output" + } + input_stream: "GRAPH_ANY_INPUT:__stream_0" + input_stream: "GRAPH_INT_INPUT:__stream_1" + )pb"); + EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); +} + +TEST(BuilderTest, AnyTypeCanBeCast) { + builder::Graph graph; + builder::Source any_input = + graph.In("GRAPH_ANY_INPUT").Cast(); + + auto& node = graph.AddNode("AnyAndSameTypeCalculator"); + any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; + builder::Source any_type_output = + node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast(); + any_type_output.SetName("any_type_output"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( node { calculator: "AnyAndSameTypeCalculator" input_stream: "INPUT:__stream_0" - input_stream: "INT_INPUT:__stream_1" output_stream: "ANY_OUTPUT:any_type_output" - output_stream: "SAME_INT_OUTPUT:same_int_output" - output_stream: "SAME_OUTPUT:same_type_output" } input_stream: "GRAPH_ANY_INPUT:__stream_0" - input_stream: "GRAPH_INT_INPUT:__stream_1" )pb"); EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } diff --git a/mediapipe/framework/api2/packet.h b/mediapipe/framework/api2/packet.h index 4ff726da0..7933575d3 100644 --- a/mediapipe/framework/api2/packet.h +++ b/mediapipe/framework/api2/packet.h @@ -27,9 +27,7 @@ using HolderBase = mediapipe::packet_internal::HolderBase; template class Packet; -struct DynamicType {}; - -struct AnyType : public DynamicType { +struct AnyType { AnyType() = delete; }; diff --git a/mediapipe/framework/api2/port.h b/mediapipe/framework/api2/port.h index a408831bc..e63d3651e 100644 --- a/mediapipe/framework/api2/port.h +++ b/mediapipe/framework/api2/port.h @@ -73,14 +73,12 @@ class SideOutputBase : public PortBase { }; struct NoneType { - private: NoneType() = delete; }; -template -class SameType : public DynamicType { - public: - static constexpr const decltype(P)& kPort = P; +template +struct SameType { + static constexpr const decltype(kP)& kPort = kP; }; class PacketTypeAccess; @@ -137,21 +135,28 @@ struct IsOneOf : std::false_type {}; template struct IsOneOf> : std::true_type {}; -template {} && !IsOneOf{}, - int>::type = 0> +template +struct IsSameType : std::false_type {}; + +template +struct IsSameType> : std::true_type {}; + +template {} && + !IsOneOf{} && !IsSameType{}, + int>::type = 0> inline void SetType(CalculatorContract* cc, PacketType& pt) { pt.Set(); } -template {}, - int>::type = 0> +template {}, int>::type = 0> inline void SetType(CalculatorContract* cc, PacketType& pt) { pt.SetSameAs(&internal::GetCollection(cc, T::kPort).Tag(T::kPort.Tag())); } -template <> -inline void SetType(CalculatorContract* cc, PacketType& pt) { +template {}, int>::type = 0> +inline void SetType(CalculatorContract* cc, PacketType& pt) { pt.SetAny(); } @@ -289,15 +294,15 @@ struct SideBase { }; // TODO: maybe return a PacketBase instead of a Packet? -template +template struct ActualPayloadType { using type = T; }; template -struct ActualPayloadType< - T, std::enable_if_t{}, void>> { - using type = internal::Generic; +struct ActualPayloadType{}, void>> { + using type = typename ActualPayloadType< + typename std::decay_t::value_t>::type; }; } // namespace internal diff --git a/mediapipe/framework/formats/image_frame_opencv.cc b/mediapipe/framework/formats/image_frame_opencv.cc index ada28ae35..940e18263 100644 --- a/mediapipe/framework/formats/image_frame_opencv.cc +++ b/mediapipe/framework/formats/image_frame_opencv.cc @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2022 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -85,14 +85,8 @@ cv::Mat MatView(const ImageFrame* image) { const size_t steps[] = {static_cast(image->WidthStep()), static_cast(image->ByteDepth())}; // Use ImageFrame to initialize in-place. ImageFrame still owns memory. - if (steps[0] == sizes[1] * image->NumberOfChannels() * image->ByteDepth()) { - // Contiguous memory optimization. See b/78570764 - return cv::Mat(dims, sizes, type, const_cast(image->PixelData())); - } else { - // Custom width step. - return cv::Mat(dims, sizes, type, const_cast(image->PixelData()), - steps); - } + return cv::Mat(dims, sizes, type, const_cast(image->PixelData()), + steps); } } // namespace formats diff --git a/mediapipe/framework/formats/image_frame_opencv.h b/mediapipe/framework/formats/image_frame_opencv.h index cfeed1c6b..d197d9a55 100644 --- a/mediapipe/framework/formats/image_frame_opencv.h +++ b/mediapipe/framework/formats/image_frame_opencv.h @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2022 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/mediapipe/framework/formats/image_frame_opencv_test.cc b/mediapipe/framework/formats/image_frame_opencv_test.cc index 6a6e0a65c..f75915d06 100644 --- a/mediapipe/framework/formats/image_frame_opencv_test.cc +++ b/mediapipe/framework/formats/image_frame_opencv_test.cc @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2022 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,7 +21,6 @@ #include "mediapipe/framework/port/logging.h" namespace mediapipe { - namespace { // Set image_frame to a constant per-channel pix_value. @@ -50,8 +49,8 @@ TEST(ImageFrameOpencvTest, ConvertToMat) { ImageFrame frame2(ImageFormat::GRAY8, i_width, i_height); // Check adding constant images. - const uint8 frame1_val = 12; - const uint8 frame2_val = 34; + const uint8_t frame1_val = 12; + const uint8_t frame2_val = 34; SetToColor(&frame1_val, &frame1); SetToColor(&frame2_val, &frame2); // Get Mat wrapper around ImageFrame memory (zero copy). @@ -77,6 +76,37 @@ TEST(ImageFrameOpencvTest, ConvertToMat) { EXPECT_EQ(max_loc.y, i_height - 6); } +TEST(ImageFrameOpencvTest, ConvertToIpl) { + const int i_width = 123, i_height = 45; + ImageFrame frame1(ImageFormat::GRAY8, i_width, i_height); + ImageFrame frame2(ImageFormat::GRAY8, i_width, i_height); + + // Check adding constant images. + const uint8_t frame1_val = 12; + const uint8_t frame2_val = 34; + SetToColor(&frame1_val, &frame1); + SetToColor(&frame2_val, &frame2); + const cv::Mat frame1_mat = formats::MatView(&frame1); + const cv::Mat frame2_mat = formats::MatView(&frame2); + const cv::Mat frame_sum = frame1_mat + frame2_mat; + const auto frame_avg = static_cast(cv::mean(frame_sum).val[0]); + EXPECT_EQ(frame_avg, frame1_val + frame2_val); + + // Check setting min/max pixels. + uint8* frame1_ptr = frame1.MutablePixelData(); + frame1_ptr[(i_width - 5) + (i_height - 5) * frame1.WidthStep()] = 1; + frame1_ptr[(i_width - 6) + (i_height - 6) * frame1.WidthStep()] = 100; + double min, max; + cv::Point min_loc, max_loc; + cv::minMaxLoc(frame1_mat, &min, &max, &min_loc, &max_loc); + EXPECT_EQ(min, 1); + EXPECT_EQ(min_loc.x, i_width - 5); + EXPECT_EQ(min_loc.y, i_height - 5); + EXPECT_EQ(max, 100); + EXPECT_EQ(max_loc.x, i_width - 6); + EXPECT_EQ(max_loc.y, i_height - 6); +} + TEST(ImageFrameOpencvTest, ImageFormats) { const int i_width = 123, i_height = 45; ImageFrame frame_g8(ImageFormat::GRAY8, i_width, i_height); diff --git a/mediapipe/framework/formats/image_opencv.cc b/mediapipe/framework/formats/image_opencv.cc index 3debbe421..9ccaa632b 100644 --- a/mediapipe/framework/formats/image_opencv.cc +++ b/mediapipe/framework/formats/image_opencv.cc @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2022 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/mediapipe/framework/formats/image_opencv.h b/mediapipe/framework/formats/image_opencv.h index b1bc4954d..488b87f43 100644 --- a/mediapipe/framework/formats/image_opencv.h +++ b/mediapipe/framework/formats/image_opencv.h @@ -1,4 +1,4 @@ -// Copyright 2019-2020 The MediaPipe Authors. +// Copyright 2022 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index eb06d14f0..ef0cddea4 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -37,26 +37,21 @@ namespace mediapipe { bool IsPowerOfTwo(int v) { return (v & (v - 1)) == 0; } int BhwcBatchFromShape(const Tensor::Shape& shape) { - LOG_IF(FATAL, shape.dims.empty()) - << "Tensor::Shape must be non-empty to retrieve a named dimension"; + if (shape.dims.empty()) { + return 1; + } return shape.dims[0]; } int BhwcHeightFromShape(const Tensor::Shape& shape) { - LOG_IF(FATAL, shape.dims.empty()) - << "Tensor::Shape must be non-empty to retrieve a named dimension"; return shape.dims.size() < 4 ? 1 : shape.dims[shape.dims.size() - 3]; } int BhwcWidthFromShape(const Tensor::Shape& shape) { - LOG_IF(FATAL, shape.dims.empty()) - << "Tensor::Shape must be non-empty to retrieve a named dimension"; return shape.dims.size() < 3 ? 1 : shape.dims[shape.dims.size() - 2]; } int BhwcDepthFromShape(const Tensor::Shape& shape) { - LOG_IF(FATAL, shape.dims.empty()) - << "Tensor::Shape must be non-empty to retrieve a named dimension"; return shape.dims.size() < 2 ? 1 : shape.dims[shape.dims.size() - 1]; } @@ -424,14 +419,36 @@ Tensor::Tensor(ElementType element_type, const Shape& shape, #if MEDIAPIPE_METAL_ENABLED void Tensor::Invalidate() { - absl::MutexLock lock(&view_mutex_); - // If memory is allocated and not owned by the metal buffer. - // TODO: Re-design cpu buffer memory management. - if (cpu_buffer_ && !metal_buffer_) { - DeallocateVirtualMemory(cpu_buffer_, AlignToPageSize(bytes())); +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 + GLuint cleanup_gl_tex = GL_INVALID_INDEX; + GLuint cleanup_gl_fb = GL_INVALID_INDEX; +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 + { + absl::MutexLock lock(&view_mutex_); + // If memory is allocated and not owned by the metal buffer. + // TODO: Re-design cpu buffer memory management. + if (cpu_buffer_ && !metal_buffer_) { + DeallocateVirtualMemory(cpu_buffer_, AlignToPageSize(bytes())); + } + metal_buffer_ = nil; + cpu_buffer_ = nullptr; +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 + // Don't need to wait for the resource to be deleted bacause if will be + // released on last reference deletion inside the OpenGL driver. + std::swap(cleanup_gl_tex, opengl_texture2d_); + std::swap(cleanup_gl_fb, frame_buffer_); +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 } - metal_buffer_ = nil; - cpu_buffer_ = nullptr; +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 + // Do not hold the view mutex while invoking GlContext::RunWithoutWaiting, + // since that method may acquire the context's own lock. + if (cleanup_gl_tex != GL_INVALID_INDEX || cleanup_gl_fb != GL_INVALID_INDEX) { + gl_context_->RunWithoutWaiting([cleanup_gl_tex, cleanup_gl_fb]() { + glDeleteTextures(1, &cleanup_gl_tex); + glDeleteFramebuffers(1, &cleanup_gl_fb); + }); + } +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 } #else diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index b9a7b9fcd..4353947d8 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -95,9 +96,8 @@ class Tensor { Shape(std::initializer_list dimensions) : dims(dimensions) {} Shape(const std::vector& dimensions) : dims(dimensions) {} int num_elements() const { - int res = dims.empty() ? 0 : 1; - std::for_each(dims.begin(), dims.end(), [&res](int i) { res *= i; }); - return res; + return std::accumulate(dims.begin(), dims.end(), 1, + std::multiplies()); } std::vector dims; }; diff --git a/mediapipe/framework/mediapipe_cc_test.bzl b/mediapipe/framework/mediapipe_cc_test.bzl index 6ccbebb0c..fe0d44e0c 100644 --- a/mediapipe/framework/mediapipe_cc_test.bzl +++ b/mediapipe/framework/mediapipe_cc_test.bzl @@ -15,7 +15,7 @@ def mediapipe_cc_test( platforms = ["linux", "android", "ios", "wasm"], exclude_platforms = None, # ios_unit_test arguments - ios_minimum_os_version = "9.0", + ios_minimum_os_version = "11.0", # android_cc_test arguments open_gl_driver = None, emulator_mini_boot = True, diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 6446eb3e5..237aa825f 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -108,6 +108,7 @@ cc_library( ":sharded_map", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_profile_cc_proto", + "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:integral_types", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", diff --git a/mediapipe/framework/profiler/graph_profiler.cc b/mediapipe/framework/profiler/graph_profiler.cc index 0503f868f..f14acfc78 100644 --- a/mediapipe/framework/profiler/graph_profiler.cc +++ b/mediapipe/framework/profiler/graph_profiler.cc @@ -22,6 +22,7 @@ #include "absl/time/time.h" #include "mediapipe/framework/port/advanced_proto_lite_inc.h" #include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/re2.h" @@ -244,7 +245,16 @@ absl::Status GraphProfiler::Start(mediapipe::Executor* executor) { executor != nullptr) { // Inform the user via logging the path to the trace logs. ASSIGN_OR_RETURN(std::string trace_log_path, GetTraceLogPath()); - LOG(INFO) << "trace_log_path: " << trace_log_path; + // Check that we can actually write to it. + auto status = + file::SetContents(absl::StrCat(trace_log_path, "trace_writing_check"), + "can write trace logs to this location"); + if (status.ok()) { + LOG(INFO) << "trace_log_path: " << trace_log_path; + } else { + LOG(ERROR) << "cannot write to trace_log_path: " << trace_log_path << ": " + << status; + } is_running_ = true; executor->Schedule([this] { diff --git a/mediapipe/framework/testdata/perfetto_minimal.pbtxt b/mediapipe/framework/testdata/perfetto_minimal.pbtxt index bdbc3c8e1..b5a563723 100644 --- a/mediapipe/framework/testdata/perfetto_minimal.pbtxt +++ b/mediapipe/framework/testdata/perfetto_minimal.pbtxt @@ -5,7 +5,7 @@ buffers: { size_kb: 150000 - fill_policy: DISCARD + fill_policy: RING_BUFFER } data_sources: { @@ -14,19 +14,21 @@ data_sources: { } } data_sources: { - config { - name: "linux.ftrace" - ftrace_config { - # Scheduling information & process tracking. Useful for: - # - what is happening on each CPU at each moment - ftrace_events: "power/cpu_frequency" - ftrace_events: "power/cpu_idle" - ftrace_events: "sched/sched_switch" - compact_sched { - enabled: true - } - } + config { + name: "linux.ftrace" + ftrace_config { + # Scheduling information & process tracking. Useful for: + # - what is happening on each CPU at each moment + ftrace_events: "power/cpu_frequency" + ftrace_events: "power/cpu_idle" + ftrace_events: "sched/sched_switch" + compact_sched { + enabled: true + } } + } } write_into_file: true file_write_period_ms: 500 +# b/243571696 Added to remove Perfetto timeouts when running benchmarks remotely. +duration_ms: 60000 diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index de35f4fd6..106738a49 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -821,6 +821,19 @@ cc_library( alwayslink = 1, ) +mediapipe_cc_test( + name = "switch_demux_calculator_test", + srcs = ["switch_demux_calculator_test.cc"], + deps = [ + ":container_util", + ":switch_demux_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:logging", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "switch_mux_calculator", srcs = ["switch_mux_calculator.cc"], diff --git a/mediapipe/framework/tool/switch_demux_calculator.cc b/mediapipe/framework/tool/switch_demux_calculator.cc index c4352c871..b9ba2a0fb 100644 --- a/mediapipe/framework/tool/switch_demux_calculator.cc +++ b/mediapipe/framework/tool/switch_demux_calculator.cc @@ -129,12 +129,12 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) { // Relay side packets to all channels. // Note: This is necessary because Calculator::Open only proceeds when every // anticipated side-packet arrives. - int channel_count = tool::ChannelCount(cc->OutputSidePackets().TagMap()); + int side_channel_count = tool::ChannelCount(cc->OutputSidePackets().TagMap()); for (const std::string& tag : ChannelTags(cc->OutputSidePackets().TagMap())) { int num_entries = cc->InputSidePackets().NumEntries(tag); for (int index = 0; index < num_entries; ++index) { Packet input = cc->InputSidePackets().Get(tag, index); - for (int channel = 0; channel < channel_count; ++channel) { + for (int channel = 0; channel < side_channel_count; ++channel) { std::string output_tag = tool::ChannelTag(tag, channel); auto output_id = cc->OutputSidePackets().GetId(output_tag, index); if (output_id.IsValid()) { @@ -143,6 +143,23 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) { } } } + + // Relay headers to all channels. + int output_channel_count = tool::ChannelCount(cc->Outputs().TagMap()); + for (const std::string& tag : ChannelTags(cc->Outputs().TagMap())) { + int num_entries = cc->Inputs().NumEntries(tag); + for (int index = 0; index < num_entries; ++index) { + auto& input = cc->Inputs().Get(tag, index); + if (input.Header().IsEmpty()) continue; + for (int channel = 0; channel < output_channel_count; ++channel) { + std::string output_tag = tool::ChannelTag(tag, channel); + auto output_id = cc->Outputs().GetId(output_tag, index); + if (output_id.IsValid()) { + cc->Outputs().Get(output_tag, index).SetHeader(input.Header()); + } + } + } + } return absl::OkStatus(); } diff --git a/mediapipe/framework/tool/switch_demux_calculator_test.cc b/mediapipe/framework/tool/switch_demux_calculator_test.cc new file mode 100644 index 000000000..acb1b1702 --- /dev/null +++ b/mediapipe/framework/tool/switch_demux_calculator_test.cc @@ -0,0 +1,135 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/tool/container_util.h" + +namespace mediapipe { +namespace { + +// Returns a CalculatorGraph to run a single calculator. +CalculatorGraph BuildCalculatorGraph(CalculatorGraphConfig::Node node_config) { + CalculatorGraphConfig config; + *config.add_node() = node_config; + *config.mutable_input_stream() = node_config.input_stream(); + *config.mutable_output_stream() = node_config.output_stream(); + *config.mutable_input_side_packet() = node_config.input_side_packet(); + *config.mutable_output_side_packet() = node_config.output_side_packet(); + return CalculatorGraph(config); +} + +// Creates a string packet. +Packet pack(std::string data, int timestamp) { + return MakePacket(data).At(Timestamp(timestamp)); +} + +// Creates an int packet. +Packet pack(int data, int timestamp) { + return MakePacket(data).At(Timestamp(timestamp)); +} + +// Tests showing packet channel synchronization through SwitchDemuxCalculator. +class SwitchDemuxCalculatorTest : public ::testing::Test { + protected: + SwitchDemuxCalculatorTest() {} + ~SwitchDemuxCalculatorTest() override {} + void SetUp() override {} + void TearDown() override {} + + // Defines a SwitchDemuxCalculator CalculatorGraphConfig::Node. + CalculatorGraphConfig::Node BuildNodeConfig() { + CalculatorGraphConfig::Node result; + *result.mutable_calculator() = "SwitchDemuxCalculator"; + *result.add_input_stream() = "SELECT:select"; + for (int c = 0; c < 2; ++c) { + *result.add_output_stream() = + absl::StrCat(tool::ChannelTag("FRAME", c), ":frame_", c); + *result.add_output_stream() = + absl::StrCat(tool::ChannelTag("MASK", c), ":mask_", c); + } + *result.add_input_stream() = "FRAME:frame"; + *result.add_input_stream() = "MASK:mask"; + return result; + } +}; + +// Shows the SwitchMuxCalculator is available. +TEST_F(SwitchDemuxCalculatorTest, IsRegistered) { + EXPECT_TRUE(CalculatorBaseRegistry::IsRegistered("SwitchDemuxCalculator")); +} + +TEST_F(SwitchDemuxCalculatorTest, BasicDataFlow) { + CalculatorGraphConfig::Node node_config = BuildNodeConfig(); + CalculatorGraph graph = BuildCalculatorGraph(node_config); + std::vector output_frames0; + EXPECT_TRUE(graph + .ObserveOutputStream("frame_0", + [&](const Packet& p) { + output_frames0.push_back(p); + return absl::OkStatus(); + }) + .ok()); + std::vector output_frames1; + EXPECT_TRUE(graph + .ObserveOutputStream("frame_1", + [&](const Packet& p) { + output_frames1.push_back(p); + return absl::OkStatus(); + }) + .ok()); + EXPECT_TRUE( + graph.StartRun({}, {{"frame", MakePacket("frame_header")}}) + .ok()); + + // Finalize input for the "mask" input stream. + EXPECT_TRUE(graph.CloseInputStream("mask").ok()); + + // Channel 0 is selected just before corresponding packets arrive. + EXPECT_TRUE(graph.AddPacketToInputStream("select", pack(0, 1)).ok()); + EXPECT_TRUE(graph.AddPacketToInputStream("select", pack(0, 10)).ok()); + EXPECT_TRUE(graph.AddPacketToInputStream("frame", pack("p0_t10", 10)).ok()); + EXPECT_TRUE(graph.WaitUntilIdle().ok()); + EXPECT_EQ(output_frames0.size(), 1); + EXPECT_EQ(output_frames1.size(), 0); + EXPECT_EQ(output_frames0[0].Get(), "p0_t10"); + + // Channel 1 is selected just before corresponding packets arrive. + EXPECT_TRUE(graph.AddPacketToInputStream("select", pack(1, 11)).ok()); + EXPECT_TRUE(graph.AddPacketToInputStream("select", pack(1, 20)).ok()); + EXPECT_TRUE(graph.AddPacketToInputStream("frame", pack("p1_t20", 20)).ok()); + EXPECT_TRUE(graph.WaitUntilIdle().ok()); + EXPECT_EQ(output_frames0.size(), 1); + EXPECT_EQ(output_frames1.size(), 1); + EXPECT_EQ(output_frames1[0].Get(), "p1_t20"); + + EXPECT_EQ( + graph.FindOutputStreamManager("frame_0")->Header().Get(), + "frame_header"); + EXPECT_EQ( + graph.FindOutputStreamManager("frame_1")->Header().Get(), + "frame_header"); + + EXPECT_TRUE(graph.CloseAllPacketSources().ok()); + EXPECT_TRUE(graph.WaitUntilDone().ok()); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index a6dd98985..9b5de0235 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -271,6 +271,7 @@ cc_library( deps = [ ":gpu_buffer_format", ":gpu_buffer_storage", + "@com_google_absl//absl/strings", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:logging", ":gpu_buffer_storage_image_frame", @@ -366,6 +367,23 @@ cc_library( ], ) +cc_library( + name = "gpu_buffer_storage_ahwb", + srcs = ["gpu_buffer_storage_ahwb.cc"], + hdrs = ["gpu_buffer_storage_ahwb.h"], + linkopts = select({ + "//conditions:default": [], + "//mediapipe:android": [ + "-landroid", + ], + }), + deps = [ + ":gpu_buffer_format", + ":gpu_buffer_storage", + "@com_google_absl//absl/strings:str_format", + ], +) + mediapipe_proto_library( name = "gpu_origin_proto", srcs = ["gpu_origin.proto"], @@ -1087,3 +1105,19 @@ ios_unit_test( ], deps = [":gl_ios_test_lib"], ) + +mediapipe_cc_test( + name = "gpu_buffer_storage_ahwb_test", + size = "small", + srcs = ["gpu_buffer_storage_ahwb_test.cc"], + exclude_platforms = [ + "ios", + "wasm", + ], + requires_full_emulation = True, + deps = [ + ":gpu_buffer_format", + ":gpu_buffer_storage_ahwb", + "//mediapipe/framework/port:gtest_main", + ], +) diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index 72e88468e..7f7ba0e23 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -620,7 +620,9 @@ class GlSyncWrapper { #endif GLenum result = glClientWaitSync(sync_, flags, timeout); if (result == GL_ALREADY_SIGNALED || result == GL_CONDITION_SATISFIED) { - Clear(); + // TODO: we could clear at this point so later calls are faster, + // but we need to do so in a thread-safe way. + // Clear(); } // TODO: do something if the wait fails? } @@ -646,7 +648,9 @@ class GlSyncWrapper { #endif GLenum result = glClientWaitSync(sync_, flags, 0); if (result == GL_ALREADY_SIGNALED || result == GL_CONDITION_SATISFIED) { - Clear(); + // TODO: we could clear at this point so later calls are faster, + // but we need to do so in a thread-safe way. + // Clear(); return true; } return false; @@ -822,10 +826,17 @@ std::shared_ptr GlContext::CreateSyncToken() { return token; } +bool GlContext::IsAnyContextCurrent() { + ContextBinding ctx; + GetCurrentContextBinding(&ctx); + return ctx.context != kPlatformGlContextNone; +} + std::shared_ptr GlContext::CreateSyncTokenForCurrentExternalContext( const std::shared_ptr& delegate_graph_context) { CHECK(delegate_graph_context); + if (!IsAnyContextCurrent()) return nullptr; if (delegate_graph_context->ShouldUseFenceSync()) { return std::shared_ptr( new GlExternalFenceSyncPoint(delegate_graph_context)); diff --git a/mediapipe/gpu/gl_context.h b/mediapipe/gpu/gl_context.h index 81cfc2e8b..957cb510f 100644 --- a/mediapipe/gpu/gl_context.h +++ b/mediapipe/gpu/gl_context.h @@ -303,6 +303,10 @@ class GlContext : public std::enable_shared_from_this { return *static_cast(entry.get()); } + // Returns true if any GL context, including external contexts not managed by + // the GlContext class, is current. + static bool IsAnyContextCurrent(); + // Creates a synchronization token for the current, non-GlContext-owned // context. This can be passed to MediaPipe so it can synchronize with the // commands issued in the external context up to this point. diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index d48b35a05..69d2fab7a 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -145,9 +145,13 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) { CHECK_NE(name_, 0); GLuint name_to_delete = name_; context->RunWithoutWaiting([name_to_delete, sync_token]() { - // TODO: maybe we do not actually have to wait for the - // consumer sync here. Check docs. - sync_token->WaitOnGpu(); + if (sync_token) { + // TODO: maybe we do not actually have to wait for the + // consumer sync here. Check docs. + sync_token->WaitOnGpu(); + } else { + LOG_FIRST_N(WARNING, 5) << "unexpected null sync in deletion_callback"; + } DLOG_IF(ERROR, !glIsTexture(name_to_delete)) << "Deleting invalid texture id: " << name_to_delete; glDeleteTextures(1, &name_to_delete); @@ -179,13 +183,19 @@ void GlTextureBuffer::Reuse() { void GlTextureBuffer::Updated(std::shared_ptr prod_token) { CHECK(!producer_sync_) << "Updated existing texture which had not been marked for reuse!"; + CHECK(prod_token); producer_sync_ = std::move(prod_token); producer_context_ = producer_sync_->GetContext(); } void GlTextureBuffer::DidRead(std::shared_ptr cons_token) const { absl::MutexLock lock(&consumer_sync_mutex_); - consumer_multi_sync_->Add(std::move(cons_token)); + if (cons_token) { + consumer_multi_sync_->Add(std::move(cons_token)); + } else { + // TODO: change to a CHECK. + LOG_FIRST_N(WARNING, 5) << "unexpected null sync in DidRead"; + } } GlTextureBuffer::~GlTextureBuffer() { diff --git a/mediapipe/gpu/gpu_buffer.cc b/mediapipe/gpu/gpu_buffer.cc index e899fc85d..e570ce8ba 100644 --- a/mediapipe/gpu/gpu_buffer.cc +++ b/mediapipe/gpu/gpu_buffer.cc @@ -2,6 +2,8 @@ #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "mediapipe/framework/port/logging.h" #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER @@ -10,6 +12,23 @@ namespace mediapipe { +namespace { + +struct StorageTypeFormatter { + void operator()(std::string* out, + const std::shared_ptr& s) const { + absl::StrAppend(out, s->storage_type().name()); + } +}; + +} // namespace + +std::string GpuBuffer::DebugString() const { + return absl::StrCat("GpuBuffer[", + absl::StrJoin(storages_, ", ", StorageTypeFormatter()), + "]"); +} + internal::GpuBufferStorage& GpuBuffer::GetStorageForView( TypeId view_provider_type, bool for_writing) const { const std::shared_ptr* chosen_storage = nullptr; @@ -52,7 +71,10 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForView( } } - CHECK(chosen_storage) << "no view provider found"; + CHECK(chosen_storage) << "no view provider found for requested view " + << view_provider_type.name() << "; storages available: " + << absl::StrJoin(storages_, ", ", + StorageTypeFormatter()); DCHECK((*chosen_storage)->can_down_cast_to(view_provider_type)); return **chosen_storage; } diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index 88bff7e1f..57e077151 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -129,6 +129,8 @@ class GpuBuffer { return nullptr; } + std::string DebugString() const; + private: class PlaceholderGpuBufferStorage : public internal::GpuBufferStorageImpl { diff --git a/mediapipe/graphs/pose_tracking/subgraphs/BUILD b/mediapipe/graphs/pose_tracking/subgraphs/BUILD index 8831692cb..5f06736eb 100644 --- a/mediapipe/graphs/pose_tracking/subgraphs/BUILD +++ b/mediapipe/graphs/pose_tracking/subgraphs/BUILD @@ -21,18 +21,29 @@ licenses(["notice"]) package(default_visibility = ["//visibility:public"]) +mediapipe_simple_subgraph( + name = "pose_landmarks_to_render_data", + graph = "pose_landmarks_to_render_data.pbtxt", + register_as = "PoseLandmarksToRenderData", + deps = [ + "//mediapipe/calculators/core:concatenate_vector_calculator", + "//mediapipe/calculators/core:split_proto_list_calculator", + "//mediapipe/calculators/util:landmarks_to_render_data_calculator", + "//mediapipe/calculators/util:rect_to_render_scale_calculator", + ], +) + mediapipe_simple_subgraph( name = "pose_renderer_gpu", graph = "pose_renderer_gpu.pbtxt", register_as = "PoseRendererGpu", deps = [ - "//mediapipe/calculators/core:split_proto_list_calculator", + ":pose_landmarks_to_render_data", + "//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/image:recolor_calculator", "//mediapipe/calculators/util:annotation_overlay_calculator", "//mediapipe/calculators/util:detections_to_render_data_calculator", - "//mediapipe/calculators/util:landmarks_to_render_data_calculator", "//mediapipe/calculators/util:rect_to_render_data_calculator", - "//mediapipe/calculators/util:rect_to_render_scale_calculator", ], ) @@ -41,12 +52,11 @@ mediapipe_simple_subgraph( graph = "pose_renderer_cpu.pbtxt", register_as = "PoseRendererCpu", deps = [ - "//mediapipe/calculators/core:split_proto_list_calculator", + ":pose_landmarks_to_render_data", + "//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/image:recolor_calculator", "//mediapipe/calculators/util:annotation_overlay_calculator", "//mediapipe/calculators/util:detections_to_render_data_calculator", - "//mediapipe/calculators/util:landmarks_to_render_data_calculator", "//mediapipe/calculators/util:rect_to_render_data_calculator", - "//mediapipe/calculators/util:rect_to_render_scale_calculator", ], ) diff --git a/mediapipe/graphs/pose_tracking/subgraphs/pose_landmarks_to_render_data.pbtxt b/mediapipe/graphs/pose_tracking/subgraphs/pose_landmarks_to_render_data.pbtxt new file mode 100644 index 000000000..6c4f70153 --- /dev/null +++ b/mediapipe/graphs/pose_tracking/subgraphs/pose_landmarks_to_render_data.pbtxt @@ -0,0 +1,236 @@ +# MediaPipe pose landmarks to render data subgraph. + +type: "PoseLandmarksToRenderData" + +# Pose landmarks. (NormalizedLandmarkList) +input_stream: "LANDMARKS:pose_landmarks" +# Region of interest calculated based on landmarks. (NormalizedRect) +input_stream: "ROI:roi" +# Image size. (pair) +input_stream: "IMAGE_SIZE:image_size" + +# The resulting render data. (vector) +output_stream: "RENDER_DATA:merged_render_data" + +# Calculates rendering scale based on the pose roi. +node { + calculator: "RectToRenderScaleCalculator" + input_stream: "NORM_RECT:roi" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "RENDER_SCALE:render_scale" + node_options: { + [type.googleapis.com/mediapipe.RectToRenderScaleCalculatorOptions] { + multiplier: 0.0012 + } + } +} + +node { + calculator: "SplitNormalizedLandmarkListCalculator" + input_stream: "pose_landmarks" + output_stream: "visible_pose_landmarks" + node_options: { + [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { + ranges: { begin: 0 end: 25 } + } + } +} + +# Converts landmarks to drawing primitives for annotation overlay. +node { + calculator: "LandmarksToRenderDataCalculator" + input_stream: "NORM_LANDMARKS:pose_landmarks" + input_stream: "RENDER_SCALE:render_scale" + output_stream: "RENDER_DATA:landmarks_render_data" + node_options: { + [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { + landmark_connections: 0 + landmark_connections: 1 + landmark_connections: 1 + landmark_connections: 2 + landmark_connections: 2 + landmark_connections: 3 + landmark_connections: 3 + landmark_connections: 7 + landmark_connections: 0 + landmark_connections: 4 + landmark_connections: 4 + landmark_connections: 5 + landmark_connections: 5 + landmark_connections: 6 + landmark_connections: 6 + landmark_connections: 8 + landmark_connections: 9 + landmark_connections: 10 + landmark_connections: 11 + landmark_connections: 12 + landmark_connections: 11 + landmark_connections: 13 + landmark_connections: 13 + landmark_connections: 15 + landmark_connections: 15 + landmark_connections: 17 + landmark_connections: 15 + landmark_connections: 19 + landmark_connections: 15 + landmark_connections: 21 + landmark_connections: 17 + landmark_connections: 19 + landmark_connections: 12 + landmark_connections: 14 + landmark_connections: 14 + landmark_connections: 16 + landmark_connections: 16 + landmark_connections: 18 + landmark_connections: 16 + landmark_connections: 20 + landmark_connections: 16 + landmark_connections: 22 + landmark_connections: 18 + landmark_connections: 20 + landmark_connections: 11 + landmark_connections: 23 + landmark_connections: 12 + landmark_connections: 24 + landmark_connections: 23 + landmark_connections: 24 + landmark_connections: 23 + landmark_connections: 25 + landmark_connections: 24 + landmark_connections: 26 + landmark_connections: 25 + landmark_connections: 27 + landmark_connections: 26 + landmark_connections: 28 + landmark_connections: 27 + landmark_connections: 29 + landmark_connections: 28 + landmark_connections: 30 + landmark_connections: 29 + landmark_connections: 31 + landmark_connections: 30 + landmark_connections: 32 + landmark_connections: 27 + landmark_connections: 31 + landmark_connections: 28 + landmark_connections: 32 + + landmark_color { r: 255 g: 255 b: 255 } + connection_color { r: 255 g: 255 b: 255 } + thickness: 3.0 + visualize_landmark_depth: false + utilize_visibility: true + visibility_threshold: 0.5 + } + } +} + +# Take left pose landmarks. +node { + calculator: "SplitNormalizedLandmarkListCalculator" + input_stream: "pose_landmarks" + output_stream: "landmarks_left_side" + node_options: { + [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { + ranges: { begin: 1 end: 4 } + ranges: { begin: 7 end: 8 } + ranges: { begin: 9 end: 10 } + ranges: { begin: 11 end: 12 } + ranges: { begin: 13 end: 14 } + ranges: { begin: 15 end: 16 } + ranges: { begin: 17 end: 18 } + ranges: { begin: 19 end: 20 } + ranges: { begin: 21 end: 22 } + ranges: { begin: 23 end: 24 } + + combine_outputs: true + } + } +} + +# Take right pose landmarks. +node { + calculator: "SplitNormalizedLandmarkListCalculator" + input_stream: "pose_landmarks" + output_stream: "landmarks_right_side" + node_options: { + [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { + ranges: { begin: 4 end: 7 } + ranges: { begin: 8 end: 9 } + ranges: { begin: 10 end: 11 } + ranges: { begin: 12 end: 13 } + ranges: { begin: 14 end: 15 } + ranges: { begin: 16 end: 17 } + ranges: { begin: 18 end: 19 } + ranges: { begin: 20 end: 21 } + ranges: { begin: 22 end: 23 } + ranges: { begin: 24 end: 25 } + + combine_outputs: true + } + } +} + +# Render pose joints as big white circles. +node { + calculator: "LandmarksToRenderDataCalculator" + input_stream: "NORM_LANDMARKS:visible_pose_landmarks" + input_stream: "RENDER_SCALE:render_scale" + output_stream: "RENDER_DATA:landmarks_background_joints_render_data" + node_options: { + [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { + landmark_color { r: 255 g: 255 b: 255 } + connection_color { r: 255 g: 255 b: 255 } + thickness: 5.0 + visualize_landmark_depth: false + utilize_visibility: true + visibility_threshold: 0.5 + } + } +} + +# Render pose left side joints as orange circles (inside white ones). +node { + calculator: "LandmarksToRenderDataCalculator" + input_stream: "NORM_LANDMARKS:landmarks_left_side" + input_stream: "RENDER_SCALE:render_scale" + output_stream: "RENDER_DATA:landmarks_left_joints_render_data" + node_options: { + [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { + landmark_color { r: 255 g: 138 b: 0 } + connection_color { r: 255 g: 138 b: 0 } + thickness: 3.0 + visualize_landmark_depth: false + utilize_visibility: true + visibility_threshold: 0.5 + } + } +} + +# Render pose right side joints as cyan circles (inside white ones). +node { + calculator: "LandmarksToRenderDataCalculator" + input_stream: "NORM_LANDMARKS:landmarks_right_side" + input_stream: "RENDER_SCALE:render_scale" + output_stream: "RENDER_DATA:landmarks_right_joints_render_data" + node_options: { + [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { + landmark_color { r: 0 g: 217 b: 231 } + connection_color { r: 0 g: 217 b: 231 } + thickness: 3.0 + visualize_landmark_depth: false + utilize_visibility: true + visibility_threshold: 0.5 + } + } +} + +# Merges annotations into one result. +node { + calculator: "ConcatenateRenderDataVectorCalculator" + input_stream: "landmarks_render_data" + input_stream: "landmarks_background_joints_render_data" + input_stream: "landmarks_left_joints_render_data" + input_stream: "landmarks_right_joints_render_data" + output_stream: "merged_render_data" +} diff --git a/mediapipe/graphs/pose_tracking/subgraphs/pose_renderer_cpu.pbtxt b/mediapipe/graphs/pose_tracking/subgraphs/pose_renderer_cpu.pbtxt index e176765dd..998c77232 100644 --- a/mediapipe/graphs/pose_tracking/subgraphs/pose_renderer_cpu.pbtxt +++ b/mediapipe/graphs/pose_tracking/subgraphs/pose_renderer_cpu.pbtxt @@ -22,19 +22,6 @@ node { output_stream: "SIZE:image_size" } -# Calculates rendering scale based on the pose roi. -node { - calculator: "RectToRenderScaleCalculator" - input_stream: "NORM_RECT:roi" - input_stream: "IMAGE_SIZE:image_size" - output_stream: "RENDER_SCALE:render_scale" - node_options: { - [type.googleapis.com/mediapipe.RectToRenderScaleCalculatorOptions] { - multiplier: 0.0012 - } - } -} - # Converts detections to drawing primitives for annotation overlay. node { calculator: "DetectionsToRenderDataCalculator" @@ -48,204 +35,13 @@ node { } } +# Computes render data for landmarks. node { - calculator: "SplitNormalizedLandmarkListCalculator" - input_stream: "pose_landmarks" - output_stream: "visible_pose_landmarks" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 0 end: 25 } - } - } -} - -# Converts landmarks to drawing primitives for annotation overlay. -node { - calculator: "LandmarksToRenderDataCalculator" - input_stream: "NORM_LANDMARKS:pose_landmarks" - input_stream: "RENDER_SCALE:render_scale" + calculator: "PoseLandmarksToRenderData" + input_stream: "LANDMARKS:pose_landmarks" + input_stream: "ROI:roi" + input_stream: "IMAGE_SIZE:image_size" output_stream: "RENDER_DATA:landmarks_render_data" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_connections: 0 - landmark_connections: 1 - landmark_connections: 1 - landmark_connections: 2 - landmark_connections: 2 - landmark_connections: 3 - landmark_connections: 3 - landmark_connections: 7 - landmark_connections: 0 - landmark_connections: 4 - landmark_connections: 4 - landmark_connections: 5 - landmark_connections: 5 - landmark_connections: 6 - landmark_connections: 6 - landmark_connections: 8 - landmark_connections: 9 - landmark_connections: 10 - landmark_connections: 11 - landmark_connections: 12 - landmark_connections: 11 - landmark_connections: 13 - landmark_connections: 13 - landmark_connections: 15 - landmark_connections: 15 - landmark_connections: 17 - landmark_connections: 15 - landmark_connections: 19 - landmark_connections: 15 - landmark_connections: 21 - landmark_connections: 17 - landmark_connections: 19 - landmark_connections: 12 - landmark_connections: 14 - landmark_connections: 14 - landmark_connections: 16 - landmark_connections: 16 - landmark_connections: 18 - landmark_connections: 16 - landmark_connections: 20 - landmark_connections: 16 - landmark_connections: 22 - landmark_connections: 18 - landmark_connections: 20 - landmark_connections: 11 - landmark_connections: 23 - landmark_connections: 12 - landmark_connections: 24 - landmark_connections: 23 - landmark_connections: 24 - landmark_connections: 23 - landmark_connections: 25 - landmark_connections: 24 - landmark_connections: 26 - landmark_connections: 25 - landmark_connections: 27 - landmark_connections: 26 - landmark_connections: 28 - landmark_connections: 27 - landmark_connections: 29 - landmark_connections: 28 - landmark_connections: 30 - landmark_connections: 29 - landmark_connections: 31 - landmark_connections: 30 - landmark_connections: 32 - landmark_connections: 27 - landmark_connections: 31 - landmark_connections: 28 - landmark_connections: 32 - - landmark_color { r: 255 g: 255 b: 255 } - connection_color { r: 255 g: 255 b: 255 } - thickness: 3.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.5 - } - } -} - -# Take left pose landmarks. -node { - calculator: "SplitNormalizedLandmarkListCalculator" - input_stream: "pose_landmarks" - output_stream: "landmarks_left_side" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 1 end: 4 } - ranges: { begin: 7 end: 8 } - ranges: { begin: 9 end: 10 } - ranges: { begin: 11 end: 12 } - ranges: { begin: 13 end: 14 } - ranges: { begin: 15 end: 16 } - ranges: { begin: 17 end: 18 } - ranges: { begin: 19 end: 20 } - ranges: { begin: 21 end: 22 } - ranges: { begin: 23 end: 24 } - - combine_outputs: true - } - } -} - -# Take right pose landmarks. -node { - calculator: "SplitNormalizedLandmarkListCalculator" - input_stream: "pose_landmarks" - output_stream: "landmarks_right_side" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 4 end: 7 } - ranges: { begin: 8 end: 9 } - ranges: { begin: 10 end: 11 } - ranges: { begin: 12 end: 13 } - ranges: { begin: 14 end: 15 } - ranges: { begin: 16 end: 17 } - ranges: { begin: 18 end: 19 } - ranges: { begin: 20 end: 21 } - ranges: { begin: 22 end: 23 } - ranges: { begin: 24 end: 25 } - - combine_outputs: true - } - } -} - -# Render pose joints as big white circles. -node { - calculator: "LandmarksToRenderDataCalculator" - input_stream: "NORM_LANDMARKS:visible_pose_landmarks" - input_stream: "RENDER_SCALE:render_scale" - output_stream: "RENDER_DATA:landmarks_background_joints_render_data" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_color { r: 255 g: 255 b: 255 } - connection_color { r: 255 g: 255 b: 255 } - thickness: 5.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.5 - } - } -} - -# Render pose left side joints as orange circles (inside white ones). -node { - calculator: "LandmarksToRenderDataCalculator" - input_stream: "NORM_LANDMARKS:landmarks_left_side" - input_stream: "RENDER_SCALE:render_scale" - output_stream: "RENDER_DATA:landmarks_left_joints_render_data" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_color { r: 255 g: 138 b: 0 } - connection_color { r: 255 g: 138 b: 0 } - thickness: 3.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.5 - } - } -} - -# Render pose right side joints as cyan circles (inside white ones). -node { - calculator: "LandmarksToRenderDataCalculator" - input_stream: "NORM_LANDMARKS:landmarks_right_side" - input_stream: "RENDER_SCALE:render_scale" - output_stream: "RENDER_DATA:landmarks_right_joints_render_data" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_color { r: 0 g: 217 b: 231 } - connection_color { r: 0 g: 217 b: 231 } - thickness: 3.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.5 - } - } } # Converts normalized rects to drawing primitives for annotation overlay. @@ -283,10 +79,7 @@ node { calculator: "AnnotationOverlayCalculator" input_stream: "IMAGE:segmented_image" input_stream: "detection_render_data" - input_stream: "landmarks_render_data" - input_stream: "landmarks_background_joints_render_data" - input_stream: "landmarks_left_joints_render_data" - input_stream: "landmarks_right_joints_render_data" + input_stream: "VECTOR:landmarks_render_data" input_stream: "roi_render_data" output_stream: "IMAGE:output_image" } diff --git a/mediapipe/graphs/pose_tracking/subgraphs/pose_renderer_gpu.pbtxt b/mediapipe/graphs/pose_tracking/subgraphs/pose_renderer_gpu.pbtxt index 4d680c6ca..6285c8afb 100644 --- a/mediapipe/graphs/pose_tracking/subgraphs/pose_renderer_gpu.pbtxt +++ b/mediapipe/graphs/pose_tracking/subgraphs/pose_renderer_gpu.pbtxt @@ -22,19 +22,6 @@ node { output_stream: "SIZE:image_size" } -# Calculates rendering scale based on the pose roi. -node { - calculator: "RectToRenderScaleCalculator" - input_stream: "NORM_RECT:roi" - input_stream: "IMAGE_SIZE:image_size" - output_stream: "RENDER_SCALE:render_scale" - node_options: { - [type.googleapis.com/mediapipe.RectToRenderScaleCalculatorOptions] { - multiplier: 0.0012 - } - } -} - # Converts detections to drawing primitives for annotation overlay. node { calculator: "DetectionsToRenderDataCalculator" @@ -48,204 +35,13 @@ node { } } +# Computes render data for landmarks. node { - calculator: "SplitNormalizedLandmarkListCalculator" - input_stream: "pose_landmarks" - output_stream: "visible_pose_landmarks" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 0 end: 25 } - } - } -} - -# Converts landmarks to drawing primitives for annotation overlay. -node { - calculator: "LandmarksToRenderDataCalculator" - input_stream: "NORM_LANDMARKS:pose_landmarks" - input_stream: "RENDER_SCALE:render_scale" + calculator: "PoseLandmarksToRenderData" + input_stream: "LANDMARKS:pose_landmarks" + input_stream: "ROI:roi" + input_stream: "IMAGE_SIZE:image_size" output_stream: "RENDER_DATA:landmarks_render_data" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_connections: 0 - landmark_connections: 1 - landmark_connections: 1 - landmark_connections: 2 - landmark_connections: 2 - landmark_connections: 3 - landmark_connections: 3 - landmark_connections: 7 - landmark_connections: 0 - landmark_connections: 4 - landmark_connections: 4 - landmark_connections: 5 - landmark_connections: 5 - landmark_connections: 6 - landmark_connections: 6 - landmark_connections: 8 - landmark_connections: 9 - landmark_connections: 10 - landmark_connections: 11 - landmark_connections: 12 - landmark_connections: 11 - landmark_connections: 13 - landmark_connections: 13 - landmark_connections: 15 - landmark_connections: 15 - landmark_connections: 17 - landmark_connections: 15 - landmark_connections: 19 - landmark_connections: 15 - landmark_connections: 21 - landmark_connections: 17 - landmark_connections: 19 - landmark_connections: 12 - landmark_connections: 14 - landmark_connections: 14 - landmark_connections: 16 - landmark_connections: 16 - landmark_connections: 18 - landmark_connections: 16 - landmark_connections: 20 - landmark_connections: 16 - landmark_connections: 22 - landmark_connections: 18 - landmark_connections: 20 - landmark_connections: 11 - landmark_connections: 23 - landmark_connections: 12 - landmark_connections: 24 - landmark_connections: 23 - landmark_connections: 24 - landmark_connections: 23 - landmark_connections: 25 - landmark_connections: 24 - landmark_connections: 26 - landmark_connections: 25 - landmark_connections: 27 - landmark_connections: 26 - landmark_connections: 28 - landmark_connections: 27 - landmark_connections: 29 - landmark_connections: 28 - landmark_connections: 30 - landmark_connections: 29 - landmark_connections: 31 - landmark_connections: 30 - landmark_connections: 32 - landmark_connections: 27 - landmark_connections: 31 - landmark_connections: 28 - landmark_connections: 32 - - landmark_color { r: 255 g: 255 b: 255 } - connection_color { r: 255 g: 255 b: 255 } - thickness: 3.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.5 - } - } -} - -# Take left pose landmarks. -node { - calculator: "SplitNormalizedLandmarkListCalculator" - input_stream: "pose_landmarks" - output_stream: "landmarks_left_side" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 1 end: 4 } - ranges: { begin: 7 end: 8 } - ranges: { begin: 9 end: 10 } - ranges: { begin: 11 end: 12 } - ranges: { begin: 13 end: 14 } - ranges: { begin: 15 end: 16 } - ranges: { begin: 17 end: 18 } - ranges: { begin: 19 end: 20 } - ranges: { begin: 21 end: 22 } - ranges: { begin: 23 end: 24 } - - combine_outputs: true - } - } -} - -# Take right pose landmarks. -node { - calculator: "SplitNormalizedLandmarkListCalculator" - input_stream: "pose_landmarks" - output_stream: "landmarks_right_side" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 4 end: 7 } - ranges: { begin: 8 end: 9 } - ranges: { begin: 10 end: 11 } - ranges: { begin: 12 end: 13 } - ranges: { begin: 14 end: 15 } - ranges: { begin: 16 end: 17 } - ranges: { begin: 18 end: 19 } - ranges: { begin: 20 end: 21 } - ranges: { begin: 22 end: 23 } - ranges: { begin: 24 end: 25 } - - combine_outputs: true - } - } -} - -# Render pose joints as big white circles. -node { - calculator: "LandmarksToRenderDataCalculator" - input_stream: "NORM_LANDMARKS:visible_pose_landmarks" - input_stream: "RENDER_SCALE:render_scale" - output_stream: "RENDER_DATA:landmarks_background_joints_render_data" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_color { r: 255 g: 255 b: 255 } - connection_color { r: 255 g: 255 b: 255 } - thickness: 5.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.5 - } - } -} - -# Render pose left side joints as orange circles (inside white ones). -node { - calculator: "LandmarksToRenderDataCalculator" - input_stream: "NORM_LANDMARKS:landmarks_left_side" - input_stream: "RENDER_SCALE:render_scale" - output_stream: "RENDER_DATA:landmarks_left_joints_render_data" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_color { r: 255 g: 138 b: 0 } - connection_color { r: 255 g: 138 b: 0 } - thickness: 3.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.5 - } - } -} - -# Render pose right side joints as cyan circles (inside white ones). -node { - calculator: "LandmarksToRenderDataCalculator" - input_stream: "NORM_LANDMARKS:landmarks_right_side" - input_stream: "RENDER_SCALE:render_scale" - output_stream: "RENDER_DATA:landmarks_right_joints_render_data" - node_options: { - [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { - landmark_color { r: 0 g: 217 b: 231 } - connection_color { r: 0 g: 217 b: 231 } - thickness: 3.0 - visualize_landmark_depth: false - utilize_visibility: true - visibility_threshold: 0.5 - } - } } # Converts normalized rects to drawing primitives for annotation overlay. @@ -283,10 +79,7 @@ node { calculator: "AnnotationOverlayCalculator" input_stream: "IMAGE_GPU:segmented_image" input_stream: "detection_render_data" - input_stream: "landmarks_render_data" - input_stream: "landmarks_background_joints_render_data" - input_stream: "landmarks_left_joints_render_data" - input_stream: "landmarks_right_joints_render_data" + input_stream: "VECTOR:landmarks_render_data" input_stream: "roi_render_data" output_stream: "IMAGE_GPU:output_image" } diff --git a/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java b/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java index c375aa61f..88e191d26 100644 --- a/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java +++ b/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java @@ -174,6 +174,14 @@ public class ExternalTextureConverter implements TextureFrameProducer { thread.setRotation(rotation); } + /** + * Sets whether the timestamps of each frame should be adjusted to be always monotonically + * increasing. The default behavior is that this is {@code true}. + */ + public void setShouldAdjustTimestamps(boolean shouldAdjustTimestamps) { + thread.setShouldAdjustTimestamps(shouldAdjustTimestamps); + } + /** * Sets an offset that can be used to adjust the timestamps on the camera frames, for example to * conform to a preferred time-base or to account for a known device latency. The offset is added @@ -298,6 +306,7 @@ public class ExternalTextureConverter implements TextureFrameProducer { private int bufferPoolMaxSize; private ExternalTextureRenderer renderer = null; + private boolean shouldAdjustTimestamps = true; private long nextFrameTimestampOffset = 0; private long timestampOffsetNanos = 0; private long previousTimestamp = 0; @@ -433,6 +442,10 @@ public class ExternalTextureConverter implements TextureFrameProducer { super.releaseGl(); // This releases the EGL context, so must do it after any GL calls. } + public void setShouldAdjustTimestamps(boolean shouldAdjustTimestamps) { + this.shouldAdjustTimestamps = shouldAdjustTimestamps; + } + public void setTimestampOffsetNanos(long offsetInNanos) { timestampOffsetNanos = offsetInNanos; } @@ -565,7 +578,8 @@ public class ExternalTextureConverter implements TextureFrameProducer { // |nextFrameTimestampOffset| to ensure that timestamps increase monotonically.) long textureTimestamp = (surfaceTexture.getTimestamp() + timestampOffsetNanos) / NANOS_PER_MICRO; - if (previousTimestampValid + if (shouldAdjustTimestamps + && previousTimestampValid && textureTimestamp + nextFrameTimestampOffset <= previousTimestamp) { nextFrameTimestampOffset = previousTimestamp + 1 - textureTimestamp; } diff --git a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java index 69c0ebeb6..4af9dae78 100644 --- a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java @@ -15,6 +15,10 @@ package com.google.mediapipe.framework; import android.graphics.Bitmap; +import com.google.mediapipe.framework.image.BitmapExtractor; +import com.google.mediapipe.framework.image.ByteBufferExtractor; +import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.ImageProperties; import java.nio.ByteBuffer; // TODO: use Preconditions in this file. @@ -55,6 +59,50 @@ public class AndroidPacketCreator extends PacketCreator { return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap)); } + /** + * Creates an Image packet from an {@link Image}. + * + *

The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP. + */ + public Packet createImage(Image image) { + // TODO: Choose the best storage from multiple containers. + ImageProperties properties = image.getContainedImageProperties().get(0); + if (properties.getStorageType() == Image.STORAGE_TYPE_BYTEBUFFER) { + ByteBuffer buffer = ByteBufferExtractor.extract(image); + int numChannels = 0; + switch (properties.getImageFormat()) { + case Image.IMAGE_FORMAT_RGBA: + numChannels = 4; + break; + case Image.IMAGE_FORMAT_RGB: + numChannels = 3; + break; + case Image.IMAGE_FORMAT_ALPHA: + numChannels = 1; + break; + default: // fall out + } + if (numChannels == 0) { + throw new UnsupportedOperationException( + "Unsupported MediaPipe Image image format: " + properties.getImageFormat()); + } + int width = image.getWidth(); + int height = image.getHeight(); + return createImage(buffer, width, height, numChannels); + } + if (properties.getImageFormat() == Image.STORAGE_TYPE_BITMAP) { + Bitmap bitmap = BitmapExtractor.extract(image); + if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) { + throw new UnsupportedOperationException("bitmap must use ARGB_8888 config."); + } + return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap)); + } + + // Unsupported type. + throw new UnsupportedOperationException( + "Unsupported Image container type: " + properties.getImageFormat()); + } + /** * Returns the native handle of a new internal::PacketWithContext object on success. Returns 0 on * failure. diff --git a/mediapipe/java/com/google/mediapipe/framework/BUILD b/mediapipe/java/com/google/mediapipe/framework/BUILD index 7b1a89166..6b7fb1271 100644 --- a/mediapipe/java/com/google/mediapipe/framework/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/BUILD @@ -57,6 +57,7 @@ android_library( ], deps = [ ":android_core", + "//mediapipe/java/com/google/mediapipe/framework/image", "//third_party:androidx_annotation", "//third_party:androidx_legacy_support_v4", "@maven//:com_google_code_findbugs_jsr305", @@ -75,6 +76,7 @@ android_library( srcs = glob( ["**/*.java"], exclude = [ + "image/**", "Android*", "AssetCache.java", "AssetCacheDbHelper.java", diff --git a/mediapipe/java/com/google/mediapipe/framework/image/AndroidManifest.xml b/mediapipe/java/com/google/mediapipe/framework/image/AndroidManifest.xml new file mode 100644 index 000000000..7d5f48f08 --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/image/AndroidManifest.xml @@ -0,0 +1,6 @@ + + + + + diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BUILD b/mediapipe/java/com/google/mediapipe/framework/image/BUILD new file mode 100644 index 000000000..abf82a892 --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/image/BUILD @@ -0,0 +1,32 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + +licenses(["notice"]) + +android_library( + name = "image", + srcs = glob(["*.java"]), + manifest = "AndroidManifest.xml", + visibility = [ + "//mediapipe:__subpackages__", + ], + deps = [ + "//third_party:androidx_legacy_support_v4", + "//third_party:autovalue", + "@maven//:androidx_annotation_annotation", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java b/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java new file mode 100644 index 000000000..4c6cebd4d --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java @@ -0,0 +1,49 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.mediapipe.framework.image; + +import android.graphics.Bitmap; + +/** + * Utility for extracting {@link android.graphics.Bitmap} from {@link Image}. + * + *

Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BITMAP}, otherwise + * {@link IllegalArgumentException} will be thrown. + */ +public final class BitmapExtractor { + + /** + * Extracts a {@link android.graphics.Bitmap} from an {@link Image}. + * + * @param image the image to extract {@link android.graphics.Bitmap} from. + * @return the {@link android.graphics.Bitmap} stored in {@link Image} + * @throws IllegalArgumentException when the extraction requires unsupported format or data type + * conversions. + */ + public static Bitmap extract(Image image) { + ImageContainer imageContainer = image.getContainer(Image.STORAGE_TYPE_BITMAP); + if (imageContainer != null) { + return ((BitmapImageContainer) imageContainer).getBitmap(); + } else { + // TODO: Support ByteBuffer -> Bitmap conversion. + throw new IllegalArgumentException( + "Extracting Bitmap from an Image created by objects other than Bitmap is not" + + " supported"); + } + } + + private BitmapExtractor() {} +} diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java new file mode 100644 index 000000000..ea2ca6b1f --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java @@ -0,0 +1,72 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.mediapipe.framework.image; + +import android.content.Context; +import android.graphics.Bitmap; +import android.net.Uri; +import android.provider.MediaStore; +import java.io.IOException; + +/** + * Builds {@link Image} from {@link android.graphics.Bitmap}. + * + *

You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once + * {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content + * in it. + * + *

Use {@link BitmapExtractor} to get {@link android.graphics.Bitmap} you passed in. + */ +public class BitmapImageBuilder { + + // Mandatory fields. + private final Bitmap bitmap; + + // Optional fields. + private long timestamp; + + /** + * Creates the builder with a mandatory {@link android.graphics.Bitmap}. + * + * @param bitmap image data object. + */ + public BitmapImageBuilder(Bitmap bitmap) { + this.bitmap = bitmap; + timestamp = 0; + } + + /** + * Creates the builder to build {@link Image} from a file. + * + * @param context the application context. + * @param uri the path to the resource file. + */ + public BitmapImageBuilder(Context context, Uri uri) throws IOException { + this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri)); + } + + /** Sets value for {@link Image#getTimestamp()}. */ + BitmapImageBuilder setTimestamp(long timestamp) { + this.timestamp = timestamp; + return this; + } + + /** Builds an {@link Image} instance. */ + public Image build() { + return new Image( + new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight()); + } +} diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java new file mode 100644 index 000000000..0457e1e9b --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java @@ -0,0 +1,60 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.mediapipe.framework.image; + +import android.graphics.Bitmap; +import com.google.mediapipe.framework.image.Image.ImageFormat; + +class BitmapImageContainer implements ImageContainer { + + private final Bitmap bitmap; + private final ImageProperties properties; + + public BitmapImageContainer(Bitmap bitmap) { + this.bitmap = bitmap; + this.properties = + ImageProperties.builder() + .setImageFormat(convertFormatCode(bitmap.getConfig())) + .setStorageType(Image.STORAGE_TYPE_BITMAP) + .build(); + } + + public Bitmap getBitmap() { + return bitmap; + } + + @Override + public ImageProperties getImageProperties() { + return properties; + } + + @Override + public void close() { + bitmap.recycle(); + } + + @ImageFormat + static int convertFormatCode(Bitmap.Config config) { + switch (config) { + case ALPHA_8: + return Image.IMAGE_FORMAT_ALPHA; + case ARGB_8888: + return Image.IMAGE_FORMAT_RGBA; + default: + return Image.IMAGE_FORMAT_UNKNOWN; + } + } +} diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java new file mode 100644 index 000000000..a0e8c3dff --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java @@ -0,0 +1,254 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.mediapipe.framework.image; + +import android.annotation.SuppressLint; +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import android.os.Build.VERSION; +import android.os.Build.VERSION_CODES; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.framework.image.Image.ImageFormat; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Locale; + +/** + * Utility for extracting {@link ByteBuffer} from {@link Image}. + * + *

Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BYTEBUFFER}, otherwise + * {@link IllegalArgumentException} will be thrown. + */ +public class ByteBufferExtractor { + + /** + * Extracts a {@link ByteBuffer} from an {@link Image}. + * + *

The returned {@link ByteBuffer} is a read-only view, with the first available {@link + * ImageProperties} whose storage type is {@code Image.STORAGE_TYPE_BYTEBUFFER}. + * + * @see Image#getContainedImageProperties() + * @return A read-only {@link ByteBuffer}. + * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage. + */ + @SuppressLint("SwitchIntDef") + public static ByteBuffer extract(Image image) { + ImageContainer container = image.getContainer(); + switch (container.getImageProperties().getStorageType()) { + case Image.STORAGE_TYPE_BYTEBUFFER: + ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; + return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); + default: + throw new IllegalArgumentException( + "Extract ByteBuffer from an Image created by objects other than Bytebuffer is not" + + " supported"); + } + } + + /** + * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link Image}. + * + *

Format conversion spec: + * + *

    + *
  • When extracting RGB images to RGBA format, A channel will always set to 255. + *
  • When extracting RGBA images to RGB format, A channel will be dropped. + *
+ * + * @param image the image to extract buffer from. + * @param targetFormat the image format of the result bytebuffer. + * @return the readonly {@link ByteBuffer} stored in {@link Image} + * @throws IllegalArgumentException when the extraction requires unsupported format or data type + * conversions. + */ + static ByteBuffer extract(Image image, @ImageFormat int targetFormat) { + ImageContainer container; + ImageProperties byteBufferProperties = + ImageProperties.builder() + .setStorageType(Image.STORAGE_TYPE_BYTEBUFFER) + .setImageFormat(targetFormat) + .build(); + if ((container = image.getContainer(byteBufferProperties)) != null) { + ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; + return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); + } else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) { + ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; + @ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat(); + return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat) + .asReadOnlyBuffer(); + } else if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) { + BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container; + ByteBuffer byteBuffer = + extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat) + .asReadOnlyBuffer(); + boolean unused = image.addContainer(new ByteBufferImageContainer(byteBuffer, targetFormat)); + return byteBuffer; + } else { + throw new IllegalArgumentException( + "Extracting ByteBuffer from an Image created by objects other than Bitmap or" + + " Bytebuffer is not supported"); + } + } + + /** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */ + @AutoValue + abstract static class Result { + /** Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(Image)}. */ + public abstract ByteBuffer buffer(); + + /** Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(Image)}. */ + @ImageFormat + public abstract int format(); + + static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) { + return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat); + } + } + + /** + * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link Image}. + * + *

It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy. + * + * @return the readonly {@link ByteBuffer} stored in {@link Image} + * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with + * given {@code imageFormat} + */ + static Result extractInRecommendedFormat(Image image) { + ImageContainer container; + if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) { + Bitmap bitmap = ((BitmapImageContainer) container).getBitmap(); + @ImageFormat int format = adviseImageFormat(bitmap); + Result result = + Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format); + + boolean unused = + image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format())); + return result; + } else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) { + ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; + return Result.create( + byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(), + byteBufferImageContainer.getImageFormat()); + } else { + throw new IllegalArgumentException( + "Extract ByteBuffer from an Image created by objects other than Bitmap or Bytebuffer" + + " is not supported"); + } + } + + @ImageFormat + private static int adviseImageFormat(Bitmap bitmap) { + if (bitmap.getConfig() == Config.ARGB_8888) { + return Image.IMAGE_FORMAT_RGBA; + } else { + throw new IllegalArgumentException( + String.format( + "Extracting ByteBuffer from an Image created by a Bitmap in config %s is not" + + " supported", + bitmap.getConfig())); + } + } + + private static ByteBuffer extractByteBufferFromBitmap( + Bitmap bitmap, @ImageFormat int imageFormat) { + if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) { + throw new IllegalArgumentException( + "Extracting ByteBuffer from an Image created by a premultiplied Bitmap is not" + + " supported"); + } + if (bitmap.getConfig() == Config.ARGB_8888) { + if (imageFormat == Image.IMAGE_FORMAT_RGBA) { + ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount()); + bitmap.copyPixelsToBuffer(buffer); + buffer.rewind(); + return buffer; + } else if (imageFormat == Image.IMAGE_FORMAT_RGB) { + // TODO: Try Use RGBA buffer to create RGB buffer which might be faster. + int w = bitmap.getWidth(); + int h = bitmap.getHeight(); + int[] pixels = new int[w * h]; + bitmap.getPixels(pixels, 0, w, 0, 0, w, h); + ByteBuffer buffer = ByteBuffer.allocateDirect(w * h * 3); + buffer.order(ByteOrder.nativeOrder()); + for (int pixel : pixels) { + // getPixels returns Color in ARGB rather than copyPixelsToBuffer which returns RGBA + buffer.put((byte) ((pixel >> 16) & 0xff)); + buffer.put((byte) ((pixel >> 8) & 0xff)); + buffer.put((byte) (pixel & 0xff)); + } + buffer.rewind(); + return buffer; + } + } + throw new IllegalArgumentException( + String.format( + "Extracting ByteBuffer from an Image created by Bitmap and convert from %s to format" + + " %d is not supported", + bitmap.getConfig(), imageFormat)); + } + + private static ByteBuffer convertByteBuffer( + ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) { + if (sourceFormat == Image.IMAGE_FORMAT_RGB && targetFormat == Image.IMAGE_FORMAT_RGBA) { + ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4); + // Extend the buffer when the target is longer than the source. Use two cursors and sweep the + // array reversely to convert in-place. + byte[] array = new byte[target.capacity()]; + source.get(array, 0, source.capacity()); + source.rewind(); + int rgbCursor = source.capacity(); + int rgbaCursor = target.capacity(); + while (rgbCursor != rgbaCursor) { + array[--rgbaCursor] = (byte) 0xff; // A + array[--rgbaCursor] = array[--rgbCursor]; // B + array[--rgbaCursor] = array[--rgbCursor]; // G + array[--rgbaCursor] = array[--rgbCursor]; // R + } + target.put(array, 0, target.capacity()); + target.rewind(); + return target; + } else if (sourceFormat == Image.IMAGE_FORMAT_RGBA && targetFormat == Image.IMAGE_FORMAT_RGB) { + ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3); + // Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the + // array to convert in-place. + byte[] array = new byte[source.capacity()]; + source.get(array, 0, source.capacity()); + source.rewind(); + int rgbaCursor = 0; + int rgbCursor = 0; + while (rgbaCursor < array.length) { + array[rgbCursor++] = array[rgbaCursor++]; // R + array[rgbCursor++] = array[rgbaCursor++]; // G + array[rgbCursor++] = array[rgbaCursor++]; // B + rgbaCursor++; + } + target.put(array, 0, target.capacity()); + target.rewind(); + return target; + } else { + throw new IllegalArgumentException( + String.format( + Locale.ENGLISH, + "Convert bytebuffer image format from %d to %d is not supported", + sourceFormat, + targetFormat)); + } + } + + // ByteBuffer is not able to be instantiated. + private ByteBufferExtractor() {} +} diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java new file mode 100644 index 000000000..07871da38 --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java @@ -0,0 +1,71 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.mediapipe.framework.image; + +import com.google.mediapipe.framework.image.Image.ImageFormat; +import java.nio.ByteBuffer; + +/** + * Builds a {@link Image} from a {@link ByteBuffer}. + * + *

You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link + * ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it. + * + *

Use {@link ByteBufferExtractor} to get {@link ByteBuffer} you passed in. + */ +public class ByteBufferImageBuilder { + + // Mandatory fields. + private final ByteBuffer buffer; + private final int width; + private final int height; + @ImageFormat private final int imageFormat; + + // Optional fields. + private long timestamp; + + /** + * Creates the builder with mandatory {@link ByteBuffer} and the represented image. + * + *

We will validate the size of the {@code byteBuffer} with given {@code width}, {@code height} + * and {@code imageFormat}. + * + * @param byteBuffer image data object. + * @param width the width of the represented image. + * @param height the height of the represented image. + * @param imageFormat how the data encode the image. + */ + public ByteBufferImageBuilder( + ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) { + this.buffer = byteBuffer; + this.width = width; + this.height = height; + this.imageFormat = imageFormat; + // TODO: Validate bytebuffer size with width, height and image format + this.timestamp = 0; + } + + /** Sets value for {@link Image#getTimestamp()}. */ + ByteBufferImageBuilder setTimestamp(long timestamp) { + this.timestamp = timestamp; + return this; + } + + /** Builds an {@link Image} instance. */ + public Image build() { + return new Image(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height); + } +} diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java new file mode 100644 index 000000000..1c24c1dfd --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java @@ -0,0 +1,58 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.mediapipe.framework.image; + +import com.google.mediapipe.framework.image.Image.ImageFormat; +import java.nio.ByteBuffer; + +class ByteBufferImageContainer implements ImageContainer { + + private final ByteBuffer buffer; + private final ImageProperties properties; + + public ByteBufferImageContainer( + ByteBuffer buffer, + @ImageFormat int imageFormat) { + this.buffer = buffer; + this.properties = + ImageProperties.builder() + .setStorageType(Image.STORAGE_TYPE_BYTEBUFFER) + .setImageFormat(imageFormat) + .build(); + } + + public ByteBuffer getByteBuffer() { + return buffer; + } + + @Override + public ImageProperties getImageProperties() { + return properties; + } + + /** + * Returns the image format. + */ + @ImageFormat + public int getImageFormat() { + return properties.getImageFormat(); + } + + @Override + public void close() { + // No op for ByteBuffer. + } +} diff --git a/mediapipe/java/com/google/mediapipe/framework/image/Image.java b/mediapipe/java/com/google/mediapipe/framework/image/Image.java new file mode 100644 index 000000000..49e63bcc0 --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/image/Image.java @@ -0,0 +1,241 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.mediapipe.framework.image; + +import androidx.annotation.IntDef; +import androidx.annotation.Nullable; +import java.io.Closeable; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +/** + * The wrapper class for image objects. + * + *

{@link Image} is designed to be an immutable image container, which could be shared + * cross-platforms. + * + *

To construct an {@link Image}, use the provided builders: + * + *

    + *
  • {@link ByteBufferImageBuilder} + *
  • {@link BitmapImageBuilder} + *
  • {@link MediaImageBuilder} + *
+ * + *

{@link Image} uses reference counting to maintain internal storage. When it is created the + * reference count is 1. Developer can call {@link #close()} to reduce reference count to release + * internal storage earlier, otherwise Java garbage collection will release the storage eventually. + * + *

To extract concrete image, first check {@link StorageType} and then use the provided + * extractors: + * + *

    + *
  • {@link ByteBufferExtractor} + *
  • {@link BitmapExtractor} + *
  • {@link MediaImageExtractor} + *
+ */ +public class Image implements Closeable { + + /** Specifies the image format of an image. */ + @IntDef({ + IMAGE_FORMAT_UNKNOWN, + IMAGE_FORMAT_RGBA, + IMAGE_FORMAT_RGB, + IMAGE_FORMAT_NV12, + IMAGE_FORMAT_NV21, + IMAGE_FORMAT_YV12, + IMAGE_FORMAT_YV21, + IMAGE_FORMAT_YUV_420_888, + IMAGE_FORMAT_ALPHA, + IMAGE_FORMAT_JPEG, + }) + @Retention(RetentionPolicy.SOURCE) + public @interface ImageFormat {} + + public static final int IMAGE_FORMAT_UNKNOWN = 0; + public static final int IMAGE_FORMAT_RGBA = 1; + public static final int IMAGE_FORMAT_RGB = 2; + public static final int IMAGE_FORMAT_NV12 = 3; + public static final int IMAGE_FORMAT_NV21 = 4; + public static final int IMAGE_FORMAT_YV12 = 5; + public static final int IMAGE_FORMAT_YV21 = 6; + public static final int IMAGE_FORMAT_YUV_420_888 = 7; + public static final int IMAGE_FORMAT_ALPHA = 8; + public static final int IMAGE_FORMAT_JPEG = 9; + + /** Specifies the image container type. Would be useful for choosing extractors. */ + @IntDef({ + STORAGE_TYPE_BITMAP, + STORAGE_TYPE_BYTEBUFFER, + STORAGE_TYPE_MEDIA_IMAGE, + STORAGE_TYPE_IMAGE_PROXY, + }) + @Retention(RetentionPolicy.SOURCE) + public @interface StorageType {} + + public static final int STORAGE_TYPE_BITMAP = 1; + public static final int STORAGE_TYPE_BYTEBUFFER = 2; + public static final int STORAGE_TYPE_MEDIA_IMAGE = 3; + public static final int STORAGE_TYPE_IMAGE_PROXY = 4; + + /** + * Returns a list of supported image properties for this {@link Image}. + * + *

Currently {@link Image} only support single storage type so the size of return list will + * always be 1. + * + * @see ImageProperties + */ + public List getContainedImageProperties() { + return Collections.singletonList(getContainer().getImageProperties()); + } + + /** Returns the timestamp attached to the image. */ + long getTimestamp() { + return timestamp; + } + + /** Returns the width of the image. */ + public int getWidth() { + return width; + } + + /** Returns the height of the image. */ + public int getHeight() { + return height; + } + + /** Acquires a reference on this {@link Image}. This will increase the reference count by 1. */ + private synchronized void acquire() { + referenceCount += 1; + } + + /** + * Removes a reference that was previously acquired or init. + * + *

When {@link Image} is created, it has 1 reference count. + * + *

When the reference count becomes 0, it will release the resource under the hood. + */ + @Override + // TODO: Create an internal flag to indicate image is closed, or use referenceCount + public synchronized void close() { + referenceCount -= 1; + if (referenceCount == 0) { + for (ImageContainer imageContainer : containerMap.values()) { + imageContainer.close(); + } + } + } + + /** Advanced API access for {@link Image}. */ + static final class Internal { + + /** + * Acquires a reference on this {@link Image}. This will increase the reference count by 1. + * + *

This method is more useful for image consumer to acquire a reference so image resource + * will not be closed accidentally. As image creator, normal developer doesn't need to call this + * method. + * + *

The reference count is 1 when {@link Image} is created. Developer can call {@link + * #close()} to indicate it doesn't need this {@link Image} anymore. + * + * @see #close() + */ + void acquire() { + image.acquire(); + } + + private final Image image; + + // Only Image creates the internal helper. + private Internal(Image image) { + this.image = image; + } + } + + /** Gets {@link Internal} object which contains internal APIs. */ + Internal getInternal() { + return new Internal(this); + } + + private final Map containerMap; + private final long timestamp; + private final int width; + private final int height; + + private int referenceCount; + + /** Constructs an {@link Image} with a built container. */ + Image(ImageContainer container, long timestamp, int width, int height) { + this.containerMap = new HashMap<>(); + containerMap.put(container.getImageProperties(), container); + this.timestamp = timestamp; + this.width = width; + this.height = height; + this.referenceCount = 1; + } + + /** + * Gets one available container. + * + * @return the current container. + */ + ImageContainer getContainer() { + // According to the design, in the future we will support multiple containers in one image. + // Currently just return the original container. + // TODO: Cache multiple containers in Image. + return containerMap.values().iterator().next(); + } + + /** + * Gets container from required {@code storageType}. Returns {@code null} if not existed. + * + *

If there are multiple containers with required {@code storageType}, returns the first one. + */ + @Nullable + ImageContainer getContainer(@StorageType int storageType) { + for (Entry entry : containerMap.entrySet()) { + if (entry.getKey().getStorageType() == storageType) { + return entry.getValue(); + } + } + return null; + } + + /** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */ + @Nullable + ImageContainer getContainer(ImageProperties imageProperties) { + return containerMap.get(imageProperties); + } + + /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */ + boolean addContainer(ImageContainer container) { + ImageProperties imageProperties = container.getImageProperties(); + if (containerMap.containsKey(imageProperties)) { + return false; + } + containerMap.put(imageProperties, container); + return true; + } +} diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ImageConsumer.java b/mediapipe/java/com/google/mediapipe/framework/image/ImageConsumer.java new file mode 100644 index 000000000..18eed68c6 --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/image/ImageConsumer.java @@ -0,0 +1,27 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package com.google.mediapipe.framework.image; + +/** Lightweight abstraction for an object that can receive {@link Image} */ +public interface ImageConsumer { + + /** + * Called when an {@link Image} is available. + * + *

The argument is only guaranteed to be available until this method returns. if you need to + * extend its life time, acquire it, then release it when done. + */ + void onNewImage(Image image); +} diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/ImageContainer.java new file mode 100644 index 000000000..727ec0893 --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/image/ImageContainer.java @@ -0,0 +1,25 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.mediapipe.framework.image; + +/** Manages internal image data storage. The interface is package-private. */ +interface ImageContainer { + /** Returns the properties of the contained image. */ + ImageProperties getImageProperties(); + + /** Close the image container and releases the image resource inside. */ + void close(); +} diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ImageProducer.java b/mediapipe/java/com/google/mediapipe/framework/image/ImageProducer.java new file mode 100644 index 000000000..4f3641d6f --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/image/ImageProducer.java @@ -0,0 +1,22 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package com.google.mediapipe.framework.image; + +/** Lightweight abstraction for an object that produce {@link Image} */ +public interface ImageProducer { + + /** Sets the consumer that receives the {@link Image}. */ + void setImageConsumer(ImageConsumer imageConsumer); +} diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ImageProperties.java b/mediapipe/java/com/google/mediapipe/framework/image/ImageProperties.java new file mode 100644 index 000000000..e33b33e7f --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/image/ImageProperties.java @@ -0,0 +1,80 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.mediapipe.framework.image; + +import com.google.auto.value.AutoValue; +import com.google.auto.value.extension.memoized.Memoized; +import com.google.mediapipe.framework.image.Image.ImageFormat; +import com.google.mediapipe.framework.image.Image.StorageType; + +/** Groups a set of properties to describe how an image is stored. */ +@AutoValue +public abstract class ImageProperties { + + /** + * Gets the pixel format of the image. + * + * @see Image.ImageFormat + */ + @ImageFormat + public abstract int getImageFormat(); + + /** + * Gets the storage type of the image. + * + * @see Image.StorageType + */ + @StorageType + public abstract int getStorageType(); + + @Memoized + @Override + public abstract int hashCode(); + + /** + * Creates a builder of {@link ImageProperties}. + * + * @see ImageProperties.Builder + */ + static Builder builder() { + return new AutoValue_ImageProperties.Builder(); + } + + /** Builds a {@link ImageProperties}. */ + @AutoValue.Builder + abstract static class Builder { + + /** + * Sets the {@link Image.ImageFormat}. + * + * @see ImageProperties#getImageFormat + */ + abstract Builder setImageFormat(@ImageFormat int value); + + /** + * Sets the {@link Image.StorageType}. + * + * @see ImageProperties#getStorageType + */ + abstract Builder setStorageType(@StorageType int value); + + /** Builds the {@link ImageProperties}. */ + abstract ImageProperties build(); + } + + // Hide the constructor. + ImageProperties() {} +} diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java new file mode 100644 index 000000000..e351a87fd --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java @@ -0,0 +1,62 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.mediapipe.framework.image; + +import android.os.Build.VERSION_CODES; +import androidx.annotation.RequiresApi; + +/** + * Builds {@link Image} from {@link android.media.Image}. + * + *

Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify + * content in it. + * + *

Use {@link MediaImageExtractor} to get {@link android.media.Image} you passed in. + */ +@RequiresApi(VERSION_CODES.KITKAT) +public class MediaImageBuilder { + + // Mandatory fields. + private final android.media.Image mediaImage; + + // Optional fields. + private long timestamp; + + /** + * Creates the builder with a mandatory {@link android.media.Image}. + * + * @param mediaImage image data object. + */ + public MediaImageBuilder(android.media.Image mediaImage) { + this.mediaImage = mediaImage; + this.timestamp = 0; + } + + /** Sets value for {@link Image#getTimestamp()}. */ + MediaImageBuilder setTimestamp(long timestamp) { + this.timestamp = timestamp; + return this; + } + + /** Builds an {@link Image} instance. */ + public Image build() { + return new Image( + new MediaImageContainer(mediaImage), + timestamp, + mediaImage.getWidth(), + mediaImage.getHeight()); + } +} diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java new file mode 100644 index 000000000..144b64def --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java @@ -0,0 +1,73 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.mediapipe.framework.image; + +import android.os.Build; +import android.os.Build.VERSION; +import android.os.Build.VERSION_CODES; +import androidx.annotation.RequiresApi; +import com.google.mediapipe.framework.image.Image.ImageFormat; + +@RequiresApi(VERSION_CODES.KITKAT) +class MediaImageContainer implements ImageContainer { + + private final android.media.Image mediaImage; + private final ImageProperties properties; + + public MediaImageContainer(android.media.Image mediaImage) { + this.mediaImage = mediaImage; + this.properties = + ImageProperties.builder() + .setStorageType(Image.STORAGE_TYPE_MEDIA_IMAGE) + .setImageFormat(convertFormatCode(mediaImage.getFormat())) + .build(); + } + + public android.media.Image getImage() { + return mediaImage; + } + + @Override + public ImageProperties getImageProperties() { + return properties; + } + + @Override + public void close() { + mediaImage.close(); + } + + @ImageFormat + static int convertFormatCode(int graphicsFormat) { + // We only cover the format mentioned in + // https://developer.android.com/reference/android/media/Image#getFormat() + if (VERSION.SDK_INT >= Build.VERSION_CODES.M) { + if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) { + return Image.IMAGE_FORMAT_RGBA; + } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) { + return Image.IMAGE_FORMAT_RGB; + } + } + switch (graphicsFormat) { + case android.graphics.ImageFormat.JPEG: + return Image.IMAGE_FORMAT_JPEG; + case android.graphics.ImageFormat.YUV_420_888: + return Image.IMAGE_FORMAT_YUV_420_888; + default: + return Image.IMAGE_FORMAT_UNKNOWN; + } + } +} diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java new file mode 100644 index 000000000..718cb471f --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java @@ -0,0 +1,49 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.mediapipe.framework.image; + +import android.os.Build.VERSION_CODES; +import androidx.annotation.RequiresApi; + +/** + * Utility for extracting {@link android.media.Image} from {@link Image}. + * + *

Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_MEDIA_IMAGE}, + * otherwise {@link IllegalArgumentException} will be thrown. + */ +@RequiresApi(VERSION_CODES.KITKAT) +public class MediaImageExtractor { + + private MediaImageExtractor() {} + + /** + * Extracts a {@link android.media.Image} from an {@link Image}. Currently it only works for + * {@link Image} that built from {@link MediaImageBuilder}. + * + * @param image the image to extract {@link android.media.Image} from. + * @return {@link android.media.Image} that stored in {@link Image}. + * @throws IllegalArgumentException if the extraction failed. + */ + public static android.media.Image extract(Image image) { + ImageContainer container; + if ((container = image.getContainer(Image.STORAGE_TYPE_MEDIA_IMAGE)) != null) { + return ((MediaImageContainer) container).getImage(); + } + throw new IllegalArgumentException( + "Extract Media Image from an Image created by objects other than Media Image" + + " is not supported"); + } +} diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc index f391d0daf..84df89260 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc @@ -73,9 +73,14 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( // TODO: get the graph's main context from the packet context? // Or clean up in some other way? if (context_for_deletion) { - token = new mediapipe::GlSyncToken( - mediapipe::GlContext::CreateSyncTokenForCurrentExternalContext( - context_for_deletion)); + auto sync = mediapipe::GlContext::CreateSyncTokenForCurrentExternalContext( + context_for_deletion); + // A Java handle to a token is a raw pointer to a std::shared_ptr on the + // heap, cast to a long. If the shared_ptr itself is null, leave the token + // null too. + if (sync) { + token = new mediapipe::GlSyncToken(std::move(sync)); + } } return reinterpret_cast(token); } diff --git a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl index 4ffcee042..ed1686954 100644 --- a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl +++ b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl @@ -145,6 +145,7 @@ EOF "//mediapipe/java/com/google/mediapipe/components:android_components", "//mediapipe/java/com/google/mediapipe/components:android_camerax_helper", "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/java/com/google/mediapipe/glutil", "//third_party:androidx_annotation", "//third_party:androidx_appcompat", diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc index 0536b1116..db939f341 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc @@ -76,9 +76,10 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) { options_proto->mutable_base_options()->Swap(base_options_proto.get()); options_proto->mutable_base_options()->set_use_stream_mode( options->running_mode == core::RunningMode::AUDIO_STREAM); - auto classifier_options_proto = std::make_unique( - components::ConvertClassifierOptionsToProto( - &(options->classifier_options))); + auto classifier_options_proto = + std::make_unique( + components::ConvertClassifierOptionsToProto( + &(options->classifier_options))); options_proto->mutable_classifier_options()->Swap( classifier_options_proto.get()); if (options->sample_rate > 0) { diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc index 52af20cb6..0f40b59a4 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc @@ -136,6 +136,11 @@ void ConfigureAudioToTensorCalculator( // options { // [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierOptions.ext] // { +// base_options { +// model_asset { +// file_name: "/path/to/model.tflite" +// } +// } // max_results: 4 // score_threshold: 0.5 // category_allowlist: "foo" @@ -225,15 +230,17 @@ class AudioClassifierGraph : public core::ModelTaskGraph { // Adds inference subgraph and connects its input stream to the output // tensors produced by the AudioToTensorCalculator. - auto& inference = AddInference(model_resources, graph); + auto& inference = AddInference( + model_resources, task_options.base_options().acceleration(), graph); audio_to_tensor.Out(kTensorsTag) >> inference.In(kTensorsTag); // Adds postprocessing calculators and connects them to the graph output. - auto& postprocessing = - graph.AddNode("mediapipe.tasks.ClassificationPostprocessingSubgraph"); + auto& postprocessing = graph.AddNode( + "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( model_resources, task_options.classifier_options(), - &postprocessing.GetOptions())); + &postprocessing.GetOptions< + tasks::components::ClassificationPostprocessingOptions>())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Time aggregation is only needed for performing audio classification on diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc index c59671b77..dd56c4ff1 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc @@ -37,7 +37,6 @@ limitations under the License. #include "mediapipe/tasks/cc/audio/core/running_mode.h" #include "mediapipe/tasks/cc/audio/utils/test_utils.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/containers/category.pb.h" #include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" @@ -168,7 +167,7 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {}; TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) { auto options = std::make_unique(); options->classifier_options.max_results = 3; - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_classifier, AudioClassifier::Create(std::move(options))); @@ -192,7 +191,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) { auto options = std::make_unique(); options->classifier_options.max_results = 0; - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithMetadata); StatusOr> audio_classifier_or = AudioClassifier::Create(std::move(options)); @@ -208,7 +207,7 @@ TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) { TEST_F(CreateFromOptionsTest, FailsWithCombinedAllowlistAndDenylist) { auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithMetadata); options->classifier_options.category_allowlist.push_back("foo"); options->classifier_options.category_denylist.push_back("bar"); @@ -226,7 +225,7 @@ TEST_F(CreateFromOptionsTest, FailsWithCombinedAllowlistAndDenylist) { TEST_F(CreateFromOptionsTest, FailsWithMissingMetadata) { auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); StatusOr> audio_classifier_or = AudioClassifier::Create(std::move(options)); @@ -242,7 +241,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingMetadata) { TEST_F(CreateFromOptionsTest, FailsWithMissingCallback) { auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); options->running_mode = core::RunningMode::AUDIO_STREAM; options->sample_rate = 16000; @@ -260,7 +259,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallback) { TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) { auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); options->result_callback = [](absl::StatusOr status_or_result) {}; @@ -279,7 +278,7 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) { TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) { auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); options->running_mode = core::RunningMode::AUDIO_STREAM; options->result_callback = @@ -301,7 +300,7 @@ class ClassifyTest : public tflite_shims::testing::Test {}; TEST_F(ClassifyTest, Succeeds) { auto audio_buffer = GetAudioData(k16kTestWavFilename); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_classifier, AudioClassifier::Create(std::move(options))); @@ -315,7 +314,7 @@ TEST_F(ClassifyTest, Succeeds) { TEST_F(ClassifyTest, SucceedsWithResampling) { auto audio_buffer = GetAudioData(k48kTestWavFilename); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_classifier, AudioClassifier::Create(std::move(options))); @@ -330,7 +329,7 @@ TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) { auto audio_buffer_16k_hz = GetAudioData(k16kTestWavFilename); auto audio_buffer_48k_hz = GetAudioData(k48kTestWavFilename); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_classifier, AudioClassifier::Create(std::move(options))); @@ -349,7 +348,7 @@ TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) { TEST_F(ClassifyTest, SucceedsWithInsufficientData) { auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_classifier, AudioClassifier::Create(std::move(options))); @@ -374,7 +373,7 @@ TEST_F(ClassifyTest, SucceedsWithInsufficientData) { TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) { auto audio_buffer = GetAudioData(k16kTestWavForTwoHeadsFilename); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_classifier, AudioClassifier::Create(std::move(options))); @@ -388,7 +387,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) { TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) { auto audio_buffer = GetAudioData(k44kTestWavForTwoHeadsFilename); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_classifier, AudioClassifier::Create(std::move(options))); @@ -404,7 +403,7 @@ TEST_F(ClassifyTest, auto audio_buffer_44k_hz = GetAudioData(k44kTestWavForTwoHeadsFilename); auto audio_buffer_16k_hz = GetAudioData(k16kTestWavForTwoHeadsFilename); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_classifier, AudioClassifier::Create(std::move(options))); @@ -424,7 +423,7 @@ TEST_F(ClassifyTest, TEST_F(ClassifyTest, SucceedsWithMaxResultOption) { auto audio_buffer = GetAudioData(k48kTestWavFilename); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithMetadata); options->classifier_options.max_results = 1; options->classifier_options.score_threshold = 0.35f; @@ -440,7 +439,7 @@ TEST_F(ClassifyTest, SucceedsWithMaxResultOption) { TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) { auto audio_buffer = GetAudioData(k48kTestWavFilename); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithMetadata); options->classifier_options.score_threshold = 0.35f; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr audio_classifier, @@ -455,7 +454,7 @@ TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) { TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) { auto audio_buffer = GetAudioData(k48kTestWavFilename); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithMetadata); options->classifier_options.score_threshold = 0.1f; options->classifier_options.category_allowlist.push_back("Speech"); @@ -471,7 +470,7 @@ TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) { TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) { auto audio_buffer = GetAudioData(k48kTestWavFilename); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithMetadata); options->classifier_options.score_threshold = 0.9f; options->classifier_options.category_denylist.push_back("Speech"); @@ -499,7 +498,7 @@ TEST_F(ClassifyAsyncTest, Succeeds) { constexpr int kSampleRateHz = 48000; auto audio_buffer = GetAudioData(k48kTestWavFilename); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithMetadata); options->classifier_options.max_results = 1; options->classifier_options.score_threshold = 0.3f; @@ -529,7 +528,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) { constexpr int kSampleRateHz = 48000; auto audio_buffer = GetAudioData(k48kTestWavFilename); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithMetadata); options->classifier_options.max_results = 1; options->classifier_options.score_threshold = 0.3f; diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD index 1bb26f5c1..7b1952e06 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD @@ -24,7 +24,7 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components:classifier_options_proto", + "//mediapipe/tasks/cc/components/proto:classifier_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto index 9dd65a265..a76ccdcab 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto @@ -18,7 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.audio.audio_classifier.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/classifier_options.proto"; +import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; message AudioClassifierOptions { @@ -31,7 +31,7 @@ message AudioClassifierOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - optional ClassifierOptions classifier_options = 2; + optional components.proto.ClassifierOptions classifier_options = 2; // The default sample rate of the input audio. Must be set when the // AudioClassifier is configured to process audio stream data. diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD index 1f9fc607b..4de32ce9b 100644 --- a/mediapipe/tasks/cc/components/BUILD +++ b/mediapipe/tasks/cc/components/BUILD @@ -35,6 +35,8 @@ cc_library( deps = [ ":image_preprocessing_options_cc_proto", "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/calculators/image:image_clone_calculator", + "//mediapipe/calculators/image:image_clone_calculator_cc_proto", "//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", @@ -56,21 +58,11 @@ cc_library( # TODO: Enable this test -mediapipe_proto_library( - name = "segmenter_options_proto", - srcs = ["segmenter_options.proto"], -) - cc_library( name = "classifier_options", srcs = ["classifier_options.cc"], hdrs = ["classifier_options.h"], - deps = [":classifier_options_cc_proto"], -) - -mediapipe_proto_library( - name = "classifier_options_proto", - srcs = ["classifier_options.proto"], + deps = ["//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto"], ) mediapipe_proto_library( @@ -81,6 +73,7 @@ mediapipe_proto_library( "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto", ], ) @@ -90,7 +83,6 @@ cc_library( hdrs = ["classification_postprocessing.h"], deps = [ ":classification_postprocessing_options_cc_proto", - ":classifier_options_cc_proto", "//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/calculators/core:split_vector_calculator_cc_proto", "//mediapipe/calculators/tensor:tensors_dequantization_calculator", @@ -104,7 +96,12 @@ cc_library( "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator", "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/components/utils:source_or_node_output", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/metadata:metadata_schema_cc", @@ -119,3 +116,38 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "embedder_options", + srcs = ["embedder_options.cc"], + hdrs = ["embedder_options.h"], + deps = ["//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto"], +) + +cc_library( + name = "embedding_postprocessing_graph", + srcs = ["embedding_postprocessing_graph.cc"], + hdrs = ["embedding_postprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/tensor:tensors_dequantization_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/tool:options_map", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", + "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto", + "//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/utils:source_or_node_output", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index 8b553dea4..13ca6b496 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -113,3 +113,66 @@ cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "end_loop_calculator", + srcs = ["end_loop_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/calculators/core:end_loop_calculator", + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_contract", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection_item_id", + "//mediapipe/framework:packet", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + ], + alwayslink = 1, +) + +mediapipe_proto_library( + name = "tensors_to_embeddings_calculator_proto", + srcs = ["tensors_to_embeddings_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/proto:embedder_options_proto", + ], +) + +cc_library( + name = "tensors_to_embeddings_calculator", + srcs = ["tensors_to_embeddings_calculator.cc"], + deps = [ + ":tensors_to_embeddings_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], + alwayslink = 1, +) + +cc_test( + name = "tensors_to_embeddings_calculator_test", + srcs = ["tensors_to_embeddings_calculator_test.cc"], + deps = [ + ":tensors_to_embeddings_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:packet", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "@com_google_absl//absl/status", + ], +) diff --git a/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc b/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc new file mode 100644 index 000000000..b688cda91 --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc @@ -0,0 +1,29 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/calculators/core/end_loop_calculator.h" + +#include + +#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" + +// Specialized EndLoopCalculator for Tasks specific types. +namespace mediapipe::tasks { + +typedef EndLoopCalculator> + EndLoopClassificationResultCalculator; +REGISTER_CALCULATOR(::mediapipe::tasks::EndLoopClassificationResultCalculator); + +} // namespace mediapipe::tasks diff --git a/mediapipe/tasks/cc/components/calculators/tensor/BUILD b/mediapipe/tasks/cc/components/calculators/tensor/BUILD index de94724b6..6e4322a8f 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/BUILD +++ b/mediapipe/tasks/cc/components/calculators/tensor/BUILD @@ -25,7 +25,7 @@ mediapipe_proto_library( "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/framework/formats:image_format_proto", - "//mediapipe/tasks/cc/components:segmenter_options_proto", + "//mediapipe/tasks/cc/components/proto:segmenter_options_proto", "//mediapipe/util:label_map_proto", ], ) @@ -45,7 +45,7 @@ cc_library( "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", - "//mediapipe/tasks/cc/components:segmenter_options_cc_proto", + "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_utils", "//mediapipe/util:label_map_cc_proto", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc index 4ea41b163..40585848f 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc @@ -36,19 +36,22 @@ limitations under the License. #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/segmenter_options.pb.h" +#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/util/label_map.pb.h" namespace mediapipe { -namespace api2 { +namespace tasks { namespace { using ::mediapipe::Image; using ::mediapipe::ImageFrameSharedPtr; -using ::mediapipe::tasks::SegmenterOptions; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Node; +using ::mediapipe::api2::Output; using ::mediapipe::tasks::TensorsToSegmentationCalculatorOptions; +using ::mediapipe::tasks::components::proto::SegmenterOptions; using ::mediapipe::tasks::vision::GetImageLikeTensorShape; using ::mediapipe::tasks::vision::Shape; @@ -254,7 +257,7 @@ std::vector TensorsToSegmentationCalculator::GetSegmentationResult( return segmented_masks; } -MEDIAPIPE_REGISTER_NODE(TensorsToSegmentationCalculator); +MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::TensorsToSegmentationCalculator); -} // namespace api2 +} // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto b/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto index 4691c283e..c26cf910a 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto +++ b/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto @@ -18,7 +18,7 @@ syntax = "proto2"; package mediapipe.tasks; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/segmenter_options.proto"; +import "mediapipe/tasks/cc/components/proto/segmenter_options.proto"; import "mediapipe/util/label_map.proto"; message TensorsToSegmentationCalculatorOptions { @@ -26,7 +26,7 @@ message TensorsToSegmentationCalculatorOptions { optional TensorsToSegmentationCalculatorOptions ext = 458105876; } - optional SegmenterOptions segmenter_options = 1; + optional components.proto.SegmenterOptions segmenter_options = 1; // Identifying information for each classification label. map label_items = 2; diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc index 72c217fb2..55e46d72b 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc +++ b/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc @@ -117,7 +117,7 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionOne) { CalculatorRunner runner( mediapipe::ParseTextProtoOrDie( R"pb( - calculator: "TensorsToSegmentationCalculator" + calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" output_stream: "SEGMENTATION:segmentation" options { @@ -144,7 +144,7 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionFive) { CalculatorRunner runner( mediapipe::ParseTextProtoOrDie( R"pb( - calculator: "TensorsToSegmentationCalculator" + calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" output_stream: "SEGMENTATION:segmentation" options { @@ -172,7 +172,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSoftmax) { CalculatorRunner runner( mediapipe::ParseTextProtoOrDie( R"pb( - calculator: "TensorsToSegmentationCalculator" + calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" output_stream: "SEGMENTATION:0:segmented_mask_0" output_stream: "SEGMENTATION:1:segmented_mask_1" @@ -217,7 +217,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithNone) { CalculatorRunner runner( mediapipe::ParseTextProtoOrDie( R"pb( - calculator: "TensorsToSegmentationCalculator" + calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" output_stream: "SEGMENTATION:0:segmented_mask_0" output_stream: "SEGMENTATION:1:segmented_mask_1" @@ -258,7 +258,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSigmoid) { CalculatorRunner runner( mediapipe::ParseTextProtoOrDie( R"pb( - calculator: "TensorsToSegmentationCalculator" + calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" output_stream: "SEGMENTATION:0:segmented_mask_0" output_stream: "SEGMENTATION:1:segmented_mask_1" @@ -300,7 +300,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMask) { CalculatorRunner runner( mediapipe::ParseTextProtoOrDie( R"pb( - calculator: "TensorsToSegmentationCalculator" + calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" output_stream: "SEGMENTATION:segmentation" options { @@ -333,7 +333,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) { CalculatorRunner runner( mediapipe::ParseTextProtoOrDie( R"pb( - calculator: "TensorsToSegmentationCalculator" + calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" input_stream: "OUTPUT_SIZE:size" output_stream: "SEGMENTATION:segmentation" diff --git a/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.cc b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.cc new file mode 100644 index 000000000..05b3e1f3f --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.cc @@ -0,0 +1,158 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" + +namespace mediapipe { +namespace api2 { + +namespace { + +using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; +using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; + +// Computes the inverse L2 norm of the provided array of values. Returns 1.0 in +// case all values are 0. +float GetInverseL2Norm(const float* values, int size) { + float squared_l2_norm = 0.0f; + for (int i = 0; i < size; ++i) { + squared_l2_norm += values[i] * values[i]; + } + float inv_l2_norm = 1.0f; + if (squared_l2_norm > 0.0f) { + inv_l2_norm = 1.0f / std::sqrt(squared_l2_norm); + } + return inv_l2_norm; +} + +} // namespace + +// Converts tensors into an EmbeddingResult object, performing optional +// L2-normalization and scalar-quantization on-the-fly if required through the +// options. +// +// Input: +// TENSORS - std::vector +// A vector of one or more Tensors of type kFloat32. +// Output: +// EMBEDDINGS - EmbeddingResult +// The contents of the input tensors converted into an EmbeddingResult +// proto. +class TensorsToEmbeddingsCalculator : public Node { + public: + static constexpr Input> kTensorsIn{"TENSORS"}; + static constexpr Output kEmbeddingsOut{"EMBEDDING_RESULT"}; + MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kEmbeddingsOut); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + bool l2_normalize_; + bool quantize_; + std::vector head_names_; + + void FillFloatEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry); + void FillQuantizedEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry); +}; + +absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) { + auto options = cc->Options(); + l2_normalize_ = options.embedder_options().l2_normalize(); + quantize_ = options.embedder_options().quantize(); + if (!options.head_names().empty()) { + head_names_.assign(options.head_names().begin(), + options.head_names().end()); + } + return absl::OkStatus(); +} + +absl::Status TensorsToEmbeddingsCalculator::Process(CalculatorContext* cc) { + EmbeddingResult result; + const auto& tensors = *kTensorsIn(cc); + if (!head_names_.empty() && tensors.size() != head_names_.size()) { + return absl::InvalidArgumentError(absl::StrFormat( + "Mismatch between number of provided head names (%d) and number " + "of input tensors (%d).", + head_names_.size(), tensors.size())); + } + for (int i = 0; i < tensors.size(); ++i) { + const auto& tensor = tensors[i]; + RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32); + auto* embeddings = result.add_embeddings(); + embeddings->set_head_index(i); + if (!head_names_.empty()) { + embeddings->set_head_name(head_names_[i]); + } + if (quantize_) { + FillQuantizedEmbeddingEntry(tensor, embeddings->add_entries()); + } else { + FillFloatEmbeddingEntry(tensor, embeddings->add_entries()); + } + } + kEmbeddingsOut(cc).Send(result); + return absl::OkStatus(); +} + +void TensorsToEmbeddingsCalculator::FillFloatEmbeddingEntry( + const Tensor& tensor, EmbeddingEntry* entry) { + int size = tensor.shape().num_elements(); + auto tensor_view = tensor.GetCpuReadView(); + const float* tensor_buffer = tensor_view.buffer(); + float inv_l2_norm = + l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f; + auto* float_embedding = entry->mutable_float_embedding(); + for (int i = 0; i < size; ++i) { + float_embedding->add_values(tensor_buffer[i] * inv_l2_norm); + } +} + +void TensorsToEmbeddingsCalculator::FillQuantizedEmbeddingEntry( + const Tensor& tensor, EmbeddingEntry* entry) { + int size = tensor.shape().num_elements(); + auto tensor_view = tensor.GetCpuReadView(); + const float* tensor_buffer = tensor_view.buffer(); + float inv_l2_norm = + l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f; + auto* values = entry->mutable_quantized_embedding()->mutable_values(); + values->resize(size); + for (int i = 0; i < size; ++i) { + // Normalize. + float normalized = tensor_buffer[i] * inv_l2_norm; + // Quantize. + int unclamped_value = static_cast(roundf(normalized * 128)); + // Clamp and assign. + (*values)[i] = + static_cast(std::max(-128, std::min(unclamped_value, 127))); + } +} + +MEDIAPIPE_REGISTER_NODE(TensorsToEmbeddingsCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto new file mode 100644 index 000000000..2f088c503 --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto @@ -0,0 +1,35 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/proto/embedder_options.proto"; + +message TensorsToEmbeddingsCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional TensorsToEmbeddingsCalculatorOptions ext = 474762326; + } + + // The embedder options defining whether to L2-normalize or scalar-quantize + // the outputs. + optional mediapipe.tasks.components.proto.EmbedderOptions embedder_options = + 1; + + // The embedder head names. + repeated string head_names = 2; +} diff --git a/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator_test.cc new file mode 100644 index 000000000..b6d319121 --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator_test.cc @@ -0,0 +1,249 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/status/status.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" + +namespace mediapipe { +namespace { + +using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +using ::testing::HasSubstr; +using Node = ::mediapipe::CalculatorGraphConfig::Node; + +// Builds the graph and feeds inputs. +void BuildGraph(CalculatorRunner* runner, + std::vector> tensors) { + auto inputs = std::make_unique>(); + for (const auto& tensor : tensors) { + inputs->emplace_back(Tensor::ElementType::kFloat32, + Tensor::Shape{1, static_cast(tensor.size())}); + auto view = inputs->back().GetCpuWriteView(); + float* buffer = view.buffer(); + ASSERT_NE(buffer, nullptr); + for (int i = 0; i < tensor.size(); ++i) { + buffer[i] = tensor[i]; + } + } + auto& input_packets = runner->MutableInputs()->Tag("TENSORS").packets; + input_packets.push_back(Adopt(inputs.release()).At(Timestamp(0))); +} + +TEST(TensorsToEmbeddingsCalculatorTest, FailsWithInvalidHeadNamesNumber) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "TensorsToEmbeddingsCalculator" + input_stream: "TENSORS:tensors" + output_stream: "EMBEDDING_RESULT:embeddings" + options { + [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { head_names: "foo" } + } + )pb")); + + BuildGraph(&runner, {{0.1, 0.2}, {0.2, 0.3}}); + auto status = runner.Run(); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + HasSubstr("Mismatch between number of provided head names")); +} + +TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "TensorsToEmbeddingsCalculator" + input_stream: "TENSORS:tensors" + output_stream: "EMBEDDING_RESULT:embeddings" + options { + [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { + embedder_options { l2_normalize: false quantize: false } + } + } + )pb")); + + BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); + MP_ASSERT_OK(runner.Run()); + + const EmbeddingResult& result = runner.Outputs() + .Get("EMBEDDING_RESULT", 0) + .packets[0] + .Get(); + EXPECT_THAT( + result, + EqualsProto(ParseTextProtoOrDie( + R"pb(embeddings { + entries { float_embedding { values: 0.1 values: 0.2 } } + head_index: 0 + } + embeddings { + entries { float_embedding { values: -0.2 values: -0.3 } } + head_index: 1 + })pb"))); +} + +TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "TensorsToEmbeddingsCalculator" + input_stream: "TENSORS:tensors" + output_stream: "EMBEDDING_RESULT:embeddings" + options { + [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { + embedder_options { l2_normalize: false quantize: false } + head_names: "foo" + head_names: "bar" + } + } + )pb")); + + BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); + MP_ASSERT_OK(runner.Run()); + + const EmbeddingResult& result = runner.Outputs() + .Get("EMBEDDING_RESULT", 0) + .packets[0] + .Get(); + EXPECT_THAT( + result, + EqualsProto(ParseTextProtoOrDie( + R"pb(embeddings { + entries { float_embedding { values: 0.1 values: 0.2 } } + head_index: 0 + head_name: "foo" + } + embeddings { + entries { float_embedding { values: -0.2 values: -0.3 } } + head_index: 1 + head_name: "bar" + })pb"))); +} + +TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "TensorsToEmbeddingsCalculator" + input_stream: "TENSORS:tensors" + output_stream: "EMBEDDING_RESULT:embeddings" + options { + [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { + embedder_options { l2_normalize: true quantize: false } + } + } + )pb")); + + BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); + MP_ASSERT_OK(runner.Run()); + + const EmbeddingResult& result = runner.Outputs() + .Get("EMBEDDING_RESULT", 0) + .packets[0] + .Get(); + EXPECT_THAT( + result, + EqualsProto(ParseTextProtoOrDie( + R"pb(embeddings { + entries { + float_embedding { values: 0.44721356 values: 0.8944271 } + } + head_index: 0 + } + embeddings { + entries { + float_embedding { values: -0.5547002 values: -0.8320503 } + } + head_index: 1 + })pb"))); +} + +TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "TensorsToEmbeddingsCalculator" + input_stream: "TENSORS:tensors" + output_stream: "EMBEDDING_RESULT:embeddings" + options { + [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { + embedder_options { l2_normalize: false quantize: true } + } + } + )pb")); + + BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); + MP_ASSERT_OK(runner.Run()); + + const EmbeddingResult& result = runner.Outputs() + .Get("EMBEDDING_RESULT", 0) + .packets[0] + .Get(); + EXPECT_THAT(result, + EqualsProto(ParseTextProtoOrDie( + R"pb(embeddings { + entries { + quantized_embedding { values: "\x0d\x1a" } # 13,26 + } + head_index: 0 + } + embeddings { + entries { + quantized_embedding { values: "\xe6\xda" } # -26,-38 + } + head_index: 1 + })pb"))); +} + +TEST(TensorsToEmbeddingsCalculatorTest, + SucceedsWithNormalizationAndQuantization) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "TensorsToEmbeddingsCalculator" + input_stream: "TENSORS:tensors" + output_stream: "EMBEDDING_RESULT:embeddings" + options { + [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { + embedder_options { l2_normalize: true quantize: true } + } + } + )pb")); + + BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); + MP_ASSERT_OK(runner.Run()); + + const EmbeddingResult& result = runner.Outputs() + .Get("EMBEDDING_RESULT", 0) + .packets[0] + .Get(); + EXPECT_THAT( + result, + EqualsProto(ParseTextProtoOrDie( + R"pb(embeddings { + entries { + quantized_embedding { values: "\x39\x72" } # 57,114 + } + head_index: 0 + } + embeddings { + entries { + quantized_embedding { values: "\xb9\x95" } # -71,-107 + } + head_index: 1 + })pb"))); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/classification_postprocessing.cc b/mediapipe/tasks/cc/components/classification_postprocessing.cc index fc28391bb..871476e8f 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing.cc +++ b/mediapipe/tasks/cc/components/classification_postprocessing.cc @@ -35,9 +35,12 @@ limitations under the License. #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" #include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" @@ -47,6 +50,7 @@ limitations under the License. namespace mediapipe { namespace tasks { +namespace components { namespace { @@ -57,18 +61,21 @@ using ::mediapipe::api2::Timestamp; using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::proto::ClassifierOptions; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::tflite::ProcessUnit; -using ::tflite::ProcessUnitOptions_ScoreThresholdingOptions; using ::tflite::TensorMetadata; using LabelItems = mediapipe::proto_ns::Map; +using TensorsSource = mediapipe::tasks::SourceOrNodeOutput>; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); -constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES"; constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; +constexpr char kScoresTag[] = "SCORES"; +constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTimestampsTag[] = "TIMESTAMPS"; // Performs sanity checks on provided ClassifierOptions. @@ -183,10 +190,10 @@ absl::StatusOr GetLabelItemsIfAny( absl::StatusOr GetScoreThreshold( const ModelMetadataExtractor& metadata_extractor, const TensorMetadata& tensor_metadata) { - ASSIGN_OR_RETURN( - const ProcessUnit* score_thresholding_process_unit, - metadata_extractor.FindFirstProcessUnit( - tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions)); + ASSIGN_OR_RETURN(const ProcessUnit* score_thresholding_process_unit, + metadata_extractor.FindFirstProcessUnit( + tensor_metadata, + tflite::ProcessUnitOptions_ScoreThresholdingOptions)); if (score_thresholding_process_unit == nullptr) { return kDefaultScoreThreshold; } @@ -230,8 +237,51 @@ absl::StatusOr> GetAllowOrDenyCategoryIndicesIfAny( return category_indices; } -// Fills in the TensorsToClassificationCalculatorOptions based on the classifier -// options and the (optional) output tensor metadata. +absl::Status ConfigureScoreCalibrationIfAny( + const ModelMetadataExtractor& metadata_extractor, int tensor_index, + ClassificationPostprocessingOptions* options) { + const auto* tensor_metadata = + metadata_extractor.GetOutputTensorMetadata(tensor_index); + if (tensor_metadata == nullptr) { + return absl::OkStatus(); + } + // Get ScoreCalibrationOptions, if any. + ASSIGN_OR_RETURN(const ProcessUnit* score_calibration_process_unit, + metadata_extractor.FindFirstProcessUnit( + *tensor_metadata, + tflite::ProcessUnitOptions_ScoreCalibrationOptions)); + if (score_calibration_process_unit == nullptr) { + return absl::OkStatus(); + } + auto* score_calibration_options = + score_calibration_process_unit->options_as_ScoreCalibrationOptions(); + // Get corresponding AssociatedFile. + auto score_calibration_filename = + metadata_extractor.FindFirstAssociatedFileName( + *tensor_metadata, + tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION); + if (score_calibration_filename.empty()) { + return CreateStatusWithPayload( + absl::StatusCode::kNotFound, + "Found ScoreCalibrationOptions but missing required associated " + "parameters file with type TENSOR_AXIS_SCORE_CALIBRATION.", + MediaPipeTasksStatus::kMetadataAssociatedFileNotFoundError); + } + ASSIGN_OR_RETURN( + absl::string_view score_calibration_file, + metadata_extractor.GetAssociatedFile(score_calibration_filename)); + ScoreCalibrationCalculatorOptions calculator_options; + MP_RETURN_IF_ERROR(ConfigureScoreCalibration( + score_calibration_options->score_transformation(), + score_calibration_options->default_score(), score_calibration_file, + &calculator_options)); + (*options->mutable_score_calibration_options())[tensor_index] = + calculator_options; + return absl::OkStatus(); +} + +// Fills in the TensorsToClassificationCalculatorOptions based on the +// classifier options and the (optional) output tensor metadata. absl::Status ConfigureTensorsToClassificationCalculator( const ClassifierOptions& options, const ModelMetadataExtractor& metadata_extractor, int tensor_index, @@ -303,6 +353,8 @@ absl::Status ConfigureClassificationPostprocessing( ASSIGN_OR_RETURN(const auto heads_properties, GetClassificationHeadsProperties(model_resources)); for (int i = 0; i < heads_properties.num_heads; ++i) { + MP_RETURN_IF_ERROR(ConfigureScoreCalibrationIfAny( + *model_resources.GetMetadataExtractor(), i, options)); MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator( classifier_options, *model_resources.GetMetadataExtractor(), i, options->add_tensors_to_classifications_options())); @@ -314,8 +366,8 @@ absl::Status ConfigureClassificationPostprocessing( return absl::OkStatus(); } -// A "mediapipe.tasks.ClassificationPostprocessingSubgraph" converts raw -// tensors into ClassificationResult objects. +// A "mediapipe.tasks.components.ClassificationPostprocessingSubgraph" converts +// raw tensors into ClassificationResult objects. // - Accepts CPU input tensors. // // Inputs: @@ -376,18 +428,21 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { } // If output tensors are quantized, they must be dequantized first. - GenericNode* tensors_dequantization_node; + TensorsSource dequantized_tensors(&tensors_in); if (options.has_quantized_outputs()) { - tensors_dequantization_node = + GenericNode* tensors_dequantization_node = &graph.AddNode("TensorsDequantizationCalculator"); tensors_in >> tensors_dequantization_node->In(kTensorsTag); + dequantized_tensors = {tensors_dequantization_node, kTensorsTag}; } // If there are multiple classification heads, the output tensors need to be // split. - GenericNode* split_tensor_vector_node; + std::vector split_tensors; + split_tensors.reserve(num_heads); if (num_heads > 1) { - split_tensor_vector_node = &graph.AddNode("SplitTensorVectorCalculator"); + GenericNode* split_tensor_vector_node = + &graph.AddNode("SplitTensorVectorCalculator"); auto& split_tensor_vector_options = split_tensor_vector_node ->GetOptions(); @@ -395,12 +450,27 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { auto* range = split_tensor_vector_options.add_ranges(); range->set_begin(i); range->set_end(i + 1); + split_tensors.emplace_back(split_tensor_vector_node, i); } - if (options.has_quantized_outputs()) { - tensors_dequantization_node->Out(kTensorsTag) >> - split_tensor_vector_node->In(0); + dequantized_tensors >> split_tensor_vector_node->In(0); + } else { + split_tensors.emplace_back(dequantized_tensors); + } + + // Adds score calibration for heads that specify it, if any. + std::vector calibrated_tensors; + calibrated_tensors.reserve(num_heads); + for (int i = 0; i < num_heads; ++i) { + if (options.score_calibration_options().contains(i)) { + GenericNode* score_calibration_node = + &graph.AddNode("ScoreCalibrationCalculator"); + score_calibration_node->GetOptions() + .CopyFrom(options.score_calibration_options().at(i)); + split_tensors[i] >> score_calibration_node->In(kScoresTag); + calibrated_tensors.emplace_back(score_calibration_node, + kCalibratedScoresTag); } else { - tensors_in >> split_tensor_vector_node->In(0); + calibrated_tensors.emplace_back(split_tensors[i]); } } @@ -413,17 +483,8 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { tensors_to_classification_nodes.back() ->GetOptions() .CopyFrom(options.tensors_to_classifications_options(i)); - if (num_heads == 1) { - if (options.has_quantized_outputs()) { - tensors_dequantization_node->Out(kTensorsTag) >> - tensors_to_classification_nodes.back()->In(kTensorsTag); - } else { - tensors_in >> tensors_to_classification_nodes.back()->In(kTensorsTag); - } - } else { - split_tensor_vector_node->Out(i) >> - tensors_to_classification_nodes.back()->In(kTensorsTag); - } + calibrated_tensors[i] >> + tensors_to_classification_nodes.back()->In(kTensorsTag); } // Aggregates Classifications into a single ClassificationResult. @@ -444,7 +505,8 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { } }; REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::ClassificationPostprocessingSubgraph); + ::mediapipe::tasks::components::ClassificationPostprocessingSubgraph); +} // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/classification_postprocessing.h b/mediapipe/tasks/cc/components/classification_postprocessing.h index 5ae12e93a..eb638bd60 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing.h +++ b/mediapipe/tasks/cc/components/classification_postprocessing.h @@ -18,11 +18,12 @@ limitations under the License. #include "absl/status/status.h" #include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" namespace mediapipe { namespace tasks { +namespace components { // Configures a ClassificationPostprocessing subgraph using the provided model // resources and ClassifierOptions. @@ -31,7 +32,7 @@ namespace tasks { // Example usage: // // auto& postprocessing = -// graph.AddNode("mediapipe.tasks.ClassificationPostprocessingSubgraph"); +// graph.AddNode("mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); // MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( // model_resources, // classifier_options, @@ -49,10 +50,11 @@ namespace tasks { // CLASSIFICATION_RESULT - ClassificationResult // The output aggregated classification results. absl::Status ConfigureClassificationPostprocessing( - const core::ModelResources& model_resources, - const ClassifierOptions& classifier_options, + const tasks::core::ModelResources& model_resources, + const tasks::components::proto::ClassifierOptions& classifier_options, ClassificationPostprocessingOptions* options); +} // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/classification_postprocessing_options.proto b/mediapipe/tasks/cc/components/classification_postprocessing_options.proto index 3f96d5bde..9b67e2f75 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing_options.proto +++ b/mediapipe/tasks/cc/components/classification_postprocessing_options.proto @@ -15,17 +15,22 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks; +package mediapipe.tasks.components; import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto"; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto"; +import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto"; message ClassificationPostprocessingOptions { extend mediapipe.CalculatorOptions { optional ClassificationPostprocessingOptions ext = 460416950; } + // Optional mapping between output tensor index and corresponding score + // calibration options. + map score_calibration_options = 4; + // Options for the TensorsToClassification calculators (one per classification // head) encapsulated by the ClassificationPostprocessing subgraph. repeated mediapipe.TensorsToClassificationCalculatorOptions diff --git a/mediapipe/tasks/cc/components/classification_postprocessing_test.cc b/mediapipe/tasks/cc/components/classification_postprocessing_test.cc index 4cba24fd9..67223050f 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing_test.cc +++ b/mediapipe/tasks/cc/components/classification_postprocessing_test.cc @@ -41,9 +41,10 @@ limitations under the License. #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/timestamp.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" #include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/util/label_map.pb.h" @@ -51,6 +52,7 @@ limitations under the License. namespace mediapipe { namespace tasks { +namespace components { namespace { using ::mediapipe::api2::Input; @@ -58,6 +60,7 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::proto::ClassifierOptions; using ::mediapipe::tasks::core::ModelResources; using ::testing::HasSubstr; using ::testing::proto::Approximately; @@ -65,6 +68,8 @@ using ::testing::proto::Approximately; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/"; constexpr char kQuantizedImageClassifierWithMetadata[] = "vision/mobilenet_v1_0.25_224_quant.tflite"; +constexpr char kQuantizedImageClassifierWithDummyScoreCalibration[] = + "vision/mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite"; constexpr char kQuantizedImageClassifierWithoutMetadata[] = "vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite"; constexpr char kFloatTwoHeadsAudioClassifierWithMetadata[] = @@ -147,11 +152,12 @@ TEST_F(ConfigureTest, SucceedsWithoutMetadata) { ClassifierOptions options_in; ClassificationPostprocessingOptions options_out; - MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, + MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, options_in, &options_out)); EXPECT_THAT(options_out, Approximately(EqualsProto( - R"pb(tensors_to_classifications_options { + R"pb(score_calibration_options: [] + tensors_to_classifications_options { min_score_threshold: -3.4028235e+38 top_k: -1 sort_by_descending_score: true @@ -169,11 +175,12 @@ TEST_F(ConfigureTest, SucceedsWithMaxResults) { options_in.set_max_results(3); ClassificationPostprocessingOptions options_out; - MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, + MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, options_in, &options_out)); EXPECT_THAT(options_out, Approximately(EqualsProto( - R"pb(tensors_to_classifications_options { + R"pb(score_calibration_options: [] + tensors_to_classifications_options { min_score_threshold: -3.4028235e+38 top_k: 3 sort_by_descending_score: true @@ -191,11 +198,12 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) { options_in.set_score_threshold(0.5); ClassificationPostprocessingOptions options_out; - MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, + MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, options_in, &options_out)); EXPECT_THAT(options_out, Approximately(EqualsProto( - R"pb(tensors_to_classifications_options { + R"pb(score_calibration_options: [] + tensors_to_classifications_options { min_score_threshold: 0.5 top_k: -1 sort_by_descending_score: true @@ -212,7 +220,7 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) { ClassifierOptions options_in; ClassificationPostprocessingOptions options_out; - MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, + MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, options_in, &options_out)); // Check label map size and two first elements. @@ -229,7 +237,8 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) { options_out.mutable_tensors_to_classifications_options(0) ->clear_label_items(); EXPECT_THAT(options_out, Approximately(EqualsProto( - R"pb(tensors_to_classifications_options { + R"pb(score_calibration_options: [] + tensors_to_classifications_options { min_score_threshold: -3.4028235e+38 top_k: -1 sort_by_descending_score: true @@ -249,14 +258,15 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) { options_in.add_category_allowlist("tench"); ClassificationPostprocessingOptions options_out; - MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, + MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, options_in, &options_out)); // Clear label map and compare the rest of the options. options_out.mutable_tensors_to_classifications_options(0) ->clear_label_items(); EXPECT_THAT(options_out, Approximately(EqualsProto( - R"pb(tensors_to_classifications_options { + R"pb(score_calibration_options: [] + tensors_to_classifications_options { min_score_threshold: -3.4028235e+38 top_k: -1 sort_by_descending_score: true @@ -277,14 +287,15 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) { options_in.add_category_denylist("background"); ClassificationPostprocessingOptions options_out; - MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, + MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, options_in, &options_out)); // Clear label map and compare the rest of the options. options_out.mutable_tensors_to_classifications_options(0) ->clear_label_items(); EXPECT_THAT(options_out, Approximately(EqualsProto( - R"pb(tensors_to_classifications_options { + R"pb(score_calibration_options: [] + tensors_to_classifications_options { min_score_threshold: -3.4028235e+38 top_k: -1 sort_by_descending_score: true @@ -297,6 +308,56 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) { )pb"))); } +TEST_F(ConfigureTest, SucceedsWithScoreCalibration) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel( + kQuantizedImageClassifierWithDummyScoreCalibration)); + ClassifierOptions options_in; + + ClassificationPostprocessingOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, + options_in, &options_out)); + + // Check label map size and two first elements. + EXPECT_EQ( + options_out.tensors_to_classifications_options(0).label_items_size(), + kMobileNetNumClasses); + EXPECT_THAT( + options_out.tensors_to_classifications_options(0).label_items().at(0), + EqualsProto(R"pb(name: "background")pb")); + EXPECT_THAT( + options_out.tensors_to_classifications_options(0).label_items().at(1), + EqualsProto(R"pb(name: "tench")pb")); + // Clear label map. + options_out.mutable_tensors_to_classifications_options(0) + ->clear_label_items(); + // Check sigmoids size and first element. + EXPECT_EQ(options_out.score_calibration_options_size(), 1); + auto score_calibration_options = + options_out.score_calibration_options().at(0); + EXPECT_EQ(score_calibration_options.sigmoids_size(), kMobileNetNumClasses); + EXPECT_THAT(score_calibration_options.sigmoids(0), + EqualsProto(R"pb(scale: 1.0 slope: 1.0 offset: 0.0)pb")); + options_out.mutable_score_calibration_options()->at(0).clear_sigmoids(); + // Compare the rest of the options. + EXPECT_THAT( + options_out, + Approximately(EqualsProto( + R"pb(score_calibration_options { + key: 0 + value { score_transformation: IDENTITY default_score: 0.5 } + } + tensors_to_classifications_options { + min_score_threshold: -3.4028235e+38 + top_k: -1 + sort_by_descending_score: true + } + classification_aggregation_options { head_names: "probability" } + has_quantized_outputs: true + )pb"))); +} + TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, @@ -304,7 +365,7 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { ClassifierOptions options_in; ClassificationPostprocessingOptions options_out; - MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, + MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, options_in, &options_out)); // Check label maps sizes and first two elements. EXPECT_EQ( @@ -331,7 +392,8 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { options_out.mutable_tensors_to_classifications_options(1) ->clear_label_items(); EXPECT_THAT(options_out, Approximately(EqualsProto( - R"pb(tensors_to_classifications_options { + R"pb(score_calibration_options: [] + tensors_to_classifications_options { min_score_threshold: -3.4028235e+38 top_k: -1 sort_by_descending_score: true @@ -358,8 +420,8 @@ class PostprocessingTest : public tflite_shims::testing::Test { CreateModelResourcesForModel(model_name)); Graph graph; - auto& postprocessing = - graph.AddNode("mediapipe.tasks.ClassificationPostprocessingSubgraph"); + auto& postprocessing = graph.AddNode( + "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( *model_resources, options, &postprocessing.GetOptions())); @@ -503,6 +565,52 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) { })pb")); } +TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { + // Build graph. + ClassifierOptions options; + options.set_max_results(3); + MP_ASSERT_OK_AND_ASSIGN( + auto poller, + BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options)); + // Build input tensors. + std::vector tensor(kMobileNetNumClasses, 0); + tensor[1] = 12; + tensor[2] = 14; + tensor[3] = 16; + tensor[4] = 18; + + // Send tensors and get results. + AddTensor(tensor, Tensor::ElementType::kUInt8, + /*quantization_parameters=*/{0.1, 10}); + MP_ASSERT_OK(Run()); + MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); + + // Validate results. + EXPECT_THAT(results, EqualsProto( + R"pb(classifications { + entries { + categories { + index: 4 + score: 0.6899744811 + category_name: "tiger shark" + } + categories { + index: 3 + score: 0.6456563062 + category_name: "great white shark" + } + categories { + index: 2 + score: 0.5986876601 + category_name: "goldfish" + } + timestamp_ms: 0 + } + head_index: 0 + head_name: "probability" + })pb")); +} + TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { // Build graph. ClassifierOptions options; @@ -621,5 +729,6 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) { } } // namespace +} // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/classifier_options.cc b/mediapipe/tasks/cc/components/classifier_options.cc index 17650db26..c54db5f88 100644 --- a/mediapipe/tasks/cc/components/classifier_options.cc +++ b/mediapipe/tasks/cc/components/classifier_options.cc @@ -15,15 +15,15 @@ limitations under the License. #include "mediapipe/tasks/cc/components/classifier_options.h" -#include "mediapipe/tasks/cc/components/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" namespace mediapipe { namespace tasks { namespace components { -tasks::ClassifierOptions ConvertClassifierOptionsToProto( +tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( ClassifierOptions* options) { - tasks::ClassifierOptions options_proto; + tasks::components::proto::ClassifierOptions options_proto; options_proto.set_display_names_locale(options->display_names_locale); options_proto.set_max_results(options->max_results); options_proto.set_score_threshold(options->score_threshold); diff --git a/mediapipe/tasks/cc/components/classifier_options.h b/mediapipe/tasks/cc/components/classifier_options.h index d5d1a54f3..e15bf5e69 100644 --- a/mediapipe/tasks/cc/components/classifier_options.h +++ b/mediapipe/tasks/cc/components/classifier_options.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ -#include "mediapipe/tasks/cc/components/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" namespace mediapipe { namespace tasks { @@ -49,7 +49,7 @@ struct ClassifierOptions { }; // Converts a ClassifierOptions to a ClassifierOptionsProto. -tasks::ClassifierOptions ConvertClassifierOptionsToProto( +tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( ClassifierOptions* classifier_options); } // namespace components diff --git a/mediapipe/tasks/cc/components/containers/proto/BUILD b/mediapipe/tasks/cc/components/containers/proto/BUILD index b6e98d72f..9c6402e64 100644 --- a/mediapipe/tasks/cc/components/containers/proto/BUILD +++ b/mediapipe/tasks/cc/components/containers/proto/BUILD @@ -29,3 +29,8 @@ mediapipe_proto_library( "//mediapipe/framework/formats:rect_proto", ], ) + +mediapipe_proto_library( + name = "embeddings_proto", + srcs = ["embeddings.proto"], +) diff --git a/mediapipe/tasks/cc/components/containers/proto/embeddings.proto b/mediapipe/tasks/cc/components/containers/proto/embeddings.proto new file mode 100644 index 000000000..d57b08b53 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/proto/embeddings.proto @@ -0,0 +1,56 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.components.containers.proto; + +// Defines a dense floating-point embedding. +message FloatEmbedding { + repeated float values = 1 [packed = true]; +} + +// Defines a dense scalar-quantized embedding. +message QuantizedEmbedding { + optional bytes values = 1; +} + +// Floating-point or scalar-quantized embedding with an optional timestamp. +message EmbeddingEntry { + // The actual embedding, either floating-point or scalar-quantized. + oneof embedding { + FloatEmbedding float_embedding = 1; + QuantizedEmbedding quantized_embedding = 2; + } + // The optional timestamp (in milliseconds) associated to the embedding entry. + // This is useful for time series use cases, e.g. audio embedding. + optional int64 timestamp_ms = 3; +} + +// Embeddings for a given embedder head. +message Embeddings { + repeated EmbeddingEntry entries = 1; + // The index of the embedder head that produced this embedding. This is useful + // for multi-head models. + optional int32 head_index = 2; + // The name of the embedder head, which is the corresponding tensor metadata + // name (if any). This is useful for multi-head models. + optional string head_name = 3; +} + +// Contains one set of results per embedder head. +message EmbeddingResult { + repeated Embeddings embeddings = 1; +} diff --git a/mediapipe/tasks/cc/components/embedder_options.cc b/mediapipe/tasks/cc/components/embedder_options.cc new file mode 100644 index 000000000..9cc399f7b --- /dev/null +++ b/mediapipe/tasks/cc/components/embedder_options.cc @@ -0,0 +1,34 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/embedder_options.h" + +#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace components { + +tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto( + EmbedderOptions* embedder_options) { + tasks::components::proto::EmbedderOptions options_proto; + options_proto.set_l2_normalize(embedder_options->l2_normalize); + options_proto.set_quantize(embedder_options->quantize); + return options_proto; +} + +} // namespace components +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/embedder_options.h b/mediapipe/tasks/cc/components/embedder_options.h new file mode 100644 index 000000000..9ed0fee87 --- /dev/null +++ b/mediapipe/tasks/cc/components/embedder_options.h @@ -0,0 +1,47 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_ + +#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace components { + +// Embedder options for MediaPipe C++ embedding extraction tasks. +struct EmbedderOptions { + // Whether to normalize the returned feature vector with L2 norm. Use this + // option only if the model does not already contain a native L2_NORMALIZATION + // TF Lite Op. In most cases, this is already the case and L2 norm is thus + // achieved through TF Lite inference. + bool l2_normalize; + + // Whether the returned embedding should be quantized to bytes via scalar + // quantization. Embeddings are implicitly assumed to be unit-norm and + // therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + // the l2_normalize option if this is not the case. + bool quantize; +}; + +tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto( + EmbedderOptions* embedder_options); + +} // namespace components +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_ diff --git a/mediapipe/tasks/cc/components/embedding_postprocessing_graph.cc b/mediapipe/tasks/cc/components/embedding_postprocessing_graph.cc new file mode 100644 index 000000000..4ea009cb8 --- /dev/null +++ b/mediapipe/tasks/cc/components/embedding_postprocessing_graph.cc @@ -0,0 +1,232 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/tool/options_map.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" +#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace mediapipe { +namespace tasks { +namespace components { + +namespace { + +using ::mediapipe::Tensor; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::GenericNode; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +using ::mediapipe::tasks::components::proto::EmbedderOptions; +using ::mediapipe::tasks::core::ModelResources; +using TensorsSource = + ::mediapipe::tasks::SourceOrNodeOutput>; + +constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT"; + +// Identifies whether or not the model has quantized outputs, and performs +// sanity checks. +absl::StatusOr HasQuantizedOutputs( + const ModelResources& model_resources) { + const tflite::Model& model = *model_resources.GetTfLiteModel(); + if (model.subgraphs()->size() != 1) { + return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, + "Embedding tflite models are " + "assumed to have a single subgraph.", + MediaPipeTasksStatus::kInvalidArgumentError); + } + const auto* primary_subgraph = (*model.subgraphs())[0]; + int num_output_tensors = primary_subgraph->outputs()->size(); + // Sanity check tensor types and check if model outputs are quantized or not. + int num_quantized_tensors = 0; + for (int i = 0; i < num_output_tensors; ++i) { + const auto* tensor = + primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i)); + if (tensor->type() != tflite::TensorType_FLOAT32 && + tensor->type() != tflite::TensorType_UINT8) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Expected output tensor at index %d to have type " + "UINT8 or FLOAT32, found %s instead.", + i, tflite::EnumNameTensorType(tensor->type())), + MediaPipeTasksStatus::kInvalidOutputTensorTypeError); + } + if (tensor->type() == tflite::TensorType_UINT8) { + num_quantized_tensors++; + } + } + if (num_quantized_tensors != num_output_tensors && + num_quantized_tensors != 0) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Expected either all or none of the output tensors to be " + "quantized, but found %d quantized outputs for %d total outputs.", + num_quantized_tensors, num_output_tensors), + MediaPipeTasksStatus::kInvalidOutputTensorTypeError); + } + // Check if metadata is consistent with model topology. + const auto* output_tensors_metadata = + model_resources.GetMetadataExtractor()->GetOutputTensorMetadata(); + if (output_tensors_metadata != nullptr && + num_output_tensors != output_tensors_metadata->size()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Mismatch between number of output tensors (%d) and " + "output tensors metadata (%d).", + num_output_tensors, output_tensors_metadata->size()), + MediaPipeTasksStatus::kMetadataInconsistencyError); + } + return num_quantized_tensors > 0; +} + +// Extracts head names from model resources. Returns an empty vector if none are +// available. If partially available, the name for heads that don't specify a +// metadata name will be set to the empty string. +absl::StatusOr> GetHeadNames( + const ModelResources& model_resources) { + std::vector head_names; + const auto* output_tensors_metadata = + model_resources.GetMetadataExtractor()->GetOutputTensorMetadata(); + if (output_tensors_metadata == nullptr) { + return head_names; + } + head_names.reserve(output_tensors_metadata->size()); + bool names_available = false; + for (const auto& metadata : *output_tensors_metadata) { + if (metadata->name() != nullptr) { + names_available = true; + head_names.push_back(metadata->name()->str()); + } else { + head_names.push_back(""); + } + } + if (!names_available) { + head_names.clear(); + } + return head_names; +} + +} // namespace + +absl::Status ConfigureEmbeddingPostprocessing( + const ModelResources& model_resources, + const EmbedderOptions& embedder_options, + proto::EmbeddingPostprocessingGraphOptions* options) { + ASSIGN_OR_RETURN(bool has_quantized_outputs, + HasQuantizedOutputs(model_resources)); + options->set_has_quantized_outputs(has_quantized_outputs); + auto* tensors_to_embeddings_options = + options->mutable_tensors_to_embeddings_options(); + *tensors_to_embeddings_options->mutable_embedder_options() = embedder_options; + ASSIGN_OR_RETURN(auto head_names, GetHeadNames(model_resources)); + if (!head_names.empty()) { + *tensors_to_embeddings_options->mutable_head_names() = {head_names.begin(), + head_names.end()}; + } + return absl::OkStatus(); +} + +// An EmbeddingPostprocessingGraph converts raw tensors into EmbeddingResult +// objects. +// - Accepts CPU input tensors. +// +// Inputs: +// TENSORS - std::vector +// The output tensors of an InferenceCalculator, to convert into +// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8. +// Outputs: +// EMBEDDING_RESULT - EmbeddingResult +// The output EmbeddingResult. +// +// The recommended way of using this graph is through the GraphBuilder API using +// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more +// details. +// +// TODO: add support for additional optional "TIMESTAMPS" input for +// embeddings aggregation. +class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { + public: + absl::StatusOr GetConfig( + mediapipe::SubgraphContext* sc) override { + Graph graph; + ASSIGN_OR_RETURN( + auto embedding_result_out, + BuildEmbeddingPostprocessing( + sc->Options(), + graph[Input>(kTensorsTag)], graph)); + embedding_result_out >> graph[Output(kEmbeddingResultTag)]; + return graph.GetConfig(); + } + + private: + // Adds an on-device embedding postprocessing graph into the provided + // builder::Graph instance. The embedding postprocessing graph takes tensors + // (std::vector) as input and returns one output stream + // containing the output embedding results (EmbeddingResult). + // + // options: the on-device EmbeddingPostprocessingGraphOptions + // tensors_in: (std::vector) tensors to postprocess. + // graph: the mediapipe builder::Graph instance to be updated. + absl::StatusOr> BuildEmbeddingPostprocessing( + const proto::EmbeddingPostprocessingGraphOptions options, + Source> tensors_in, Graph& graph) { + // If output tensors are quantized, they must be dequantized first. + TensorsSource dequantized_tensors(&tensors_in); + if (options.has_quantized_outputs()) { + GenericNode& tensors_dequantization_node = + graph.AddNode("TensorsDequantizationCalculator"); + tensors_in >> tensors_dequantization_node.In(kTensorsTag); + dequantized_tensors = {&tensors_dequantization_node, kTensorsTag}; + } + + // Adds TensorsToEmbeddingsCalculator. + GenericNode& tensors_to_embeddings_node = + graph.AddNode("TensorsToEmbeddingsCalculator"); + tensors_to_embeddings_node + .GetOptions() + .CopyFrom(options.tensors_to_embeddings_options()); + dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag); + return tensors_to_embeddings_node[Output( + kEmbeddingResultTag)]; + } +}; +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::components::EmbeddingPostprocessingGraph); + +} // namespace components +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/embedding_postprocessing_graph.h b/mediapipe/tasks/cc/components/embedding_postprocessing_graph.h new file mode 100644 index 000000000..af8fa6706 --- /dev/null +++ b/mediapipe/tasks/cc/components/embedding_postprocessing_graph.h @@ -0,0 +1,61 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_ + +#include "absl/status/status.h" +#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" +#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" + +namespace mediapipe { +namespace tasks { +namespace components { + +// Configures an EmbeddingPostprocessingGraph using the provided model resources +// and EmbedderOptions. +// - Accepts CPU input tensors. +// +// Example usage: +// +// auto& postprocessing = +// graph.AddNode("mediapipe.tasks.components.EmbeddingPostprocessingGraph"); +// MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing( +// model_resources, +// embedder_options, +// &postprocessing.GetOptions())); +// +// The result EmbeddingPostprocessingGraph has the following I/O: +// Inputs: +// TENSORS - std::vector +// The output tensors of an InferenceCalculator, to convert into +// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8. +// Outputs: +// EMBEDDING_RESULT - EmbeddingResult +// The output EmbeddingResult. +// +// TODO: add support for additional optional "TIMESTAMPS" input for +// embeddings aggregation. +absl::Status ConfigureEmbeddingPostprocessing( + const tasks::core::ModelResources& model_resources, + const tasks::components::proto::EmbedderOptions& embedder_options, + proto::EmbeddingPostprocessingGraphOptions* options); + +} // namespace components +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/embedding_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/embedding_postprocessing_graph_test.cc new file mode 100644 index 000000000..9c0d21ab2 --- /dev/null +++ b/mediapipe/tasks/cc/components/embedding_postprocessing_graph_test.cc @@ -0,0 +1,136 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h" + +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" +#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace { + +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::proto::EmbedderOptions; +using ::mediapipe::tasks::components::proto:: + EmbeddingPostprocessingGraphOptions; +using ::mediapipe::tasks::core::ModelResources; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/"; +constexpr char kMobileNetV3Embedder[] = + "vision/mobilenet_v3_small_100_224_embedder.tflite"; +// Abusing a few classifiers (topologically similar to embedders) in order to +// add coverage. +constexpr char kQuantizedImageClassifierWithMetadata[] = + "vision/mobilenet_v1_0.25_224_quant.tflite"; +constexpr char kQuantizedImageClassifierWithoutMetadata[] = + "vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite"; + +constexpr char kTestModelResourcesTag[] = "test_model_resources"; + +// Helper function to get ModelResources. +absl::StatusOr> CreateModelResourcesForModel( + absl::string_view model_name) { + auto external_file = std::make_unique(); + external_file->set_file_name(JoinPath("./", kTestDataDirectory, model_name)); + return ModelResources::Create(kTestModelResourcesTag, + std::move(external_file)); +} + +class ConfigureTest : public tflite_shims::testing::Test {}; + +TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); + EmbedderOptions options_in; + options_in.set_l2_normalize(true); + + EmbeddingPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, + &options_out)); + + EXPECT_THAT( + options_out, + EqualsProto(ParseTextProtoOrDie( + R"pb(tensors_to_embeddings_options { + embedder_options { l2_normalize: true } + head_names: "probability" + } + has_quantized_outputs: true)pb"))); +} + +TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); + EmbedderOptions options_in; + options_in.set_quantize(true); + + EmbeddingPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, + &options_out)); + + EXPECT_THAT( + options_out, + EqualsProto(ParseTextProtoOrDie( + R"pb(tensors_to_embeddings_options { + embedder_options { quantize: true } + } + has_quantized_outputs: true)pb"))); +} + +TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { + MP_ASSERT_OK_AND_ASSIGN(auto model_resources, + CreateModelResourcesForModel(kMobileNetV3Embedder)); + EmbedderOptions options_in; + options_in.set_quantize(true); + options_in.set_l2_normalize(true); + + EmbeddingPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, + &options_out)); + + EXPECT_THAT( + options_out, + EqualsProto(ParseTextProtoOrDie( + R"pb(tensors_to_embeddings_options { + embedder_options { quantize: true l2_normalize: true } + head_names: "feature" + } + has_quantized_outputs: false)pb"))); +} + +// TODO: add E2E Postprocessing tests once timestamp aggregation is +// supported. + +} // namespace +} // namespace components +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/image_preprocessing.cc b/mediapipe/tasks/cc/components/image_preprocessing.cc index 18958a911..046a97e4d 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.cc +++ b/mediapipe/tasks/cc/components/image_preprocessing.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "mediapipe/calculators/image/image_clone_calculator.pb.h" #include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" @@ -37,6 +38,7 @@ limitations under the License. namespace mediapipe { namespace tasks { +namespace components { namespace { using ::mediapipe::Tensor; @@ -137,10 +139,25 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, BuildImageTensorSpecs(model_resources)); MP_RETURN_IF_ERROR(ConfigureImageToTensorCalculator( image_tensor_specs, options->mutable_image_to_tensor_options())); + // The GPU backend isn't able to process int data. If the input tensor is + // quantized, forces the image preprocessing graph to use CPU backend. + if (image_tensor_specs.tensor_type == tflite::TensorType_UINT8) { + options->set_backend(ImagePreprocessingOptions::CPU_BACKEND); + } return absl::OkStatus(); } -// A "mediapipe.tasks.ImagePreprocessingSubgraph" performs image preprocessing. +Source AddDataConverter(Source image_in, Graph& graph, + bool output_on_gpu) { + auto& image_converter = graph.AddNode("ImageCloneCalculator"); + image_converter.GetOptions() + .set_output_on_gpu(output_on_gpu); + image_in >> image_converter.In(""); + return image_converter[Output("")]; +} + +// A "mediapipe.tasks.components.ImagePreprocessingSubgraph" performs image +// preprocessing. // - Accepts CPU input images and outputs CPU tensors. // // Inputs: @@ -212,7 +229,22 @@ class ImagePreprocessingSubgraph : public Subgraph { auto& image_to_tensor = graph.AddNode("ImageToTensorCalculator"); image_to_tensor.GetOptions() .CopyFrom(options.image_to_tensor_options()); - image_in >> image_to_tensor.In(kImageTag); + switch (options.backend()) { + case ImagePreprocessingOptions::CPU_BACKEND: { + auto cpu_image = + AddDataConverter(image_in, graph, /*output_on_gpu=*/false); + cpu_image >> image_to_tensor.In(kImageTag); + break; + } + case ImagePreprocessingOptions::GPU_BACKEND: { + auto gpu_image = + AddDataConverter(image_in, graph, /*output_on_gpu=*/true); + gpu_image >> image_to_tensor.In(kImageTag); + break; + } + default: + image_in >> image_to_tensor.In(kImageTag); + } norm_rect_in >> image_to_tensor.In(kNormRectTag); // Extract optional image properties. @@ -237,7 +269,9 @@ class ImagePreprocessingSubgraph : public Subgraph { }; } }; -REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::ImagePreprocessingSubgraph); +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::components::ImagePreprocessingSubgraph); +} // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/image_preprocessing.h b/mediapipe/tasks/cc/components/image_preprocessing.h index 097045d2e..a5b767f3a 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.h +++ b/mediapipe/tasks/cc/components/image_preprocessing.h @@ -22,6 +22,7 @@ limitations under the License. namespace mediapipe { namespace tasks { +namespace components { // Configures an ImagePreprocessing subgraph using the provided model resources. // - Accepts CPU input images and outputs CPU tensors. @@ -29,7 +30,7 @@ namespace tasks { // Example usage: // // auto& preprocessing = -// graph.AddNode("mediapipe.tasks.ImagePreprocessingSubgraph"); +// graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); // MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( // model_resources, // &preprocessing.GetOptions())); @@ -38,6 +39,9 @@ namespace tasks { // Inputs: // IMAGE - Image // The image to preprocess. +// NORM_RECT - NormalizedRect @Optional +// Describes region of image to extract. +// @Optional: rect covering the whole image is used if not specified. // Outputs: // TENSORS - std::vector // Vector containing a single Tensor populated with the converted and @@ -55,6 +59,7 @@ absl::Status ConfigureImagePreprocessing( const core::ModelResources& model_resources, ImagePreprocessingOptions* options); +} // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/image_preprocessing_options.proto b/mediapipe/tasks/cc/components/image_preprocessing_options.proto index 0b2c77975..d1685c319 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing_options.proto +++ b/mediapipe/tasks/cc/components/image_preprocessing_options.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks; +package mediapipe.tasks.components; import "mediapipe/calculators/tensor/image_to_tensor_calculator.proto"; import "mediapipe/framework/calculator.proto"; @@ -28,4 +28,13 @@ message ImagePreprocessingOptions { // Options for the ImageToTensor calculator encapsulated by the // ImagePreprocessing subgraph. optional mediapipe.ImageToTensorCalculatorOptions image_to_tensor_options = 1; + + // The required image processing backend type. If not specified or set to + // default, use the backend that the input image data is already on. + enum Backend { + DEFAULT = 0; + CPU_BACKEND = 1; + GPU_BACKEND = 2; + } + optional Backend backend = 2 [default = DEFAULT]; } diff --git a/mediapipe/tasks/cc/components/proto/BUILD b/mediapipe/tasks/cc/components/proto/BUILD new file mode 100644 index 000000000..8c4dcdad9 --- /dev/null +++ b/mediapipe/tasks/cc/components/proto/BUILD @@ -0,0 +1,53 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "segmenter_options_proto", + srcs = ["segmenter_options.proto"], +) + +mediapipe_proto_library( + name = "classifier_options_proto", + srcs = ["classifier_options.proto"], +) + +mediapipe_proto_library( + name = "embedder_options_proto", + srcs = ["embedder_options.proto"], +) + +mediapipe_proto_library( + name = "embedding_postprocessing_graph_options_proto", + srcs = ["embedding_postprocessing_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto", + ], +) + +mediapipe_proto_library( + name = "text_preprocessing_graph_options_proto", + srcs = ["text_preprocessing_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/classifier_options.proto b/mediapipe/tasks/cc/components/proto/classifier_options.proto similarity index 97% rename from mediapipe/tasks/cc/components/classifier_options.proto rename to mediapipe/tasks/cc/components/proto/classifier_options.proto index 99dc9d026..ea1491bb8 100644 --- a/mediapipe/tasks/cc/components/classifier_options.proto +++ b/mediapipe/tasks/cc/components/proto/classifier_options.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks; +package mediapipe.tasks.components.proto; // Shared options used by all classification tasks. message ClassifierOptions { diff --git a/mediapipe/tasks/cc/components/proto/embedder_options.proto b/mediapipe/tasks/cc/components/proto/embedder_options.proto new file mode 100644 index 000000000..8a60a1398 --- /dev/null +++ b/mediapipe/tasks/cc/components/proto/embedder_options.proto @@ -0,0 +1,33 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.components.proto; + +// Shared options used by all embedding extraction tasks. +message EmbedderOptions { + // Whether to normalize the returned feature vector with L2 norm. Use this + // option only if the model does not already contain a native L2_NORMALIZATION + // TF Lite Op. In most cases, this is already the case and L2 norm is thus + // achieved through TF Lite inference. + optional bool l2_normalize = 1; + + // Whether the returned embedding should be quantized to bytes via scalar + // quantization. Embeddings are implicitly assumed to be unit-norm and + // therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + // the l2_normalize option if this is not the case. + optional bool quantize = 2; +} diff --git a/mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.proto b/mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.proto new file mode 100644 index 000000000..4e79f8178 --- /dev/null +++ b/mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.proto @@ -0,0 +1,38 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.components.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto"; + +message EmbeddingPostprocessingGraphOptions { + extend mediapipe.CalculatorOptions { + optional EmbeddingPostprocessingGraphOptions ext = 476346926; + } + + // Options for the TensorsToEmbeddings calculator encapsulated by the + // EmbeddingPostprocessingGraph. + optional mediapipe.TensorsToEmbeddingsCalculatorOptions + tensors_to_embeddings_options = 1; + + // Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32). + optional bool has_quantized_outputs = 2; + + // TODO: add options to control whether timestamp aggregation + // should be used or not. +} diff --git a/mediapipe/tasks/cc/components/segmenter_options.proto b/mediapipe/tasks/cc/components/proto/segmenter_options.proto similarity index 97% rename from mediapipe/tasks/cc/components/segmenter_options.proto rename to mediapipe/tasks/cc/components/proto/segmenter_options.proto index c70b4af47..a2f37d3a0 100644 --- a/mediapipe/tasks/cc/components/segmenter_options.proto +++ b/mediapipe/tasks/cc/components/proto/segmenter_options.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks; +package mediapipe.tasks.components.proto; // Shared options used by image segmentation tasks. message SegmenterOptions { diff --git a/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto b/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto new file mode 100644 index 000000000..c0c207543 --- /dev/null +++ b/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto @@ -0,0 +1,40 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.components.proto; + +import "mediapipe/framework/calculator.proto"; + +message TextPreprocessingGraphOptions { + extend mediapipe.CalculatorOptions { + optional TextPreprocessingGraphOptions ext = 476978751; + } + + // The type of text preprocessor required for the TFLite model. + enum PreprocessorType { + UNSPECIFIED_PREPROCESSOR = 0; + // Used for the BertPreprocessorCalculator. + BERT_PREPROCESSOR = 1; + // Used for the RegexPreprocessorCalculator. + REGEX_PREPROCESSOR = 2; + } + optional PreprocessorType preprocessor_type = 1; + + // The maximum input sequence length for the TFLite model. Used with + // BERT_PREPROCESSOR and REGEX_PREPROCESSOR. + optional int32 max_seq_len = 2; +} diff --git a/mediapipe/tasks/cc/components/text_preprocessing_graph.cc b/mediapipe/tasks/cc/components/text_preprocessing_graph.cc new file mode 100644 index 000000000..2c4c1b866 --- /dev/null +++ b/mediapipe/tasks/cc/components/text_preprocessing_graph.cc @@ -0,0 +1,266 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/substitute.h" +#include "mediapipe/calculators/tensor/bert_preprocessor_calculator.pb.h" +#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/subgraph.h" +#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" + +namespace mediapipe { +namespace tasks { +namespace components { + +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::SideInput; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::SideSource; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::proto::TextPreprocessingGraphOptions; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; + +constexpr char kTextTag[] = "TEXT"; +constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; +constexpr char kTensorsTag[] = "TENSORS"; + +constexpr int kNumInputTensorsForBert = 3; +constexpr int kNumInputTensorsForRegex = 1; + +// Gets the name of the MediaPipe calculator associated with +// `preprocessor_type`. +absl::StatusOr GetCalculatorNameFromPreprocessorType( + TextPreprocessingGraphOptions::PreprocessorType preprocessor_type) { + switch (preprocessor_type) { + case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, "Unspecified preprocessor type", + MediaPipeTasksStatus::kInvalidArgumentError); + case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: + return "BertPreprocessorCalculator"; + case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: + return "RegexPreprocessorCalculator"; + } +} + +// Determines the PreprocessorType for the model based on its metadata as well +// as its input tensors' type and count. Returns an error if there is no +// compatible preprocessor. +absl::StatusOr +GetPreprocessorType(const ModelResources& model_resources) { + const tflite::SubGraph& model_graph = + *(*model_resources.GetTfLiteModel()->subgraphs())[0]; + bool all_int32_tensors = + absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { + return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32; + }); + bool all_string_tensors = + absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { + return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING; + }); + if (!all_int32_tensors && !all_string_tensors) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "All input tensors should have type int32 or all should have type " + "string", + MediaPipeTasksStatus::kInvalidInputTensorTypeError); + } + if (all_string_tensors) { + // TODO: Support a TextToTensor calculator for string tensors. + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "String tensors are not supported yet", + MediaPipeTasksStatus::kInvalidInputTensorTypeError); + } + + // Otherwise, all tensors should have type int32 + const ModelMetadataExtractor* metadata_extractor = + model_resources.GetMetadataExtractor(); + if (metadata_extractor->GetModelMetadata() == nullptr || + metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Text models with int32 input tensors require TFLite Model " + "Metadata but none was found", + MediaPipeTasksStatus::kMetadataNotFoundError); + } + + if (model_graph.inputs()->size() == kNumInputTensorsForBert) { + return TextPreprocessingGraphOptions::BERT_PREPROCESSOR; + } + + if (model_graph.inputs()->size() == kNumInputTensorsForRegex) { + return TextPreprocessingGraphOptions::REGEX_PREPROCESSOR; + } + + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute("Models with int32 input tensors should take exactly $0 " + "or $1 input tensors, but found $2", + kNumInputTensorsForBert, kNumInputTensorsForRegex, + model_graph.inputs()->size()), + MediaPipeTasksStatus::kInvalidNumInputTensorsError); +} + +// Returns the maximum input sequence length accepted by the TFLite +// model that owns `model graph` or returns an error if the model's input +// tensors' shape is invalid for text preprocessing. This util assumes that the +// model has the correct input tensors type and count for the +// BertPreprocessorCalculator or the RegexPreprocessorCalculator. +absl::StatusOr GetMaxSeqLen(const tflite::SubGraph& model_graph) { + const flatbuffers::Vector& input_indices = *model_graph.inputs(); + const flatbuffers::Vector>& + model_tensors = *model_graph.tensors(); + for (int i : input_indices) { + const tflite::Tensor* tensor = model_tensors[i]; + + if (tensor->shape()->size() != 2) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute( + "Model should take 2-D input tensors, got dimension: $0", + tensor->shape()->size()), + MediaPipeTasksStatus::kInvalidInputTensorDimensionsError); + } + + if ((*tensor->shape())[0] != 1) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute( + "Input tensors should all have batch size 1, got: $0", + (*tensor->shape())[0]), + MediaPipeTasksStatus::kInvalidInputTensorSizeError); + } + } + + int max_seq_len = (*model_tensors[input_indices[0]]->shape())[1]; + if (!absl::c_all_of(input_indices, [&model_tensors, max_seq_len](int i) { + return (*model_tensors[i]->shape())[1] == max_seq_len; + })) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Input tensors don't have the same size", + MediaPipeTasksStatus::kInvalidInputTensorSizeError); + } + return max_seq_len; +} +} // namespace + +absl::Status ConfigureTextPreprocessingSubgraph( + const ModelResources& model_resources, + TextPreprocessingGraphOptions& options) { + if (model_resources.GetTfLiteModel()->subgraphs()->size() != 1) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Text tflite models are assumed to have a single subgraph.", + MediaPipeTasksStatus::kInvalidArgumentError); + } + + ASSIGN_OR_RETURN( + TextPreprocessingGraphOptions::PreprocessorType preprocessor_type, + GetPreprocessorType(model_resources)); + options.set_preprocessor_type(preprocessor_type); + ASSIGN_OR_RETURN( + int max_seq_len, + GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0])); + options.set_max_seq_len(max_seq_len); + + return absl::OkStatus(); +} + +// A "mediapipe.tasks.components.TextPreprocessingSubgraph" performs text +// preprocessing. +// - Accepts a std::string input and outputs CPU tensors. +// +// Inputs: +// TEXT - std::string +// The text to preprocess. +// Side inputs: +// METADATA_EXTRACTOR - ModelMetadataExtractor +// The metadata extractor for the TFLite model. Used to determine the order +// for input tensors and to extract tokenizer information. +// Outputs: +// TENSORS - std::vector +// Vector containing the preprocessed input tensors for the TFLite model. +// +// The recommended way of using this subgraph is through the GraphBuilder API +// using the 'ConfigureTextPreprocessing()' function. See header file for more +// details. +class TextPreprocessingSubgraph : public mediapipe::Subgraph { + public: + absl::StatusOr GetConfig( + mediapipe::SubgraphContext* sc) override { + Graph graph; + ASSIGN_OR_RETURN( + Source> tensors_in, + BuildTextPreprocessing( + sc->Options(), + graph[Input(kTextTag)], + graph[SideInput(kMetadataExtractorTag)], + graph)); + tensors_in >> graph[Output>(kTensorsTag)]; + return graph.GetConfig(); + } + + private: + absl::StatusOr>> BuildTextPreprocessing( + const TextPreprocessingGraphOptions& options, Source text_in, + SideSource metadata_extractor_in, Graph& graph) { + ASSIGN_OR_RETURN( + std::string preprocessor_name, + GetCalculatorNameFromPreprocessorType(options.preprocessor_type())); + auto& text_preprocessor = graph.AddNode(preprocessor_name); + switch (options.preprocessor_type()) { + case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: { + break; + } + case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: { + text_preprocessor.GetOptions() + .set_bert_max_seq_len(options.max_seq_len()); + metadata_extractor_in >> + text_preprocessor.SideIn(kMetadataExtractorTag); + break; + } + case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: { + text_preprocessor.GetOptions() + .set_max_seq_len(options.max_seq_len()); + metadata_extractor_in >> + text_preprocessor.SideIn(kMetadataExtractorTag); + break; + } + } + text_in >> text_preprocessor.In(kTextTag); + return text_preprocessor[Output>(kTensorsTag)]; + } +}; +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::components::TextPreprocessingSubgraph); + +} // namespace components +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/text_preprocessing_graph.h b/mediapipe/tasks/cc/components/text_preprocessing_graph.h new file mode 100644 index 000000000..b031a5550 --- /dev/null +++ b/mediapipe/tasks/cc/components/text_preprocessing_graph.h @@ -0,0 +1,58 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ + +#include "absl/status/status.h" +#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" + +// Configures a TextPreprocessing subgraph using the provided `model_resources` +// and TextPreprocessingGraphOptions. +// - Accepts a std::string input and outputs CPU tensors. +// +// Example usage: +// +// auto& preprocessing = +// graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); +// MP_RETURN_IF_ERROR(ConfigureTextPreprocessingSubgraph( +// model_resources, +// &preprocessing.GetOptions())); +// +// The resulting TextPreprocessing subgraph has the following I/O: +// Inputs: +// TEXT - std::string +// The text to preprocess. +// Side inputs: +// METADATA_EXTRACTOR - ModelMetadataExtractor +// The metadata extractor for the TFLite model. Used to determine the order +// for input tensors and to extract tokenizer information. +// Outputs: +// TENSORS - std::vector +// Vector containing the preprocessed input tensors for the TFLite model. +namespace mediapipe { +namespace tasks { +namespace components { + +absl::Status ConfigureTextPreprocessingSubgraph( + const tasks::core::ModelResources& model_resources, + tasks::components::proto::TextPreprocessingGraphOptions& options); + +} // namespace components +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ diff --git a/mediapipe/tasks/cc/components/utils/BUILD b/mediapipe/tasks/cc/components/utils/BUILD new file mode 100644 index 000000000..0ec7ac945 --- /dev/null +++ b/mediapipe/tasks/cc/components/utils/BUILD @@ -0,0 +1,44 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +cc_library( + name = "source_or_node_output", + hdrs = ["source_or_node_output.h"], + deps = ["//mediapipe/framework/api2:builder"], +) + +cc_library( + name = "cosine_similarity", + srcs = ["cosine_similarity.cc"], + hdrs = ["cosine_similarity.h"], + deps = [ + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_test( + name = "cosine_similarity_test", + srcs = ["cosine_similarity_test.cc"], + deps = [ + ":cosine_similarity", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/utils/cosine_similarity.cc b/mediapipe/tasks/cc/components/utils/cosine_similarity.cc new file mode 100644 index 000000000..af471a2d8 --- /dev/null +++ b/mediapipe/tasks/cc/components/utils/cosine_similarity.cc @@ -0,0 +1,112 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace utils { + +namespace { + +using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; + +template +absl::StatusOr ComputeCosineSimilarity(const T& u, const T& v, + int num_elements) { + if (num_elements <= 0) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Cannot compute cosing similarity on empty embeddings", + MediaPipeTasksStatus::kInvalidArgumentError); + } + double dot_product = 0.0; + double norm_u = 0.0; + double norm_v = 0.0; + for (int i = 0; i < num_elements; ++i) { + dot_product += u[i] * v[i]; + norm_u += u[i] * u[i]; + norm_v += v[i] * v[i]; + } + if (norm_u <= 0.0 || norm_v <= 0.0) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Cannot compute cosine similarity on embedding with 0 norm", + MediaPipeTasksStatus::kInvalidArgumentError); + } + return dot_product / std::sqrt(norm_u * norm_v); +} + +} // namespace + +// Utility function to compute cosine similarity [1] between two embedding +// entries. May return an InvalidArgumentError if e.g. the feature vectors are +// of different types (quantized vs. float), have different sizes, or have a +// an L2-norm of 0. +// +// [1]: https://en.wikipedia.org/wiki/Cosine_similarity +absl::StatusOr CosineSimilarity(const EmbeddingEntry& u, + const EmbeddingEntry& v) { + if (u.has_float_embedding() && v.has_float_embedding()) { + if (u.float_embedding().values().size() != + v.float_embedding().values().size()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Cannot compute cosine similarity between embeddings " + "of different sizes (%d vs. %d)", + u.float_embedding().values().size(), + v.float_embedding().values().size()), + MediaPipeTasksStatus::kInvalidArgumentError); + } + return ComputeCosineSimilarity(u.float_embedding().values().data(), + v.float_embedding().values().data(), + u.float_embedding().values().size()); + } + if (u.has_quantized_embedding() && v.has_quantized_embedding()) { + if (u.quantized_embedding().values().size() != + v.quantized_embedding().values().size()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Cannot compute cosine similarity between embeddings " + "of different sizes (%d vs. %d)", + u.quantized_embedding().values().size(), + v.quantized_embedding().values().size()), + MediaPipeTasksStatus::kInvalidArgumentError); + } + return ComputeCosineSimilarity(reinterpret_cast( + u.quantized_embedding().values().data()), + reinterpret_cast( + v.quantized_embedding().values().data()), + u.quantized_embedding().values().size()); + } + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Cannot compute cosine similarity between quantized and float embeddings", + MediaPipeTasksStatus::kInvalidArgumentError); +} + +} // namespace utils +} // namespace components +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/utils/cosine_similarity.h b/mediapipe/tasks/cc/components/utils/cosine_similarity.h new file mode 100644 index 000000000..4356811cd --- /dev/null +++ b/mediapipe/tasks/cc/components/utils/cosine_similarity.h @@ -0,0 +1,42 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_COSINE_SIMILARITY_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_COSINE_SIMILARITY_H_ + +#include "absl/status/statusor.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace utils { + +// Utility function to compute cosine similarity [1] between two embedding +// entries. May return an InvalidArgumentError if e.g. the feature vectors are +// of different types (quantized vs. float), have different sizes, or have a +// an L2-norm of 0. +// +// [1]: https://en.wikipedia.org/wiki/Cosine_similarity +absl::StatusOr CosineSimilarity( + const containers::proto::EmbeddingEntry& u, + const containers::proto::EmbeddingEntry& v); + +} // namespace utils +} // namespace components +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_COSINE_SIMILARITY_H_ diff --git a/mediapipe/tasks/cc/components/utils/cosine_similarity_test.cc b/mediapipe/tasks/cc/components/utils/cosine_similarity_test.cc new file mode 100644 index 000000000..176f7f7a6 --- /dev/null +++ b/mediapipe/tasks/cc/components/utils/cosine_similarity_test.cc @@ -0,0 +1,111 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h" + +#include +#include +#include + +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace utils { +namespace { + +using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; +using ::testing::HasSubstr; + +// Helper function to generate float EmbeddingEntry. +EmbeddingEntry BuildFloatEntry(std::vector values) { + EmbeddingEntry entry; + for (const float value : values) { + entry.mutable_float_embedding()->add_values(value); + } + return entry; +} + +// Helper function to generate quantized EmbeddingEntry. +EmbeddingEntry BuildQuantizedEntry(std::vector values) { + EmbeddingEntry entry; + entry.mutable_quantized_embedding()->set_values( + reinterpret_cast(values.data()), values.size()); + return entry; +} + +TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) { + auto u = BuildFloatEntry({0.1, 0.2}); + auto v = BuildQuantizedEntry({0, 1}); + + auto status = CosineSimilarity(u, v); + + EXPECT_EQ(status.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.status().message(), + HasSubstr("Cannot compute cosine similarity between quantized " + "and float embeddings")); +} + +TEST(CosineSimilarity, FailsWithZeroNorm) { + auto u = BuildFloatEntry({0.1, 0.2}); + auto v = BuildFloatEntry({0.0, 0.0}); + + auto status = CosineSimilarity(u, v); + + EXPECT_EQ(status.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + status.status().message(), + HasSubstr("Cannot compute cosine similarity on embedding with 0 norm")); +} + +TEST(CosineSimilarity, FailsWithDifferentSizes) { + auto u = BuildFloatEntry({0.1, 0.2}); + auto v = BuildFloatEntry({0.1, 0.2, 0.3}); + + auto status = CosineSimilarity(u, v); + + EXPECT_EQ(status.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.status().message(), + HasSubstr("Cannot compute cosine similarity between embeddings " + "of different sizes")); +} + +TEST(CosineSimilarity, SucceedsWithFloatEntries) { + auto u = BuildFloatEntry({1.0, 0.0, 0.0, 0.0}); + auto v = BuildFloatEntry({0.5, 0.5, 0.5, 0.5}); + + MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v)); + + EXPECT_EQ(result, 0.5); +} + +TEST(CosineSimilarity, SucceedsWithQuantizedEntries) { + auto u = BuildQuantizedEntry({127, 0, 0, 0}); + auto v = BuildQuantizedEntry({-128, 0, 0, 0}); + + MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v)); + + EXPECT_EQ(result, -1); +} + +} // namespace +} // namespace utils +} // namespace components +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/utils/source_or_node_output.h b/mediapipe/tasks/cc/components/utils/source_or_node_output.h new file mode 100644 index 000000000..55805d5a3 --- /dev/null +++ b/mediapipe/tasks/cc/components/utils/source_or_node_output.h @@ -0,0 +1,66 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_ + +#include "mediapipe/framework/api2/builder.h" + +namespace mediapipe { +namespace tasks { + +// Helper class representing either a Source object or a GenericNode output. +// +// Source and MultiSource (the output of a GenericNode) are widely incompatible, +// but being able to represent either of these in temporary variables and +// connect them later on facilitates graph building. +template +class SourceOrNodeOutput { + public: + SourceOrNodeOutput() = delete; + // The caller is responsible for ensuring 'source' outlives this object. + explicit SourceOrNodeOutput(mediapipe::api2::builder::Source* source) + : source_(source) {} + // The caller is responsible for ensuring 'node' outlives this object. + SourceOrNodeOutput(mediapipe::api2::builder::GenericNode* node, + std::string tag) + : node_(node), tag_(tag) {} + // The caller is responsible for ensuring 'node' outlives this object. + SourceOrNodeOutput(mediapipe::api2::builder::GenericNode* node, int index) + : node_(node), index_(index) {} + + // Connects the source or node output to the provided destination. + template + void operator>>(const U& dest) { + if (source_ != nullptr) { + *source_ >> dest; + } else { + if (index_ < 0) { + node_->Out(tag_) >> dest; + } else { + node_->Out(index_) >> dest; + } + } + } + + private: + mediapipe::api2::builder::Source* source_ = nullptr; + mediapipe::api2::builder::GenericNode* node_ = nullptr; + std::string tag_ = ""; + int index_ = -1; +}; + +} // namespace tasks +} // namespace mediapipe +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_ diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 38b134e78..38030c525 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -23,6 +23,8 @@ cc_library( srcs = ["base_options.cc"], hdrs = ["base_options.h"], deps = [ + "//mediapipe/calculators/tensor:inference_calculator_cc_proto", + "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "@com_google_absl//absl/memory", diff --git a/mediapipe/tasks/cc/core/base_options.cc b/mediapipe/tasks/cc/core/base_options.cc index d265ccad8..ec85ea753 100644 --- a/mediapipe/tasks/cc/core/base_options.cc +++ b/mediapipe/tasks/cc/core/base_options.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" +#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" namespace mediapipe { @@ -26,28 +28,37 @@ namespace core { proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) { proto::BaseOptions base_options_proto; - if (!base_options->model_file_name.empty()) { - base_options_proto.mutable_model_file()->set_file_name( - base_options->model_file_name); + if (!base_options->model_asset_path.empty()) { + base_options_proto.mutable_model_asset()->set_file_name( + base_options->model_asset_path); } - if (base_options->model_file_contents) { - base_options_proto.mutable_model_file()->mutable_file_content()->swap( - *base_options->model_file_contents.release()); + if (base_options->model_asset_buffer) { + base_options_proto.mutable_model_asset()->mutable_file_content()->swap( + *base_options->model_asset_buffer.release()); } - if (base_options->model_file_descriptor_meta.fd > 0) { - auto* file_descriptor_meta_proto = - base_options_proto.mutable_model_file()->mutable_file_descriptor_meta(); + if (base_options->model_asset_descriptor_meta.fd > 0) { + auto* file_descriptor_meta_proto = base_options_proto.mutable_model_asset() + ->mutable_file_descriptor_meta(); file_descriptor_meta_proto->set_fd( - base_options->model_file_descriptor_meta.fd); - if (base_options->model_file_descriptor_meta.length > 0) { + base_options->model_asset_descriptor_meta.fd); + if (base_options->model_asset_descriptor_meta.length > 0) { file_descriptor_meta_proto->set_length( - base_options->model_file_descriptor_meta.length); + base_options->model_asset_descriptor_meta.length); } - if (base_options->model_file_descriptor_meta.offset > 0) { + if (base_options->model_asset_descriptor_meta.offset > 0) { file_descriptor_meta_proto->set_offset( - base_options->model_file_descriptor_meta.offset); + base_options->model_asset_descriptor_meta.offset); } } + switch (base_options->delegate) { + case BaseOptions::Delegate::CPU: + base_options_proto.mutable_acceleration()->mutable_xnnpack(); + break; + case BaseOptions::Delegate::GPU: + base_options_proto.mutable_acceleration()->mutable_gpu(); + break; + } + return base_options_proto; } } // namespace core diff --git a/mediapipe/tasks/cc/core/base_options.h b/mediapipe/tasks/cc/core/base_options.h index 430726a08..67a03385b 100644 --- a/mediapipe/tasks/cc/core/base_options.h +++ b/mediapipe/tasks/cc/core/base_options.h @@ -30,11 +30,20 @@ namespace core { // Base options for MediaPipe C++ Tasks. struct BaseOptions { - // The model file contents as a string. - std::unique_ptr model_file_contents; + // The model asset file contents as as string. + std::unique_ptr model_asset_buffer; - // The path to the model file to open and mmap in memory. - std::string model_file_name = ""; + // The path to the model asset to open and mmap in memory. + std::string model_asset_path = ""; + + // The delegate to run MediaPipe. If the delegate is not set, default + // delegate CPU is used. + enum Delegate { + CPU = 0, + GPU = 1, + }; + + Delegate delegate = CPU; // The file descriptor to a file opened with open(2), with optional additional // offset and length information. @@ -49,7 +58,7 @@ struct BaseOptions { // Optional starting offset in the file referred to by the file descriptor // `fd`. int offset = -1; - } model_file_descriptor_meta; + } model_asset_descriptor_meta; // A non-default OpResolver to support custom Ops or specify a subset of // built-in Ops. diff --git a/mediapipe/tasks/cc/core/model_task_graph.cc b/mediapipe/tasks/cc/core/model_task_graph.cc index 547f35f2c..c6bc8f69b 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.cc +++ b/mediapipe/tasks/cc/core/model_task_graph.cc @@ -91,7 +91,7 @@ class InferenceSubgraph : public Subgraph { subgraph_options->model_resources_tag()); } else { model_resources_opts.mutable_model_file()->Swap( - subgraph_options->mutable_base_options()->mutable_model_file()); + subgraph_options->mutable_base_options()->mutable_model_asset()); } model_resources_node.SideOut(kMetadataExtractorTag) >> graph.SideOut(kMetadataExtractorTag); @@ -165,12 +165,16 @@ absl::StatusOr ModelTaskGraph::CreateModelResources( return model_resources_cache_service.GetObject().GetModelResources(tag); } -GenericNode& ModelTaskGraph::AddInference(const ModelResources& model_resources, - Graph& graph) const { +GenericNode& ModelTaskGraph::AddInference( + const ModelResources& model_resources, + const proto::Acceleration& acceleration, Graph& graph) const { auto& inference_subgraph = graph.AddNode("mediapipe.tasks.core.InferenceSubgraph"); auto& inference_subgraph_opts = inference_subgraph.GetOptions(); + inference_subgraph_opts.mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(acceleration); // When the model resources tag is available, the ModelResourcesCalculator // will retrieve the cached model resources from the graph service by tag. // Otherwise, provides the exteranal file and asks the @@ -180,7 +184,7 @@ GenericNode& ModelTaskGraph::AddInference(const ModelResources& model_resources, inference_subgraph_opts.set_model_resources_tag(model_resources.GetTag()); } else { inference_subgraph_opts.mutable_base_options() - ->mutable_model_file() + ->mutable_model_asset() ->CopyFrom(model_resources.GetModelFile()); } return inference_subgraph; diff --git a/mediapipe/tasks/cc/core/model_task_graph.h b/mediapipe/tasks/cc/core/model_task_graph.h index b13f2b5b4..36016cb89 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.h +++ b/mediapipe/tasks/cc/core/model_task_graph.h @@ -28,6 +28,7 @@ limitations under the License. #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/subgraph.h" #include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" @@ -62,7 +63,7 @@ class ModelTaskGraph : public Subgraph { auto external_file = std::make_unique(); external_file->Swap(sc->MutableOptions() ->mutable_base_options() - ->mutable_model_file()); + ->mutable_model_asset()); return CreateModelResources(sc, std::move(external_file)); } @@ -88,7 +89,9 @@ class ModelTaskGraph : public Subgraph { // engine. // - a MetadataExtractor output side packet with tag "METADATA_EXTRACTOR". api2::builder::GenericNode& AddInference( - const ModelResources& model_resources, api2::builder::Graph& graph) const; + const ModelResources& model_resources, + const proto::Acceleration& acceleration, + api2::builder::Graph& graph) const; private: std::unique_ptr local_model_resources_; diff --git a/mediapipe/tasks/cc/core/proto/acceleration.proto b/mediapipe/tasks/cc/core/proto/acceleration.proto index a0e522d94..a0575a5d5 100644 --- a/mediapipe/tasks/cc/core/proto/acceleration.proto +++ b/mediapipe/tasks/cc/core/proto/acceleration.proto @@ -19,7 +19,7 @@ package mediapipe.tasks.core.proto; import "mediapipe/calculators/tensor/inference_calculator.proto"; -option java_package = "com.google.mediapipe.tasks.core"; +option java_package = "com.google.mediapipe.tasks.core.proto"; option java_outer_classname = "AccelerationProto"; message Acceleration { diff --git a/mediapipe/tasks/cc/core/proto/base_options.proto b/mediapipe/tasks/cc/core/proto/base_options.proto index 07f4b9e35..b7c0629e8 100644 --- a/mediapipe/tasks/cc/core/proto/base_options.proto +++ b/mediapipe/tasks/cc/core/proto/base_options.proto @@ -26,13 +26,13 @@ option java_outer_classname = "BaseOptionsProto"; // Base options for mediapipe tasks. // Next Id: 4 message BaseOptions { - // The external model file, as a single standalone TFLite file. It could be + // The external model asset, as a single standalone TFLite file. It could be // packed with TFLite Model Metadata[1] and associated files if exist. Fail to // provide the necessary metadata and associated files might result in errors. // Check the documentation for each task about the specific requirement. // [1]: https://www.tensorflow.org/lite/convert/metadata - optional ExternalFile model_file = 1; + optional ExternalFile model_asset = 1; // Whether the mediapipe task treats the input data as a continuous data // stream, or a batch of unrelated data. Default to False. diff --git a/mediapipe/tasks/cc/metadata/metadata_extractor.cc b/mediapipe/tasks/cc/metadata/metadata_extractor.cc index 9ad4eee0a..fcec49083 100644 --- a/mediapipe/tasks/cc/metadata/metadata_extractor.cc +++ b/mediapipe/tasks/cc/metadata/metadata_extractor.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/metadata_extractor.h b/mediapipe/tasks/cc/metadata/metadata_extractor.h index d1a522a86..e74ac50a3 100644 --- a/mediapipe/tasks/cc/metadata/metadata_extractor.h +++ b/mediapipe/tasks/cc/metadata/metadata_extractor.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/metadata_parser.h.template b/mediapipe/tasks/cc/metadata/metadata_parser.h.template index 8ee0b4a28..f5ebfa04d 100644 --- a/mediapipe/tasks/cc/metadata/metadata_parser.h.template +++ b/mediapipe/tasks/cc/metadata/metadata_parser.h.template @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/metadata_populator.cc b/mediapipe/tasks/cc/metadata/metadata_populator.cc index 9892b7fe9..a6fd496a3 100644 --- a/mediapipe/tasks/cc/metadata/metadata_populator.cc +++ b/mediapipe/tasks/cc/metadata/metadata_populator.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/metadata_populator.h b/mediapipe/tasks/cc/metadata/metadata_populator.h index 47d0cb273..024ad785f 100644 --- a/mediapipe/tasks/cc/metadata/metadata_populator.h +++ b/mediapipe/tasks/cc/metadata/metadata_populator.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/metadata_version.cc b/mediapipe/tasks/cc/metadata/metadata_version.cc index 056c78a6b..7b9f123cb 100644 --- a/mediapipe/tasks/cc/metadata/metadata_version.cc +++ b/mediapipe/tasks/cc/metadata/metadata_version.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/metadata_version.h b/mediapipe/tasks/cc/metadata/metadata_version.h index a92a3b61c..a53caa547 100644 --- a/mediapipe/tasks/cc/metadata/metadata_version.h +++ b/mediapipe/tasks/cc/metadata/metadata_version.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/python/metadata_version.cc b/mediapipe/tasks/cc/metadata/python/metadata_version.cc index fa5b1e592..860a00e4f 100644 --- a/mediapipe/tasks/cc/metadata/python/metadata_version.cc +++ b/mediapipe/tasks/cc/metadata/python/metadata_version.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/tests/metadata_extractor_test.cc b/mediapipe/tasks/cc/metadata/tests/metadata_extractor_test.cc index e6f718e8f..4dacc7b8c 100644 --- a/mediapipe/tasks/cc/metadata/tests/metadata_extractor_test.cc +++ b/mediapipe/tasks/cc/metadata/tests/metadata_extractor_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/tests/metadata_parser_test.cc b/mediapipe/tasks/cc/metadata/tests/metadata_parser_test.cc index e1738d099..3605648e9 100644 --- a/mediapipe/tasks/cc/metadata/tests/metadata_parser_test.cc +++ b/mediapipe/tasks/cc/metadata/tests/metadata_parser_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc b/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc index 74938d17f..bf6206f38 100644 --- a/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc +++ b/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.cc b/mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.cc index 49a2c2926..a231afc40 100644 --- a/mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.cc +++ b/mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.h b/mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.h index f43d0dd55..fcd22d6d6 100644 --- a/mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.h +++ b/mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/utils/zip_writable_mem_file.cc b/mediapipe/tasks/cc/metadata/utils/zip_writable_mem_file.cc index 20318947b..3dc1f1950 100644 --- a/mediapipe/tasks/cc/metadata/utils/zip_writable_mem_file.cc +++ b/mediapipe/tasks/cc/metadata/utils/zip_writable_mem_file.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/utils/zip_writable_mem_file.h b/mediapipe/tasks/cc/metadata/utils/zip_writable_mem_file.h index f540d059f..ca06476ec 100644 --- a/mediapipe/tasks/cc/metadata/utils/zip_writable_mem_file.h +++ b/mediapipe/tasks/cc/metadata/utils/zip_writable_mem_file.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/tokenizers/BUILD b/mediapipe/tasks/cc/text/tokenizers/BUILD similarity index 100% rename from mediapipe/tasks/cc/components/tokenizers/BUILD rename to mediapipe/tasks/cc/text/tokenizers/BUILD diff --git a/mediapipe/tasks/cc/components/tokenizers/bert_tokenizer.cc b/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.cc similarity index 96% rename from mediapipe/tasks/cc/components/tokenizers/bert_tokenizer.cc rename to mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.cc index 4def30cfe..3348abff5 100644 --- a/mediapipe/tasks/cc/components/tokenizers/bert_tokenizer.cc +++ b/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.cc @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/tokenizers/bert_tokenizer.h" +#include "mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.h" #include "mediapipe/framework/port/integral_types.h" #include "tensorflow_text/core/kernels/regex_split.h" namespace mediapipe { namespace tasks { -namespace tokenizer { +namespace text { +namespace tokenizers { FlatHashMapBackedWordpiece::FlatHashMapBackedWordpiece( const std::vector& vocab) @@ -102,6 +103,7 @@ WordpieceTokenizerResult BertTokenizer::TokenizeWordpiece( return result; } -} // namespace tokenizer +} // namespace tokenizers +} // namespace text } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/tokenizers/bert_tokenizer.h b/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.h similarity index 92% rename from mediapipe/tasks/cc/components/tokenizers/bert_tokenizer.h rename to mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.h index ca362c304..d655fcadd 100644 --- a/mediapipe/tasks/cc/components/tokenizers/bert_tokenizer.h +++ b/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_TOKENIZERS_BERT_TOKENIZER_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_TOKENIZERS_BERT_TOKENIZER_H_ +#ifndef MEDIAPIPE_TASKS_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_ #include #include @@ -23,14 +23,15 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" -#include "mediapipe/tasks/cc/components/tokenizers/tokenizer.h" +#include "mediapipe/tasks/cc/text/tokenizers/tokenizer.h" #include "mediapipe/tasks/cc/text/utils/vocab_utils.h" #include "re2/re2.h" #include "tensorflow_text/core/kernels/wordpiece_tokenizer.h" namespace mediapipe { namespace tasks { -namespace tokenizer { +namespace text { +namespace tokenizers { constexpr char kDefaultDelimRe[] = R"((\s+|[!-/]|[:-@]|[\[-`]|[{-~]|[\p{P}]|[\x{4E00}-\x{9FFF}]|[\x{3400}-\x{4DBF}]|[\x{20000}-\x{2A6DF}]|[\x{2A700}-\x{2B73F}]|[\x{2B740}-\x{2B81F}]|[\x{2B820}-\x{2CEAF}]|[\x{F900}-\x{FAFF}]|[\x{2F800}-\x{2FA1F}]))"; @@ -86,7 +87,7 @@ class FlatHashMapBackedWordpiece : public tensorflow::text::WordpieceVocab { }; // Wordpiece tokenizer for bert models. Initialized with a vocab file or vector. -class BertTokenizer : public mediapipe::tasks::tokenizer::Tokenizer { +class BertTokenizer : public mediapipe::tasks::text::tokenizers::Tokenizer { public: // Initialize the tokenizer from vocab vector and tokenizer configs. explicit BertTokenizer(const std::vector& vocab, @@ -136,14 +137,15 @@ class BertTokenizer : public mediapipe::tasks::tokenizer::Tokenizer { int VocabularySize() const { return vocab_.VocabularySize(); } private: - mediapipe::tasks::tokenizer::FlatHashMapBackedWordpiece vocab_; + mediapipe::tasks::text::tokenizers::FlatHashMapBackedWordpiece vocab_; BertTokenizerOptions options_; RE2 delim_re_; RE2 include_delim_re_; }; -} // namespace tokenizer +} // namespace tokenizers +} // namespace text } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_TOKENIZERS_BERT_TOKENIZER_H_ +#endif // MEDIAPIPE_TASKS_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_ diff --git a/mediapipe/tasks/cc/components/tokenizers/bert_tokenizer_test.cc b/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer_test.cc similarity index 97% rename from mediapipe/tasks/cc/components/tokenizers/bert_tokenizer_test.cc rename to mediapipe/tasks/cc/text/tokenizers/bert_tokenizer_test.cc index ceb754ea2..6970c5365 100644 --- a/mediapipe/tasks/cc/components/tokenizers/bert_tokenizer_test.cc +++ b/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/tokenizers/bert_tokenizer.h" +#include "mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -21,7 +21,8 @@ limitations under the License. namespace mediapipe { namespace tasks { -namespace tokenizer { +namespace text { +namespace tokenizers { using ::mediapipe::tasks::core::LoadBinaryContent; using ::testing::ElementsAre; @@ -168,6 +169,7 @@ TEST(TokenizerTest, TestLVocabularySize) { ASSERT_EQ(tokenizer->VocabularySize(), 4); } -} // namespace tokenizer +} // namespace tokenizers +} // namespace text } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/tokenizers/regex_tokenizer.cc b/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.cc similarity index 96% rename from mediapipe/tasks/cc/components/tokenizers/regex_tokenizer.cc rename to mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.cc index 002a40086..6a1dc2506 100644 --- a/mediapipe/tasks/cc/components/tokenizers/regex_tokenizer.cc +++ b/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/tokenizers/regex_tokenizer.h" +#include "mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h" #include @@ -22,7 +22,8 @@ limitations under the License. namespace mediapipe { namespace tasks { -namespace tokenizer { +namespace text { +namespace tokenizers { namespace { @@ -122,6 +123,7 @@ bool RegexTokenizer::GetUnknownToken(int* unknown_token) { return LookupId(kUnknown, unknown_token); } -} // namespace tokenizer +} // namespace tokenizers +} // namespace text } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/tokenizers/regex_tokenizer.h b/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h similarity index 84% rename from mediapipe/tasks/cc/components/tokenizers/regex_tokenizer.h rename to mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h index dc09803ee..c84dd33d2 100644 --- a/mediapipe/tasks/cc/components/tokenizers/regex_tokenizer.h +++ b/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h @@ -13,20 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_TOKENIZERS_REGEX_TOKENIZER_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_TOKENIZERS_REGEX_TOKENIZER_H_ +#ifndef MEDIAPIPE_TASKS_CC_TEXT_TOKENIZERS_REGEX_TOKENIZER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_TOKENIZERS_REGEX_TOKENIZER_H_ #include #include #include "absl/container/node_hash_map.h" #include "absl/strings/string_view.h" -#include "mediapipe/tasks/cc/components/tokenizers/tokenizer.h" +#include "mediapipe/tasks/cc/text/tokenizers/tokenizer.h" #include "re2/re2.h" namespace mediapipe { namespace tasks { -namespace tokenizer { +namespace text { +namespace tokenizers { // Tokenizer to load a vocabulary and split text by regular expressions. class RegexTokenizer : public Tokenizer { @@ -54,8 +55,9 @@ class RegexTokenizer : public Tokenizer { absl::node_hash_map index_token_map_; }; -} // namespace tokenizer +} // namespace tokenizers +} // namespace text } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_TOKENIZERS_REGEX_TOKENIZER_H_ +#endif // MEDIAPIPE_TASKS_CC_TEXT_TOKENIZERS_REGEX_TOKENIZER_H_ diff --git a/mediapipe/tasks/cc/components/tokenizers/regex_tokenizer_test.cc b/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer_test.cc similarity index 96% rename from mediapipe/tasks/cc/components/tokenizers/regex_tokenizer_test.cc rename to mediapipe/tasks/cc/text/tokenizers/regex_tokenizer_test.cc index 0831532f6..f0ae6497c 100644 --- a/mediapipe/tasks/cc/components/tokenizers/regex_tokenizer_test.cc +++ b/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/tokenizers/regex_tokenizer.h" +#include "mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -21,7 +21,8 @@ limitations under the License. namespace mediapipe { namespace tasks { -namespace tokenizer { +namespace text { +namespace tokenizers { using ::mediapipe::tasks::core::LoadBinaryContent; using ::testing::ElementsAre; @@ -117,6 +118,7 @@ TEST(RegexTokenizerTest, TestGetSpecialTokensFailure) { } // namespace -} // namespace tokenizer +} // namespace tokenizers +} // namespace text } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/tokenizers/sentencepiece_tokenizer.h b/mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer.h similarity index 85% rename from mediapipe/tasks/cc/components/tokenizers/sentencepiece_tokenizer.h rename to mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer.h index 4349c4520..9798f7bde 100644 --- a/mediapipe/tasks/cc/components/tokenizers/sentencepiece_tokenizer.h +++ b/mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_TOKENIZERS_SENTENCEPIECE_TOKENIZER_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_TOKENIZERS_SENTENCEPIECE_TOKENIZER_H_ +#ifndef MEDIAPIPE_TASKS_CC_TEXT_TOKENIZERS_SENTENCEPIECE_TOKENIZER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_TOKENIZERS_SENTENCEPIECE_TOKENIZER_H_ #include #include @@ -23,12 +23,13 @@ limitations under the License. #include "absl/strings/string_view.h" #include "mediapipe/framework/port/logging.h" -#include "mediapipe/tasks/cc/components/tokenizers/tokenizer.h" +#include "mediapipe/tasks/cc/text/tokenizers/tokenizer.h" #include "src/sentencepiece_processor.h" namespace mediapipe { namespace tasks { -namespace tokenizer { +namespace text { +namespace tokenizers { // SentencePiece tokenizer. Initialized with a model file. class SentencePieceTokenizer : public Tokenizer { @@ -68,8 +69,9 @@ class SentencePieceTokenizer : public Tokenizer { sentencepiece::SentencePieceProcessor sp_; }; -} // namespace tokenizer +} // namespace tokenizers +} // namespace text } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_TOKENIZERS_SENTENCEPIECE_TOKENIZER_H_ +#endif // MEDIAPIPE_TASKS_CC_TEXT_TOKENIZERS_SENTENCEPIECE_TOKENIZER_H_ diff --git a/mediapipe/tasks/cc/components/tokenizers/sentencepiece_tokenizer_test.cc b/mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer_test.cc similarity index 94% rename from mediapipe/tasks/cc/components/tokenizers/sentencepiece_tokenizer_test.cc rename to mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer_test.cc index e7e1e3f64..ed7decbd9 100644 --- a/mediapipe/tasks/cc/components/tokenizers/sentencepiece_tokenizer_test.cc +++ b/mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/tokenizers/sentencepiece_tokenizer.h" +#include "mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -21,7 +21,8 @@ limitations under the License. namespace mediapipe { namespace tasks { -namespace tokenizer { +namespace text { +namespace tokenizers { using ::mediapipe::tasks::core::LoadBinaryContent; using ::testing::ElementsAre; @@ -71,6 +72,7 @@ TEST(SentencePieceTokenizerTest, TestLookupWord) { } } -} // namespace tokenizer +} // namespace tokenizers +} // namespace text } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/tokenizers/tokenizer.h b/mediapipe/tasks/cc/text/tokenizers/tokenizer.h similarity index 84% rename from mediapipe/tasks/cc/components/tokenizers/tokenizer.h rename to mediapipe/tasks/cc/text/tokenizers/tokenizer.h index 107bdd5d3..ae984808e 100644 --- a/mediapipe/tasks/cc/components/tokenizers/tokenizer.h +++ b/mediapipe/tasks/cc/text/tokenizers/tokenizer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_TOKENIZERS_TOKENIZER_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_TOKENIZERS_TOKENIZER_H_ +#ifndef MEDIAPIPE_TASKS_CC_TEXT_TOKENIZERS_TOKENIZER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_TOKENIZERS_TOKENIZER_H_ #include #include @@ -24,7 +24,8 @@ limitations under the License. namespace mediapipe { namespace tasks { -namespace tokenizer { +namespace text { +namespace tokenizers { struct TokenizerResult { std::vector subwords; @@ -46,8 +47,9 @@ class Tokenizer { virtual ~Tokenizer() = default; }; -} // namespace tokenizer +} // namespace tokenizers +} // namespace text } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_TOKENIZERS_TOKENIZER_H_ +#endif // MEDIAPIPE_TASKS_CC_TEXT_TOKENIZERS_TOKENIZER_H_ diff --git a/mediapipe/tasks/cc/components/tokenizers/tokenizer_utils.cc b/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.cc similarity index 95% rename from mediapipe/tasks/cc/components/tokenizers/tokenizer_utils.cc rename to mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.cc index 1553db2ee..839c0818e 100644 --- a/mediapipe/tasks/cc/components/tokenizers/tokenizer_utils.cc +++ b/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/tokenizers/tokenizer_utils.h" +#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h" #include #include @@ -26,13 +26,14 @@ limitations under the License. #include "flatbuffers/flatbuffers.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/tokenizers/bert_tokenizer.h" -#include "mediapipe/tasks/cc/components/tokenizers/sentencepiece_tokenizer.h" +#include "mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.h" +#include "mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" namespace mediapipe { namespace tasks { -namespace tokenizer { +namespace text { +namespace tokenizers { using ::mediapipe::tasks::CreateStatusWithPayload; using ::mediapipe::tasks::MediaPipeTasksStatus; @@ -137,6 +138,7 @@ absl::StatusOr> CreateTokenizerFromProcessUnit( } } -} // namespace tokenizer +} // namespace tokenizers +} // namespace text } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/tokenizers/tokenizer_utils.h b/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h similarity index 78% rename from mediapipe/tasks/cc/components/tokenizers/tokenizer_utils.h rename to mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h index f60edb27b..c6bea1418 100644 --- a/mediapipe/tasks/cc/components/tokenizers/tokenizer_utils.h +++ b/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h @@ -13,20 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_TOKENIZERS_TOKENIZER_UTILS_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_TOKENIZERS_TOKENIZER_UTILS_H_ +#ifndef MEDIAPIPE_TASKS_CC_TEXT_TOKENIZERS_TOKENIZER_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_TOKENIZERS_TOKENIZER_UTILS_H_ #include #include "absl/status/statusor.h" -#include "mediapipe/tasks/cc/components/tokenizers/regex_tokenizer.h" -#include "mediapipe/tasks/cc/components/tokenizers/tokenizer.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h" +#include "mediapipe/tasks/cc/text/tokenizers/tokenizer.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" namespace mediapipe { namespace tasks { -namespace tokenizer { +namespace text { +namespace tokenizers { // Creates a RegexTokenizer by extracting vocab files from the metadata. absl::StatusOr> CreateRegexTokenizerFromOptions( @@ -38,8 +39,9 @@ absl::StatusOr> CreateTokenizerFromProcessUnit( const tflite::ProcessUnit* tokenizer_process_unit, const metadata::ModelMetadataExtractor* metadata_extractor); -} // namespace tokenizer +} // namespace tokenizers +} // namespace text } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_TOKENIZERS_TOKENIZER_UTILS_H_ +#endif // MEDIAPIPE_TASKS_CC_TEXT_TOKENIZERS_TOKENIZER_UTILS_H_ diff --git a/mediapipe/tasks/cc/components/tokenizers/tokenizer_utils_test.cc b/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils_test.cc similarity index 91% rename from mediapipe/tasks/cc/components/tokenizers/tokenizer_utils_test.cc rename to mediapipe/tasks/cc/text/tokenizers/tokenizer_utils_test.cc index eae475f5a..337d5ec7d 100644 --- a/mediapipe/tasks/cc/components/tokenizers/tokenizer_utils_test.cc +++ b/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/tokenizers/tokenizer_utils.h" +#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h" #include #include @@ -23,22 +23,22 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" -#include "flatbuffers/flatbuffers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/tokenizers/bert_tokenizer.h" -#include "mediapipe/tasks/cc/components/tokenizers/regex_tokenizer.h" -#include "mediapipe/tasks/cc/components/tokenizers/sentencepiece_tokenizer.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.h" +#include "mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h" +#include "mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" namespace mediapipe { namespace tasks { -namespace tokenizer { +namespace text { +namespace tokenizers { using ::mediapipe::tasks::kMediaPipeTasksPayload; using ::mediapipe::tasks::MediaPipeTasksStatus; @@ -55,7 +55,7 @@ constexpr char kModelWithSentencePieceTokenizerPath[] = "albert_with_metadata.tflite"; constexpr char kModelWithRegexTokenizerPath[] = "mediapipe/tasks/testdata/text/" - "test_model_nl_classifier_with_regex_tokenizer.tflite"; + "test_model_text_classifier_with_regex_tokenizer.tflite"; template bool is_type(T* t) { @@ -121,6 +121,7 @@ TEST(TokenizerUtilsTest, TestCreateFailure) { MediaPipeTasksStatus::kMetadataInvalidTokenizerError)))); } -} // namespace tokenizer +} // namespace tokenizers +} // namespace text } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_detector/BUILD b/mediapipe/tasks/cc/vision/hand_detector/BUILD new file mode 100644 index 000000000..23cf5f72d --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_detector/BUILD @@ -0,0 +1,72 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +cc_library( + name = "hand_detector_op_resolver", + srcs = ["hand_detector_op_resolver.cc"], + hdrs = ["hand_detector_op_resolver.h"], + deps = [ + "//mediapipe/util/tflite/operations:max_pool_argmax", + "//mediapipe/util/tflite/operations:max_unpooling", + "//mediapipe/util/tflite/operations:transpose_conv_bias", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) + +cc_library( + name = "hand_detector_graph", + srcs = ["hand_detector_graph.cc"], + deps = [ + "//mediapipe/calculators/core:clip_vector_size_calculator", + "//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto", + "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", + "//mediapipe/calculators/tflite:ssd_anchors_calculator", + "//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", + "//mediapipe/calculators/util:detection_letterbox_removal_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", + "//mediapipe/calculators/util:non_max_suppression_calculator", + "//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto", + "//mediapipe/calculators/util:rect_transformation_calculator", + "//mediapipe/calculators/util:rect_transformation_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_options_cc_proto", + "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc new file mode 100644 index 000000000..7ead21bad --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -0,0 +1,320 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h" +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" +#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h" +#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h" +#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h" +#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h" +#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" + +namespace mediapipe { +namespace tasks { +namespace vision { + +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorOptions; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kDetectionsTag[] = "DETECTIONS"; +constexpr char kNormRectsTag[] = "NORM_RECTS"; + +struct HandDetectionOuts { + Source> palm_detections; + Source> hand_rects; +}; + +void ConfigureTensorsToDetectionsCalculator( + mediapipe::TensorsToDetectionsCalculatorOptions* options) { + // TODO use metadata to configure these fields. + options->set_num_classes(1); + options->set_num_boxes(2016); + options->set_num_coords(18); + options->set_box_coord_offset(0); + options->set_keypoint_coord_offset(4); + options->set_num_keypoints(7); + options->set_num_values_per_keypoint(2); + options->set_sigmoid_score(true); + options->set_score_clipping_thresh(100.0); + options->set_reverse_output_order(true); + options->set_min_score_thresh(0.5); + options->set_x_scale(192.0); + options->set_y_scale(192.0); + options->set_w_scale(192.0); + options->set_h_scale(192.0); +} + +void ConfigureNonMaxSuppressionCalculator( + mediapipe::NonMaxSuppressionCalculatorOptions* options) { + options->set_min_suppression_threshold(0.3); + options->set_overlap_type( + mediapipe::NonMaxSuppressionCalculatorOptions::INTERSECTION_OVER_UNION); + options->set_algorithm( + mediapipe::NonMaxSuppressionCalculatorOptions::WEIGHTED); + // TODO "return_empty_detections" was removed from 1P graph, + // consider setting it from metadata accordingly. + options->set_return_empty_detections(true); +} + +void ConfigureSsdAnchorsCalculator( + mediapipe::SsdAnchorsCalculatorOptions* options) { + // TODO config SSD anchors parameters from metadata. + options->set_num_layers(4); + options->set_min_scale(0.1484375); + options->set_max_scale(0.75); + options->set_input_size_height(192); + options->set_input_size_width(192); + options->set_anchor_offset_x(0.5); + options->set_anchor_offset_y(0.5); + options->add_strides(8); + options->add_strides(16); + options->add_strides(16); + options->add_strides(16); + options->add_aspect_ratios(1.0); + options->set_fixed_anchor_size(true); +} + +void ConfigureDetectionsToRectsCalculator( + mediapipe::DetectionsToRectsCalculatorOptions* options) { + // Center of wrist. + options->set_rotation_vector_start_keypoint_index(0); + // MCP of middle finger. + options->set_rotation_vector_end_keypoint_index(2); + options->set_rotation_vector_target_angle(90); + options->set_output_zero_rect_for_empty_detections(true); +} + +void ConfigureRectTransformationCalculator( + mediapipe::RectTransformationCalculatorOptions* options) { + options->set_scale_x(2.6); + options->set_scale_y(2.6); + options->set_shift_y(-0.5); + options->set_square_long(true); +} + +} // namespace + +// A "mediapipe.tasks.vision.HandDetectorGraph" performs hand detection. The +// Hand Detection Graph is based on palm detection model, and scale the detected +// palm bounding box to enclose the detected whole hand. +// Accepts CPU input images and outputs Landmark on CPU. +// +// Inputs: +// IMAGE - Image +// Image to perform detection on. +// +// Outputs: +// DETECTIONS - std::vector +// Detected palms with maximum `num_hands` specified in options. +// NORM_RECTS - std::vector +// Detected hand bounding boxes in normalized coordinates. +// +// Example: +// node { +// calculator: "mediapipe.tasks.vision.HandDetectorGraph" +// input_stream: "IMAGE:image" +// output_stream: "DETECTIONS:palm_detections" +// output_stream: "NORM_RECTS:hand_rects_from_palm_detections" +// options { +// [mediapipe.tasks.hand_detector.proto.HandDetectorOptions.ext] { +// base_options { +// model_asset { +// file_name: "palm_detection.tflite" +// } +// } +// min_detection_confidence: 0.5 +// num_hands: 2 +// } +// } +// } +// TODO Decouple detection part and rects part. +class HandDetectorGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + ASSIGN_OR_RETURN(const auto* model_resources, + CreateModelResources(sc)); + Graph graph; + ASSIGN_OR_RETURN(auto hand_detection_outs, + BuildHandDetectionSubgraph( + sc->Options(), *model_resources, + graph[Input(kImageTag)], graph)); + hand_detection_outs.palm_detections >> + graph[Output>(kDetectionsTag)]; + hand_detection_outs.hand_rects >> + graph[Output>(kNormRectsTag)]; + return graph.GetConfig(); + } + + private: + // Updates graph to perform hand detection. Returns palm detections and + // corresponding hand RoI rects. + // + // subgraph_options: the mediapipe tasks module HandDetectionOptions. + // model_resources: the ModelSources object initialized from an hand detection + // model file with model metadata. + // image_in: image stream to run hand detection on. + // graph: the mediapipe builder::Graph instance to be updated. + absl::StatusOr BuildHandDetectionSubgraph( + const HandDetectorOptions& subgraph_options, + const core::ModelResources& model_resources, Source image_in, + Graph& graph) { + // Add image preprocessing subgraph. The model expects aspect ratio + // unchanged. + auto& preprocessing = + graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + auto& image_to_tensor_options = + *preprocessing + .GetOptions() + .mutable_image_to_tensor_options(); + image_to_tensor_options.set_keep_aspect_ratio(true); + image_to_tensor_options.set_border_mode( + mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); + MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + model_resources, + &preprocessing + .GetOptions())); + image_in >> preprocessing.In("IMAGE"); + auto preprocessed_tensors = preprocessing.Out("TENSORS"); + auto letterbox_padding = preprocessing.Out("LETTERBOX_PADDING"); + auto image_size = preprocessing.Out("IMAGE_SIZE"); + + // Adds SSD palm detection model. + auto& inference = AddInference( + model_resources, subgraph_options.base_options().acceleration(), graph); + preprocessed_tensors >> inference.In("TENSORS"); + auto model_output_tensors = inference.Out("TENSORS"); + + // Generates a single side packet containing a vector of SSD anchors. + auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator"); + ConfigureSsdAnchorsCalculator( + &ssd_anchor.GetOptions()); + auto anchors = ssd_anchor.SideOut(""); + + // Converts output tensors to Detections. + auto& tensors_to_detections = + graph.AddNode("TensorsToDetectionsCalculator"); + ConfigureTensorsToDetectionsCalculator( + &tensors_to_detections + .GetOptions()); + model_output_tensors >> tensors_to_detections.In("TENSORS"); + anchors >> tensors_to_detections.SideIn("ANCHORS"); + auto detections = tensors_to_detections.Out("DETECTIONS"); + + // Non maximum suppression removes redundant palm detections. + auto& non_maximum_suppression = + graph.AddNode("NonMaxSuppressionCalculator"); + ConfigureNonMaxSuppressionCalculator( + &non_maximum_suppression + .GetOptions()); + detections >> non_maximum_suppression.In(""); + auto nms_detections = non_maximum_suppression.Out(""); + + // Maps detection label IDs to the corresponding label text "Palm". + auto& detection_label_id_to_text = + graph.AddNode("DetectionLabelIdToTextCalculator"); + detection_label_id_to_text + .GetOptions() + .add_label("Palm"); + nms_detections >> detection_label_id_to_text.In(""); + auto detections_with_text = detection_label_id_to_text.Out(""); + + // Adjusts detection locations (already normalized to [0.f, 1.f]) on the + // letterboxed image (after image transformation with the FIT scale mode) to + // the corresponding locations on the same image with the letterbox removed + // (the input image to the graph before image transformation). + auto& detection_letterbox_removal = + graph.AddNode("DetectionLetterboxRemovalCalculator"); + detections_with_text >> detection_letterbox_removal.In("DETECTIONS"); + letterbox_padding >> detection_letterbox_removal.In("LETTERBOX_PADDING"); + auto palm_detections = + detection_letterbox_removal[Output>( + "DETECTIONS")]; + + // Converts each palm detection into a rectangle (normalized by image size) + // that encloses the palm and is rotated such that the line connecting + // center of the wrist and MCP of the middle finger is aligned with the + // Y-axis of the rectangle. + auto& detections_to_rects = graph.AddNode("DetectionsToRectsCalculator"); + ConfigureDetectionsToRectsCalculator( + &detections_to_rects + .GetOptions()); + palm_detections >> detections_to_rects.In("DETECTIONS"); + image_size >> detections_to_rects.In("IMAGE_SIZE"); + auto palm_rects = detections_to_rects.Out("NORM_RECTS"); + + // Expands and shifts the rectangle that contains the palm so that it's + // likely to cover the entire hand. + auto& rect_transformation = graph.AddNode("RectTransformationCalculator"); + ConfigureRectTransformationCalculator( + &rect_transformation + .GetOptions()); + palm_rects >> rect_transformation.In("NORM_RECTS"); + image_size >> rect_transformation.In("IMAGE_SIZE"); + auto hand_rects = rect_transformation.Out(""); + + // Clips the size of the input vector to the provided max_vec_size. This + // determines the maximum number of hand instances this graph outputs. + // Note that the performance gain of clipping detections earlier in this + // graph is minimal because NMS will minimize overlapping detections and the + // number of detections isn't expected to exceed 5-10. + auto& clip_normalized_rect_vector_size = + graph.AddNode("ClipNormalizedRectVectorSizeCalculator"); + clip_normalized_rect_vector_size + .GetOptions() + .set_max_vec_size(subgraph_options.num_hands()); + hand_rects >> clip_normalized_rect_vector_size.In(""); + auto clipped_hand_rects = + clip_normalized_rect_vector_size[Output>( + "")]; + + return HandDetectionOuts{.palm_detections = palm_detections, + .hand_rects = clipped_hand_rects}; + } +}; + +REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::HandDetectorGraph); + +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc new file mode 100644 index 000000000..a2fbd7c54 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc @@ -0,0 +1,205 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace { + +using ::file::Defaults; +using ::file::GetTextProto; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::core::proto::ExternalFile; +using ::mediapipe::tasks::vision::DecodeImageFromFile; +using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorOptions; +using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorResult; +using ::testing::EqualsProto; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::UnorderedPointwise; +using ::testing::Values; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kPalmDetectionModel[] = "palm_detection_full.tflite"; +constexpr char kTestRightHandsImage[] = "right_hands.jpg"; +constexpr char kTestModelResourcesTag[] = "test_model_resources"; + +constexpr char kOneHandResultFile[] = "hand_detector_result_one_hand.pbtxt"; +constexpr char kTwoHandsResultFile[] = "hand_detector_result_two_hands.pbtxt"; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageName[] = "image"; +constexpr char kPalmDetectionsTag[] = "DETECTIONS"; +constexpr char kPalmDetectionsName[] = "palm_detections"; +constexpr char kHandNormRectsTag[] = "NORM_RECTS"; +constexpr char kHandNormRectsName[] = "hand_norm_rects"; + +constexpr float kPalmDetectionBboxMaxDiff = 0.01; +constexpr float kHandRectMaxDiff = 0.02; + +// Helper function to get ModelResources. +absl::StatusOr> CreateModelResourcesForModel( + absl::string_view model_name) { + auto external_file = std::make_unique(); + external_file->set_file_name(JoinPath("./", kTestDataDirectory, model_name)); + return ModelResources::Create(kTestModelResourcesTag, + std::move(external_file)); +} + +// Helper function to create a TaskRunner from ModelResources. +absl::StatusOr> CreateTaskRunner( + const ModelResources& model_resources, absl::string_view model_name, + int num_hands) { + Graph graph; + + auto& hand_detection = + graph.AddNode("mediapipe.tasks.vision.HandDetectorGraph"); + + auto options = std::make_unique(); + options->mutable_base_options()->mutable_model_asset()->set_file_name( + JoinPath("./", kTestDataDirectory, model_name)); + options->set_min_detection_confidence(0.5); + options->set_num_hands(num_hands); + hand_detection.GetOptions().Swap(options.get()); + + graph[Input(kImageTag)].SetName(kImageName) >> + hand_detection.In(kImageTag); + + hand_detection.Out(kPalmDetectionsTag).SetName(kPalmDetectionsName) >> + graph[Output>(kPalmDetectionsTag)]; + hand_detection.Out(kHandNormRectsTag).SetName(kHandNormRectsName) >> + graph[Output>(kHandNormRectsTag)]; + + return TaskRunner::Create(graph.GetConfig(), + absl::make_unique()); +} + +HandDetectorResult GetExpectedHandDetectorResult(absl::string_view file_name) { + HandDetectorResult result; + CHECK_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name), + &result, Defaults())) + << "Expected hand detector result does not exist."; + return result; +} + +struct TestParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of hand landmark detection model. + std::string hand_detection_model_name; + // The filename of test image. + std::string test_image_name; + // The number of maximum detected hands. + int num_hands; + // The expected hand detector result. + HandDetectorResult expected_result; +}; + +class HandDetectionTest : public testing::TestWithParam {}; + +TEST_P(HandDetectionTest, DetectTwoHands) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + GetParam().test_image_name))); + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(GetParam().hand_detection_model_name)); + MP_ASSERT_OK_AND_ASSIGN( + auto task_runner, CreateTaskRunner(*model_resources, kPalmDetectionModel, + GetParam().num_hands)); + auto output_packets = + task_runner->Process({{kImageName, MakePacket(std::move(image))}}); + MP_ASSERT_OK(output_packets); + const std::vector& palm_detections = + (*output_packets)[kPalmDetectionsName].Get>(); + const std::vector expected_palm_detections( + GetParam().expected_result.detections().begin(), + GetParam().expected_result.detections().end()); + EXPECT_THAT(palm_detections, + UnorderedPointwise(Approximately(Partially(EqualsProto()), + kPalmDetectionBboxMaxDiff), + expected_palm_detections)); + const std::vector& hand_rects = + (*output_packets)[kHandNormRectsName].Get>(); + const std::vector expected_hand_rects( + GetParam().expected_result.hand_rects().begin(), + GetParam().expected_result.hand_rects().end()); + EXPECT_THAT(hand_rects, + UnorderedPointwise( + Approximately(Partially(EqualsProto()), kHandRectMaxDiff), + expected_hand_rects)); +} + +INSTANTIATE_TEST_SUITE_P( + HandDetectionTest, HandDetectionTest, + Values(TestParams{.test_name = "DetectOneHand", + .hand_detection_model_name = kPalmDetectionModel, + .test_image_name = kTestRightHandsImage, + .num_hands = 1, + .expected_result = + GetExpectedHandDetectorResult(kOneHandResultFile)}, + TestParams{.test_name = "DetectTwoHands", + .hand_detection_model_name = kPalmDetectionModel, + .test_image_name = kTestRightHandsImage, + .num_hands = 2, + .expected_result = + GetExpectedHandDetectorResult(kTwoHandsResultFile)}), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc new file mode 100644 index 000000000..262fb2c75 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc @@ -0,0 +1,35 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h" + +#include "mediapipe/util/tflite/operations/max_pool_argmax.h" +#include "mediapipe/util/tflite/operations/max_unpooling.h" +#include "mediapipe/util/tflite/operations/transpose_conv_bias.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +HandDetectorOpResolver::HandDetectorOpResolver() { + AddCustom("MaxPoolingWithArgmax2D", + mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D()); + AddCustom("MaxUnpooling2D", + mediapipe::tflite_operations::RegisterMaxUnpooling2D()); + AddCustom("Convolution2DTransposeBias", + mediapipe::tflite_operations::RegisterConvolution2DTransposeBias()); +} +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h new file mode 100644 index 000000000..a55661fa3 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h @@ -0,0 +1,34 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ + +#include "tensorflow/lite/kernels/register.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +class HandDetectorOpResolver : public tflite::ops::builtin::BuiltinOpResolver { + public: + HandDetectorOpResolver(); + HandDetectorOpResolver(const HandDetectorOpResolver& r) = delete; +}; + +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ diff --git a/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD b/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD new file mode 100644 index 000000000..2d22aab10 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD @@ -0,0 +1,40 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "hand_detector_options_proto", + srcs = ["hand_detector_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) + +mediapipe_proto_library( + name = "hand_detector_result_proto", + srcs = ["hand_detector_result.proto"], + deps = [ + "//mediapipe/framework/formats:detection_proto", + "//mediapipe/framework/formats:rect_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto new file mode 100644 index 000000000..ae22c7991 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto @@ -0,0 +1,44 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.vision.hand_detector.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.handdetector"; +option java_outer_classname = "HandDetectorOptionsProto"; + +message HandDetectorOptions { + extend mediapipe.CalculatorOptions { + optional HandDetectorOptions ext = 464864288; + } + // Base options for configuring Task library, such as specifying the TfLite + // model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // The locale to use for display names specified through the TFLite Model + // Metadata, if any. Defaults to English. + optional string display_names_locale = 2 [default = "en"]; + + // Minimum confidence value ([0.0, 1.0]) for confidence score to be considered + // successfully detecting a hand in the image. + optional float min_detection_confidence = 3 [default = 0.5]; + + // The maximum number of hands output by the detector. + optional int32 num_hands = 4; +} diff --git a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.proto b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.proto new file mode 100644 index 000000000..00c179ca9 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.proto @@ -0,0 +1,30 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.vision.hand_detector.proto; + +import "mediapipe/framework/formats/detection.proto"; +import "mediapipe/framework/formats/rect.proto"; + +message HandDetectorResult { + repeated mediapipe.Detection detections = 1; + repeated mediapipe.NormalizedRect hand_rects = 2; +} + +message HandDetectorResults { + repeated HandDetectorResult results = 1; +} diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD index 511d82d17..bb5b86212 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD @@ -62,10 +62,11 @@ cc_library( "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators:hand_landmarks_to_matrix_calculator", "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators:handedness_to_matrix_calculator", + "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators:landmarks_to_matrix_calculator", "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:hand_gesture_recognizer_subgraph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmark:hand_landmark_detector_graph", + "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:landmarks_to_matrix_calculator_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_subgraph", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "//mediapipe/tasks/metadata:metadata_schema_cc", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/BUILD b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/BUILD index ea4acb01c..4863c8682 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/BUILD @@ -50,14 +50,14 @@ cc_test( ) cc_library( - name = "hand_landmarks_to_matrix_calculator", - srcs = ["hand_landmarks_to_matrix_calculator.cc"], + name = "landmarks_to_matrix_calculator", + srcs = ["landmarks_to_matrix_calculator.cc"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:ret_check", - "@com_google_absl//absl/memory", + "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:landmarks_to_matrix_calculator_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -66,16 +66,16 @@ cc_library( ) cc_test( - name = "hand_landmarks_to_matrix_calculator_test", - srcs = ["hand_landmarks_to_matrix_calculator_test.cc"], + name = "landmarks_to_matrix_calculator_test", + srcs = ["landmarks_to_matrix_calculator_test.cc"], deps = [ - ":hand_landmarks_to_matrix_calculator", + ":landmarks_to_matrix_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", - "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/hand_landmarks_to_matrix_calculator_test.cc b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/hand_landmarks_to_matrix_calculator_test.cc deleted file mode 100644 index f8d1b5116..000000000 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/hand_landmarks_to_matrix_calculator_test.cc +++ /dev/null @@ -1,163 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "absl/memory/memory.h" -#include "mediapipe/framework/calculator_framework.h" -#include "mediapipe/framework/calculator_runner.h" -#include "mediapipe/framework/formats/landmark.pb.h" -#include "mediapipe/framework/formats/matrix.h" -#include "mediapipe/framework/port/gtest.h" -#include "mediapipe/framework/port/parse_text_proto.h" -#include "mediapipe/framework/port/status_matchers.h" - -namespace mediapipe { -namespace tasks { -namespace vision { - -namespace { - -constexpr char kHandLandmarksTag[] = "HAND_LANDMARKS"; -constexpr char kHandWorldLandmarksTag[] = "HAND_WORLD_LANDMARKS"; -constexpr char kImageSizeTag[] = "IMAGE_SIZE"; -constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX"; -constexpr char kNumHandLandmarks = 21; - -template -LandmarkListT BuildPseudoHandLandmarks(int offset = 0) { - LandmarkListT landmarks; - for (int i = 0; i < kNumHandLandmarks; ++i) { - auto* landmark = landmarks.add_landmark(); - landmark->set_x((offset + i) * 0.01 + 0.001); - landmark->set_y((offset + i) * 0.01 + 0.002); - landmark->set_z((offset + i) * 0.01 + 0.003); - } - return landmarks; -} - -struct HandLandmarks2dToMatrixCalculatorTestCase { - std::string test_name; - int hand_offset; -}; - -using HandLandmarks2dToMatrixCalculatorTest = - testing::TestWithParam; - -TEST_P(HandLandmarks2dToMatrixCalculatorTest, OutputsCorrectResult) { - const HandLandmarks2dToMatrixCalculatorTestCase& test_case = GetParam(); - - auto node_config = ParseTextProtoOrDie( - R"pb( - calculator: "HandLandmarksToMatrixCalculator" - input_stream: "HAND_LANDMARKS:hand_landmarks" - input_stream: "IMAGE_SIZE:image_size" - output_stream: "LANDMARKS_MATRIX:landmarks_matrix" - )pb"); - CalculatorRunner runner(node_config); - - auto hand_landmarks = std::make_unique(); - *hand_landmarks = - BuildPseudoHandLandmarks(test_case.hand_offset); - - runner.MutableInputs() - ->Tag(kHandLandmarksTag) - .packets.push_back(Adopt(hand_landmarks.release()).At(Timestamp(0))); - auto image_size = std::make_unique>(640, 480); - runner.MutableInputs() - ->Tag(kImageSizeTag) - .packets.push_back(Adopt(image_size.release()).At(Timestamp(0))); - - MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - - const auto hand = - runner.Outputs().Tag(kLandmarksMatrixTag).packets[0].Get(); - ASSERT_EQ(21, hand.cols()); - ASSERT_EQ(3, hand.rows()); - EXPECT_NEAR(hand(0, 2), 0.1f, 0.001f); - EXPECT_NEAR(hand(1, 5), 0.1875f, 0.001f); -} - -INSTANTIATE_TEST_CASE_P( - HandLandmarksToMatrixCalculatorTests, HandLandmarks2dToMatrixCalculatorTest, - testing::ValuesIn( - {{.test_name = "TestWithHandOffset0", .hand_offset = 0}, - {.test_name = "TestWithHandOffset21", .hand_offset = 21}}), - [](const testing::TestParamInfo< - HandLandmarks2dToMatrixCalculatorTest::ParamType>& info) { - return info.param.test_name; - }); - -struct HandLandmarksWorld3dToMatrixCalculatorTestCase { - std::string test_name; - int hand_offset; -}; - -using HandLandmarksWorld3dToMatrixCalculatorTest = - testing::TestWithParam; - -TEST_P(HandLandmarksWorld3dToMatrixCalculatorTest, OutputsCorrectResult) { - const HandLandmarksWorld3dToMatrixCalculatorTestCase& test_case = GetParam(); - - auto node_config = ParseTextProtoOrDie( - R"pb( - calculator: "HandLandmarksToMatrixCalculator" - input_stream: "HAND_WORLD_LANDMARKS:hand_landmarks" - input_stream: "IMAGE_SIZE:image_size" - output_stream: "LANDMARKS_MATRIX:landmarks_matrix" - )pb"); - CalculatorRunner runner(node_config); - - auto hand_landmarks = std::make_unique(); - *hand_landmarks = - BuildPseudoHandLandmarks(test_case.hand_offset); - - runner.MutableInputs() - ->Tag(kHandWorldLandmarksTag) - .packets.push_back(Adopt(hand_landmarks.release()).At(Timestamp(0))); - auto image_size = std::make_unique>(640, 480); - runner.MutableInputs() - ->Tag(kImageSizeTag) - .packets.push_back(Adopt(image_size.release()).At(Timestamp(0))); - - MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; - - const auto hand = - runner.Outputs().Tag(kLandmarksMatrixTag).packets[0].Get(); - ASSERT_EQ(21, hand.cols()); - ASSERT_EQ(3, hand.rows()); - EXPECT_NEAR(hand(0, 2), 0.1f, 0.001f); - EXPECT_NEAR(hand(1, 5), 0.25f, 0.001f); -} - -INSTANTIATE_TEST_CASE_P( - HandLandmarksToMatrixCalculatorTests, - HandLandmarksWorld3dToMatrixCalculatorTest, - testing::ValuesIn( - {{.test_name = "TestWithHandOffset0", .hand_offset = 0}, - {.test_name = "TestWithHandOffset21", .hand_offset = 21}}), - [](const testing::TestParamInfo< - HandLandmarksWorld3dToMatrixCalculatorTest::ParamType>& info) { - return info.param.test_name; - }); - -} // namespace - -} // namespace vision -} // namespace tasks -} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/hand_landmarks_to_matrix_calculator.cc b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc similarity index 57% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/hand_landmarks_to_matrix_calculator.cc rename to mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc index 20add83cf..990e99920 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/hand_landmarks_to_matrix_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include +#include #include -#include -#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -27,15 +27,18 @@ limitations under the License. #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.pb.h" namespace mediapipe { namespace tasks { namespace vision { +using proto::LandmarksToMatrixCalculatorOptions; + namespace { -constexpr char kHandLandmarksTag[] = "HAND_LANDMARKS"; -constexpr char kHandWorldLandmarksTag[] = "HAND_WORLD_LANDMARKS"; +constexpr char kLandmarksTag[] = "LANDMARKS"; +constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX"; constexpr int kFeaturesPerLandmark = 3; @@ -62,14 +65,14 @@ absl::StatusOr NormalizeLandmarkAspectRatio( } template -absl::StatusOr CanonicalizeOffsetAndScale( - const LandmarkListT& landmarks) { +absl::StatusOr NormalizeObject(const LandmarkListT& landmarks, + int origin_offset) { if (landmarks.landmark_size() == 0) { return ::absl::InvalidArgumentError( "Expected non-zero number of input landmarks."); } LandmarkListT canonicalized_landmarks; - const auto& wrist = landmarks.landmark(0); + const auto& origin = landmarks.landmark(origin_offset); float min_x = std::numeric_limits::max(); float max_x = std::numeric_limits::min(); float min_y = std::numeric_limits::max(); @@ -77,9 +80,9 @@ absl::StatusOr CanonicalizeOffsetAndScale( for (int i = 0; i < landmarks.landmark_size(); ++i) { const auto& old_landmark = landmarks.landmark(i); auto* new_landmark = canonicalized_landmarks.add_landmark(); - new_landmark->set_x(old_landmark.x() - wrist.x()); - new_landmark->set_y(old_landmark.y() - wrist.y()); - new_landmark->set_z(old_landmark.z() - wrist.z()); + new_landmark->set_x(old_landmark.x() - origin.x()); + new_landmark->set_y(old_landmark.y() - origin.y()); + new_landmark->set_z(old_landmark.z() - origin.z()); min_x = std::min(min_x, new_landmark->x()); max_x = std::max(max_x, new_landmark->x()); min_y = std::min(min_y, new_landmark->y()); @@ -107,23 +110,42 @@ Matrix LandmarksToMatrix(const LandmarkListT& landmarks) { return matrix; } -template -absl::Status ProcessLandmarks(LandmarkListT hand_landmarks, bool is_normalized, - CalculatorContext* cc) { - const bool normalize_wrt_aspect_ratio = - is_normalized && !cc->Inputs().Tag(kImageSizeTag).IsEmpty(); +template +struct DependentFalse : std::false_type {}; - if (normalize_wrt_aspect_ratio) { +template +bool IsNormalized() { + if constexpr (std::is_same_v) { + return true; + } else if constexpr (std::is_same_v) { + return false; + } else { + static_assert(DependentFalse::value, + "Type is not supported."); + } +} + +template +absl::Status ProcessLandmarks(LandmarkListT landmarks, CalculatorContext* cc) { + if (IsNormalized()) { + RET_CHECK(cc->Inputs().HasTag(kImageSizeTag) && + !cc->Inputs().Tag(kImageSizeTag).IsEmpty()); const auto [width, height] = cc->Inputs().Tag(kImageSizeTag).Get>(); - ASSIGN_OR_RETURN(hand_landmarks, NormalizeLandmarkAspectRatio( - hand_landmarks, width, height)); + ASSIGN_OR_RETURN(landmarks, + NormalizeLandmarkAspectRatio(landmarks, width, height)); + } + + const auto& options = cc->Options(); + if (options.object_normalization()) { + ASSIGN_OR_RETURN( + landmarks, + NormalizeObject(landmarks, + options.object_normalization_origin_offset())); } - ASSIGN_OR_RETURN(auto canonicalized_landmarks, - CanonicalizeOffsetAndScale(hand_landmarks)); auto landmarks_matrix = std::make_unique(); - *landmarks_matrix = LandmarksToMatrix(canonicalized_landmarks); + *landmarks_matrix = LandmarksToMatrix(landmarks); cc->Outputs() .Tag(kLandmarksMatrixTag) .Add(landmarks_matrix.release(), cc->InputTimestamp()); @@ -132,34 +154,38 @@ absl::Status ProcessLandmarks(LandmarkListT hand_landmarks, bool is_normalized, } // namespace -// Convert single hand landmarks into a matrix. The landmarks are normalized -// w.r.t. the image's aspect ratio and w.r.t the wrist. This pre-processing step -// is required for the hand gesture recognition model. +// Convert landmarks into a matrix. The landmarks are normalized +// w.r.t. the image's aspect ratio (if they are NormalizedLandmarksList) +// and optionally w.r.t and "origin" landmark. This pre-processing step +// is required for the some models. // // Input: -// HAND_LANDMARKS - Single hand landmarks. Use *either* HAND_LANDMARKS or -// HAND_WORLD_LANDMARKS. -// HAND_WORLD_LANDMARKS - Single hand world 3d landmarks. Use *either* -// HAND_LANDMARKS or HAND_WORLD_LANDMARKS. +// LANDMARKS - Landmarks of one object. Use *either* LANDMARKS or +// WORLD_LANDMARKS. +// WORLD_LANDMARKS - World 3d landmarks of one object. Use *either* +// LANDMARKS or WORLD_LANDMARKS. // IMAGE_SIZE - (width, height) of the image // Output: -// LANDMARKS_MATRIX - Matrix for hand landmarks. +// LANDMARKS_MATRIX - Matrix for the landmarks. // // Usage example: // node { -// calculator: "HandLandmarksToMatrixCalculator" -// input_stream: "HAND_LANDMARKS:hand_landmarks" +// calculator: "LandmarksToMatrixCalculator" +// input_stream: "LANDMARKS:hand_landmarks" // input_stream: "IMAGE_SIZE:image_size" // output_stream: "LANDMARKS_MATRIX:landmarks_matrix" +// options { +// [mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions.ext] { +// object_normalization: true +// object_normalization_origin_offset: 0 +// } +// } // } -class HandLandmarksToMatrixCalculator : public CalculatorBase { +class LandmarksToMatrixCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - cc->Inputs() - .Tag(kHandLandmarksTag) - .Set() - .Optional(); - cc->Inputs().Tag(kHandWorldLandmarksTag).Set().Optional(); + cc->Inputs().Tag(kLandmarksTag).Set().Optional(); + cc->Inputs().Tag(kWorldLandmarksTag).Set().Optional(); cc->Inputs().Tag(kImageSizeTag).Set>().Optional(); cc->Outputs().Tag(kLandmarksMatrixTag).Set(); return absl::OkStatus(); @@ -167,28 +193,29 @@ class HandLandmarksToMatrixCalculator : public CalculatorBase { absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - RET_CHECK(cc->Inputs().HasTag(kHandLandmarksTag) ^ - cc->Inputs().HasTag(kHandWorldLandmarksTag)); + RET_CHECK(cc->Inputs().HasTag(kLandmarksTag) ^ + cc->Inputs().HasTag(kWorldLandmarksTag)); + const auto& options = cc->Options(); + RET_CHECK(options.has_object_normalization()); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override; }; -REGISTER_CALCULATOR(HandLandmarksToMatrixCalculator); +REGISTER_CALCULATOR(LandmarksToMatrixCalculator); -absl::Status HandLandmarksToMatrixCalculator::Process(CalculatorContext* cc) { - if (cc->Inputs().HasTag(kHandLandmarksTag)) { - if (!cc->Inputs().Tag(kHandLandmarksTag).IsEmpty()) { - auto hand_landmarks = - cc->Inputs().Tag(kHandLandmarksTag).Get(); - return ProcessLandmarks(hand_landmarks, /*is_normalized=*/true, cc); +absl::Status LandmarksToMatrixCalculator::Process(CalculatorContext* cc) { + if (cc->Inputs().HasTag(kLandmarksTag)) { + if (!cc->Inputs().Tag(kLandmarksTag).IsEmpty()) { + auto landmarks = + cc->Inputs().Tag(kLandmarksTag).Get(); + return ProcessLandmarks(landmarks, cc); } - } else if (cc->Inputs().HasTag(kHandWorldLandmarksTag)) { - if (!cc->Inputs().Tag(kHandWorldLandmarksTag).IsEmpty()) { - auto hand_landmarks = - cc->Inputs().Tag(kHandWorldLandmarksTag).Get(); - return ProcessLandmarks(hand_landmarks, /*is_normalized=*/false, cc); + } else if (cc->Inputs().HasTag(kWorldLandmarksTag)) { + if (!cc->Inputs().Tag(kWorldLandmarksTag).IsEmpty()) { + auto landmarks = cc->Inputs().Tag(kWorldLandmarksTag).Get(); + return ProcessLandmarks(landmarks, cc); } } return absl::OkStatus(); diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc new file mode 100644 index 000000000..05d238f66 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc @@ -0,0 +1,207 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace tasks { +namespace vision { + +namespace { + +constexpr char kLandmarksTag[] = "LANDMARKS"; +constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX"; + +template +LandmarkListT BuildPseudoLandmarks(int num_landmarks, int offset = 0) { + LandmarkListT landmarks; + for (int i = 0; i < num_landmarks; ++i) { + auto* landmark = landmarks.add_landmark(); + landmark->set_x((offset + i) * 0.01 + 0.001); + landmark->set_y((offset + i) * 0.01 + 0.002); + landmark->set_z((offset + i) * 0.01 + 0.003); + } + return landmarks; +} + +struct Landmarks2dToMatrixCalculatorTestCase { + std::string test_name; + int base_offset; + int object_normalization_origin_offset = -1; + float expected_cell_0_2; + float expected_cell_1_5; +}; + +using Landmarks2dToMatrixCalculatorTest = + testing::TestWithParam; + +TEST_P(Landmarks2dToMatrixCalculatorTest, OutputsCorrectResult) { + const Landmarks2dToMatrixCalculatorTestCase& test_case = GetParam(); + + auto node_config = + ParseTextProtoOrDie(absl::Substitute( + R"pb( + calculator: "LandmarksToMatrixCalculator" + input_stream: "LANDMARKS:landmarks" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "LANDMARKS_MATRIX:landmarks_matrix" + options { + [mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions + .ext] { + object_normalization: $0 + object_normalization_origin_offset: $1 + } + } + )pb", + test_case.object_normalization_origin_offset >= 0 ? "true" : "false", + test_case.object_normalization_origin_offset)); + CalculatorRunner runner(node_config); + + auto landmarks = std::make_unique(); + *landmarks = + BuildPseudoLandmarks(21, test_case.base_offset); + + runner.MutableInputs() + ->Tag(kLandmarksTag) + .packets.push_back(Adopt(landmarks.release()).At(Timestamp(0))); + auto image_size = std::make_unique>(640, 480); + runner.MutableInputs() + ->Tag(kImageSizeTag) + .packets.push_back(Adopt(image_size.release()).At(Timestamp(0))); + + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + + const auto matrix = + runner.Outputs().Tag(kLandmarksMatrixTag).packets[0].Get(); + ASSERT_EQ(21, matrix.cols()); + ASSERT_EQ(3, matrix.rows()); + EXPECT_NEAR(matrix(0, 2), test_case.expected_cell_0_2, 1e-4f); + EXPECT_NEAR(matrix(1, 5), test_case.expected_cell_1_5, 1e-4f); +} + +INSTANTIATE_TEST_CASE_P( + LandmarksToMatrixCalculatorTests, Landmarks2dToMatrixCalculatorTest, + testing::ValuesIn( + {{.test_name = "TestWithOffset0", + .base_offset = 0, + .object_normalization_origin_offset = 0, + .expected_cell_0_2 = 0.1f, + .expected_cell_1_5 = 0.1875f}, + {.test_name = "TestWithOffset21", + .base_offset = 21, + .object_normalization_origin_offset = 0, + .expected_cell_0_2 = 0.1f, + .expected_cell_1_5 = 0.1875f}}), + [](const testing::TestParamInfo< + Landmarks2dToMatrixCalculatorTest::ParamType>& info) { + return info.param.test_name; + }); + +struct LandmarksWorld3dToMatrixCalculatorTestCase { + std::string test_name; + int base_offset; + int object_normalization_origin_offset = -1; + float expected_cell_0_2; + float expected_cell_1_5; +}; + +using LandmarksWorld3dToMatrixCalculatorTest = + testing::TestWithParam; + +TEST_P(LandmarksWorld3dToMatrixCalculatorTest, OutputsCorrectResult) { + const LandmarksWorld3dToMatrixCalculatorTestCase& test_case = GetParam(); + + auto node_config = + ParseTextProtoOrDie(absl::Substitute( + R"pb( + calculator: "LandmarksToMatrixCalculator" + input_stream: "WORLD_LANDMARKS:landmarks" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "LANDMARKS_MATRIX:landmarks_matrix" + options { + [mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions + .ext] { + object_normalization: $0 + object_normalization_origin_offset: $1 + } + } + )pb", + test_case.object_normalization_origin_offset >= 0 ? "true" : "false", + test_case.object_normalization_origin_offset)); + CalculatorRunner runner(node_config); + + auto landmarks = std::make_unique(); + *landmarks = BuildPseudoLandmarks(21, test_case.base_offset); + + runner.MutableInputs() + ->Tag(kWorldLandmarksTag) + .packets.push_back(Adopt(landmarks.release()).At(Timestamp(0))); + auto image_size = std::make_unique>(640, 480); + runner.MutableInputs() + ->Tag(kImageSizeTag) + .packets.push_back(Adopt(image_size.release()).At(Timestamp(0))); + + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + + const auto matrix = + runner.Outputs().Tag(kLandmarksMatrixTag).packets[0].Get(); + ASSERT_EQ(21, matrix.cols()); + ASSERT_EQ(3, matrix.rows()); + EXPECT_NEAR(matrix(0, 2), test_case.expected_cell_0_2, 1e-4f); + EXPECT_NEAR(matrix(1, 5), test_case.expected_cell_1_5, 1e-4f); +} + +INSTANTIATE_TEST_CASE_P( + LandmarksToMatrixCalculatorTests, LandmarksWorld3dToMatrixCalculatorTest, + testing::ValuesIn( + {{.test_name = "TestWithOffset0", + .base_offset = 0, + .object_normalization_origin_offset = 0, + .expected_cell_0_2 = 0.1f, + .expected_cell_1_5 = 0.25}, + {.test_name = "TestWithOffset21", + .base_offset = 21, + .object_normalization_origin_offset = 0, + .expected_cell_0_2 = 0.1f, + .expected_cell_1_5 = 0.25}, + {.test_name = "NoObjectNormalization", + .base_offset = 0, + .object_normalization_origin_offset = -1, + .expected_cell_0_2 = 0.021f, + .expected_cell_1_5 = 0.052f}}), + [](const testing::TestParamInfo< + LandmarksWorld3dToMatrixCalculatorTest::ParamType>& info) { + return info.param.test_name; + }); + +} // namespace + +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc index 4bdf38da0..e124d3410 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" @@ -50,6 +51,23 @@ using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto:: HandGestureRecognizerSubgraphOptions; +using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions; + +constexpr char kHandednessTag[] = "HANDEDNESS"; +constexpr char kLandmarksTag[] = "LANDMARKS"; +constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS"; +constexpr char kHandGesturesTag[] = "HAND_GESTURES"; +constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX"; +constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kHandednessMatrixTag[] = "HANDEDNESS_MATRIX"; +constexpr char kCloneTag[] = "CLONE"; +constexpr char kItemTag[] = "ITEM"; +constexpr char kVectorTag[] = "VECTOR"; +constexpr char kIndexTag[] = "INDEX"; +constexpr char kIterableTag[] = "ITERABLE"; +constexpr char kBatchEndTag[] = "BATCH_END"; absl::Status SanityCheckOptions( const HandGestureRecognizerSubgraphOptions& options) { @@ -72,8 +90,8 @@ Source> ConvertMatrixToTensor(Source matrix, } // namespace -// A "mediapipe.tasks.vision.HandGestureRecognizerSubgraph" performs single hand -// gesture recognition. This graph is used as a building block for +// A "mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph" performs +// single hand gesture recognition. This graph is used as a building block for // mediapipe.tasks.vision.HandGestureRecognizerGraph. // // Inputs: @@ -94,7 +112,7 @@ Source> ConvertMatrixToTensor(Source matrix, // // Example: // node { -// calculator: "mediapipe.tasks.vision.HandGestureRecognizerSubgraph" +// calculator: "mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph" // input_stream: "HANDEDNESS:handedness" // input_stream: "LANDMARKS:landmarks" // input_stream: "WORLD_LANDMARKS:world_landmarks" @@ -104,12 +122,14 @@ Source> ConvertMatrixToTensor(Source matrix, // [mediapipe.tasks.vision.hand_gesture_recognizer.proto.HandGestureRecognizerSubgraphOptions.ext] // { // base_options { -// model_file: "hand_gesture.tflite" +// model_asset { +// file_name: "hand_gesture.tflite" +// } // } // } // } // } -class HandGestureRecognizerSubgraph : public core::ModelTaskGraph { +class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { @@ -121,11 +141,11 @@ class HandGestureRecognizerSubgraph : public core::ModelTaskGraph { auto hand_gestures, BuildHandGestureRecognizerGraph( sc->Options(), - *model_resources, graph[Input("HANDEDNESS")], - graph[Input("LANDMARKS")], - graph[Input("WORLD_LANDMARKS")], - graph[Input>("IMAGE_SIZE")], graph)); - hand_gestures >> graph[Output("HAND_GESTURES")]; + *model_resources, graph[Input(kHandednessTag)], + graph[Input(kLandmarksTag)], + graph[Input(kWorldLandmarksTag)], + graph[Input>(kImageSizeTag)], graph)); + hand_gestures >> graph[Output(kHandGesturesTag)]; return graph.GetConfig(); } @@ -141,21 +161,26 @@ class HandGestureRecognizerSubgraph : public core::ModelTaskGraph { // Converts the ClassificationList to a matrix. auto& handedness_to_matrix = graph.AddNode("HandednessToMatrixCalculator"); - handedness >> handedness_to_matrix.In("HANDEDNESS"); + handedness >> handedness_to_matrix.In(kHandednessTag); auto handedness_matrix = - handedness_to_matrix[Output("HANDEDNESS_MATRIX")]; + handedness_to_matrix[Output(kHandednessMatrixTag)]; // Converts the handedness matrix to a tensor for the inference // calculator. auto handedness_tensors = ConvertMatrixToTensor(handedness_matrix, graph); // Converts the screen landmarks to a matrix. + LandmarksToMatrixCalculatorOptions landmarks_options; + landmarks_options.set_object_normalization(true); + landmarks_options.set_object_normalization_origin_offset(0); auto& hand_landmarks_to_matrix = - graph.AddNode("HandLandmarksToMatrixCalculator"); - hand_landmarks >> hand_landmarks_to_matrix.In("HAND_LANDMARKS"); - image_size >> hand_landmarks_to_matrix.In("IMAGE_SIZE"); + graph.AddNode("LandmarksToMatrixCalculator"); + hand_landmarks_to_matrix.GetOptions() = + landmarks_options; + hand_landmarks >> hand_landmarks_to_matrix.In(kLandmarksTag); + image_size >> hand_landmarks_to_matrix.In(kImageSizeTag); auto hand_landmarks_matrix = - hand_landmarks_to_matrix[Output("LANDMARKS_MATRIX")]; + hand_landmarks_to_matrix[Output(kLandmarksMatrixTag)]; // Converts the landmarks matrix to a tensor for the inference calculator. auto hand_landmarks_tensor = @@ -163,12 +188,14 @@ class HandGestureRecognizerSubgraph : public core::ModelTaskGraph { // Converts the world landmarks to a matrix. auto& hand_world_landmarks_to_matrix = - graph.AddNode("HandLandmarksToMatrixCalculator"); + graph.AddNode("LandmarksToMatrixCalculator"); + hand_world_landmarks_to_matrix + .GetOptions() = landmarks_options; hand_world_landmarks >> - hand_world_landmarks_to_matrix.In("HAND_WORLD_LANDMARKS"); - image_size >> hand_world_landmarks_to_matrix.In("IMAGE_SIZE"); + hand_world_landmarks_to_matrix.In(kWorldLandmarksTag); + image_size >> hand_world_landmarks_to_matrix.In(kImageSizeTag); auto hand_world_landmarks_matrix = - hand_world_landmarks_to_matrix[Output("LANDMARKS_MATRIX")]; + hand_world_landmarks_to_matrix[Output(kLandmarksMatrixTag)]; // Converts the world landmarks matrix to a tensor for the inference // calculator. @@ -185,20 +212,151 @@ class HandGestureRecognizerSubgraph : public core::ModelTaskGraph { auto concatenated_tensors = concatenate_tensor_vector.Out(""); // Inference for static hand gesture recognition. - auto& inference = AddInference(model_resources, graph); - concatenated_tensors >> inference.In("TENSORS"); - auto inference_output_tensors = inference.Out("TENSORS"); + auto& inference = AddInference( + model_resources, graph_options.base_options().acceleration(), graph); + concatenated_tensors >> inference.In(kTensorsTag); + auto inference_output_tensors = inference.Out(kTensorsTag); - auto& postprocessing = - graph.AddNode("mediapipe.tasks.ClassificationPostprocessingSubgraph"); + auto& postprocessing = graph.AddNode( + "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( model_resources, graph_options.classifier_options(), - &postprocessing.GetOptions())); - inference_output_tensors >> postprocessing.In("TENSORS"); + &postprocessing.GetOptions< + tasks::components::ClassificationPostprocessingOptions>())); + inference_output_tensors >> postprocessing.In(kTensorsTag); auto classification_result = postprocessing[Output("CLASSIFICATION_RESULT")]; - return {classification_result}; + return classification_result; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::SingleHandGestureRecognizerSubgraph); + +// A "mediapipe.tasks.vision.HandGestureRecognizerSubgraph" performs multi +// hand gesture recognition. This graph is used as a building block for +// mediapipe.tasks.vision.HandGestureRecognizerGraph. +// +// Inputs: +// HANDEDNESS - std::vector +// A vector of Classification of handedness. +// LANDMARKS - std::vector +// A vector hand landmarks in normalized image coordinates. +// WORLD_LANDMARKS - std::vector +// A vector hand landmarks in world coordinates. +// IMAGE_SIZE - std::pair +// The size of image from which the landmarks detected from. +// HAND_TRACKING_IDS - std::vector +// A vector of the tracking ids of the hands. The tracking id is the vector +// index corresponding to the same hand if the graph runs multiple times. +// +// Outputs: +// HAND_GESTURES - std::vector +// A vector of recognized hand gestures. Each vector element is the +// ClassificationResult of the hand in input vector. +// +// +// Example: +// node { +// calculator: "mediapipe.tasks.vision.HandGestureRecognizerSubgraph" +// input_stream: "HANDEDNESS:handedness" +// input_stream: "LANDMARKS:landmarks" +// input_stream: "WORLD_LANDMARKS:world_landmarks" +// input_stream: "IMAGE_SIZE:image_size" +// input_stream: "HAND_TRACKING_IDS:hand_tracking_ids" +// output_stream: "HAND_GESTURES:hand_gestures" +// options { +// [mediapipe.tasks.vision.hand_gesture_recognizer.proto.HandGestureRecognizerSubgraph.ext] +// { +// base_options { +// model_asset { +// file_name: "hand_gesture.tflite" +// } +// } +// } +// } +// } +class HandGestureRecognizerSubgraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + ASSIGN_OR_RETURN( + auto multi_hand_gestures, + BuildMultiHandGestureRecognizerSubraph( + sc->Options(), + graph[Input>(kHandednessTag)], + graph[Input>(kLandmarksTag)], + graph[Input>(kWorldLandmarksTag)], + graph[Input>(kImageSizeTag)], + graph[Input>(kHandTrackingIdsTag)], graph)); + multi_hand_gestures >> + graph[Output>(kHandGesturesTag)]; + return graph.GetConfig(); + } + + private: + absl::StatusOr>> + BuildMultiHandGestureRecognizerSubraph( + const HandGestureRecognizerSubgraphOptions& graph_options, + Source> multi_handedness, + Source> multi_hand_landmarks, + Source> multi_hand_world_landmarks, + Source> image_size, + Source> multi_hand_tracking_ids, Graph& graph) { + auto& begin_loop_int = graph.AddNode("BeginLoopIntCalculator"); + image_size >> begin_loop_int.In(kCloneTag)[0]; + multi_handedness >> begin_loop_int.In(kCloneTag)[1]; + multi_hand_landmarks >> begin_loop_int.In(kCloneTag)[2]; + multi_hand_world_landmarks >> begin_loop_int.In(kCloneTag)[3]; + multi_hand_tracking_ids >> begin_loop_int.In(kIterableTag); + auto image_size_clone = begin_loop_int.Out(kCloneTag)[0]; + auto multi_handedness_clone = begin_loop_int.Out(kCloneTag)[1]; + auto multi_hand_landmarks_clone = begin_loop_int.Out(kCloneTag)[2]; + auto multi_hand_world_landmarks_clone = begin_loop_int.Out(kCloneTag)[3]; + auto hand_tracking_id = begin_loop_int.Out(kItemTag); + auto batch_end = begin_loop_int.Out(kBatchEndTag); + + auto& get_handedness_at_index = + graph.AddNode("GetClassificationListVectorItemCalculator"); + multi_handedness_clone >> get_handedness_at_index.In(kVectorTag); + hand_tracking_id >> get_handedness_at_index.In(kIndexTag); + auto handedness = get_handedness_at_index.Out(kItemTag); + + auto& get_landmarks_at_index = + graph.AddNode("GetNormalizedLandmarkListVectorItemCalculator"); + multi_hand_landmarks_clone >> get_landmarks_at_index.In(kVectorTag); + hand_tracking_id >> get_landmarks_at_index.In(kIndexTag); + auto hand_landmarks = get_landmarks_at_index.Out(kItemTag); + + auto& get_world_landmarks_at_index = + graph.AddNode("GetLandmarkListVectorItemCalculator"); + multi_hand_world_landmarks_clone >> + get_world_landmarks_at_index.In(kVectorTag); + hand_tracking_id >> get_world_landmarks_at_index.In(kIndexTag); + auto hand_world_landmarks = get_world_landmarks_at_index.Out(kItemTag); + + auto& hand_gesture_recognizer_subgraph = graph.AddNode( + "mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph"); + hand_gesture_recognizer_subgraph + .GetOptions() + .CopyFrom(graph_options); + handedness >> hand_gesture_recognizer_subgraph.In(kHandednessTag); + hand_landmarks >> hand_gesture_recognizer_subgraph.In(kLandmarksTag); + hand_world_landmarks >> + hand_gesture_recognizer_subgraph.In(kWorldLandmarksTag); + image_size_clone >> hand_gesture_recognizer_subgraph.In(kImageSizeTag); + auto hand_gestures = hand_gesture_recognizer_subgraph.Out(kHandGesturesTag); + + auto& end_loop_classification_results = + graph.AddNode("mediapipe.tasks.EndLoopClassificationResultCalculator"); + batch_end >> end_loop_classification_results.In(kBatchEndTag); + hand_gestures >> end_loop_classification_results.In(kItemTag); + auto multi_hand_gestures = end_loop_classification_results + [Output>(kIterableTag)]; + + return multi_hand_gestures; } }; diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD index 47e220ac8..f3927727e 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD @@ -14,7 +14,9 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) licenses(["notice"]) @@ -24,7 +26,18 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components:classifier_options_proto", + "//mediapipe/tasks/cc/components/proto:classifier_options_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) + +mediapipe_proto_library( + name = "landmarks_to_matrix_calculator_proto", + srcs = ["landmarks_to_matrix_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/proto:classifier_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto index 42f2bbc85..f73443eaf 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto @@ -18,7 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/classifier_options.proto"; +import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; message HandGestureRecognizerSubgraphOptions { @@ -31,7 +31,7 @@ message HandGestureRecognizerSubgraphOptions { // Options for configuring the gesture classifier behavior, such as score // threshold, number of results, etc. - optional ClassifierOptions classifier_options = 2; + optional components.proto.ClassifierOptions classifier_options = 2; // Minimum confidence value ([0.0, 1.0]) for the hand landmarks to be // considered tracked successfully diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.proto b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.proto new file mode 100644 index 000000000..6b004e203 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.proto @@ -0,0 +1,39 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.vision.proto; + +import "mediapipe/framework/calculator.proto"; + +// Options for LandmarksToMatrixCalculator. +message LandmarksToMatrixCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional LandmarksToMatrixCalculatorOptions ext = 473345607; + } + + // Determines whether to perform object normalization. If enabled, the + // normalizes the object so that: + // - max(height, width) of the object is 1 + // - the aspect ratio is preserved + // - the landmark at offset object_normalization_origin_offset within + // the landmarks array is at the origin. + // It is required to set object_normalization to true or false. + optional bool object_normalization = 1; + // The offset within the landmarks list of the landmark to use as origin + // for object normalization. + optional int32 object_normalization_origin_offset = 2 [default = 0]; +} diff --git a/mediapipe/tasks/cc/vision/hand_landmark/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD similarity index 85% rename from mediapipe/tasks/cc/vision/hand_landmark/BUILD rename to mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 5b490124a..653976b96 100644 --- a/mediapipe/tasks/cc/vision/hand_landmark/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -12,29 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") - package(default_visibility = [ "//mediapipe/tasks:internal", ]) licenses(["notice"]) -mediapipe_proto_library( - name = "hand_landmark_detector_options_proto", - srcs = ["hand_landmark_detector_options.proto"], - deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/core/proto:base_options_proto", - ], -) - cc_library( - name = "hand_landmark_detector_graph", - srcs = ["hand_landmark_detector_graph.cc"], + name = "hand_landmarker_subgraph", + srcs = ["hand_landmarker_subgraph.cc"], deps = [ - ":hand_landmark_detector_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_subgraph_options_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "//mediapipe/calculators/core:split_vector_calculator", diff --git a/mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph.cc similarity index 63% rename from mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_graph.cc rename to mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph.cc index c5677cd98..fff4ae0d4 100644 --- a/mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph.cc @@ -39,7 +39,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" -#include "mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/util/label_map.pb.h" @@ -56,6 +56,8 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::vision::hand_landmarker::proto:: + HandLandmarkerSubgraphOptions; using LabelItems = mediapipe::proto_ns::Map; constexpr char kImageTag[] = "IMAGE"; @@ -73,7 +75,7 @@ constexpr int kLandmarksNum = 21; constexpr float kLandmarksNormalizeZ = 0.4; constexpr int kModelOutputTensorSplitNum = 4; -struct HandLandmarkDetectionOuts { +struct SingleHandLandmarkerOutputs { Source hand_landmarks; Source world_hand_landmarks; Source hand_rect_next_frame; @@ -83,7 +85,17 @@ struct HandLandmarkDetectionOuts { Source> image_size; }; -absl::Status SanityCheckOptions(const HandLandmarkDetectorOptions& options) { +struct HandLandmarkerOutputs { + Source> landmark_lists; + Source> world_landmark_lists; + Source> hand_rects_next_frame; + Source> presences; + Source> presence_scores; + Source> handednesses; + Source> image_size; +}; + +absl::Status SanityCheckOptions(const HandLandmarkerSubgraphOptions& options) { if (options.min_detection_confidence() < 0 || options.min_detection_confidence() > 1) { return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, @@ -170,8 +182,8 @@ void ConfigureHandRectTransformationCalculator( } // namespace -// A "mediapipe.tasks.vision.HandLandmarkDetectorGraph" performs hand landmark -// detection. +// A "mediapipe.tasks.vision.SingleHandLandmarkerSubgraph" performs hand +// landmark detection. // - Accepts CPU input images and outputs Landmark on CPU. // // Inputs: @@ -196,11 +208,13 @@ void ConfigureHandRectTransformationCalculator( // Float value indicates the probability that the hand is present. // HANDEDNESS - ClassificationList // Classification of handedness. +// IMAGE_SIZE - std::vector +// The size of input image. // // Example: // node { -// calculator: "mediapipe.tasks.vision.HandLandmarkDetectorGraph" -// input_stream: "IMAGE:input_video" +// calculator: "mediapipe.tasks.vision.SingleHandLandmarkerSubgraph" +// input_stream: "IMAGE:input_image" // input_stream: "HAND_RECT:hand_rect" // output_stream: "LANDMARKS:hand_landmarks" // output_stream: "WORLD_LANDMARKS:world_hand_landmarks" @@ -208,10 +222,12 @@ void ConfigureHandRectTransformationCalculator( // output_stream: "PRESENCE:hand_presence" // output_stream: "PRESENCE_SCORE:hand_presence_score" // output_stream: "HANDEDNESS:handedness" +// output_stream: "IMAGE_SIZE:image_size" // options { -// [mediapipe.tasks.HandLandmarkDetectorGraph.ext] { +// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarkerSubgraphOptions.ext] +// { // base_options { -// model_file { +// model_asset { // file_name: "hand_landmark_lite.tflite" // } // } @@ -219,16 +235,16 @@ void ConfigureHandRectTransformationCalculator( // } // } // } -class HandLandmarkDetectorGraph : public core::ModelTaskGraph { +class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { ASSIGN_OR_RETURN(const auto* model_resources, - CreateModelResources(sc)); + CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN(auto hand_landmark_detection_outs, - BuildHandLandmarkDetectionSubgraph( - sc->Options(), + BuildSingleHandLandmarkerSubgraph( + sc->Options(), *model_resources, graph[Input(kImageTag)], graph[Input(kHandRectTag)], graph)); hand_landmark_detection_outs.hand_landmarks >> @@ -253,23 +269,24 @@ class HandLandmarkDetectorGraph : public core::ModelTaskGraph { // Adds a mediapipe hand landmark detection graph into the provided // builder::Graph instance. // - // subgraph_options: the mediapipe tasks module HandLandmarkDetectorOptions. - // model_resources: the ModelSources object initialized from an hand landmark - // detection model file with model metadata. + // subgraph_options: the mediapipe tasks module HandLandmarkerSubgraphOptions. + // model_resources: the ModelSources object initialized from a hand landmark + // detection model file with model metadata. // image_in: (mediapipe::Image) stream to run hand landmark detection on. // rect: (NormalizedRect) stream to run on the RoI of image. - // graph: the mediapipe builder::Graph instance to be updated. - absl::StatusOr BuildHandLandmarkDetectionSubgraph( - const HandLandmarkDetectorOptions& subgraph_options, + // graph: the mediapipe graph instance to be updated. + absl::StatusOr BuildSingleHandLandmarkerSubgraph( + const HandLandmarkerSubgraphOptions& subgraph_options, const core::ModelResources& model_resources, Source image_in, Source hand_rect, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options)); auto& preprocessing = - graph.AddNode("mediapipe.tasks.ImagePreprocessingSubgraph"); + graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( model_resources, - &preprocessing.GetOptions())); + &preprocessing + .GetOptions())); image_in >> preprocessing.In("IMAGE"); hand_rect >> preprocessing.In("NORM_RECT"); auto image_size = preprocessing[Output>("IMAGE_SIZE")]; @@ -277,7 +294,8 @@ class HandLandmarkDetectorGraph : public core::ModelTaskGraph { ASSIGN_OR_RETURN(auto image_tensor_specs, BuildImageTensorSpecs(model_resources)); - auto& inference = AddInference(model_resources, graph); + auto& inference = AddInference( + model_resources, subgraph_options.base_options().acceleration(), graph); preprocessing.Out("TENSORS") >> inference.In("TENSORS"); // Split model output tensors to multiple streams. @@ -356,7 +374,7 @@ class HandLandmarkDetectorGraph : public core::ModelTaskGraph { auto projected_landmarks = landmark_projection[Output("NORM_LANDMARKS")]; - // Projects the world landmarks from the cropped pose image to the + // Projects the world landmarks from the cropped hand image to the // corresponding locations on the full image before cropping (input to the // graph). auto& world_landmark_projection = @@ -399,7 +417,180 @@ class HandLandmarkDetectorGraph : public core::ModelTaskGraph { } }; -REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::HandLandmarkDetectorGraph); +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::SingleHandLandmarkerSubgraph); + +// A "mediapipe.tasks.vision.HandLandmarkerSubgraph" performs multi +// hand landmark detection. +// - Accepts CPU input image and a vector of hand rect RoIs to detect the +// multiple hands landmarks enclosed by the RoIs. Output vectors of +// hand landmarks related results, where each element in the vectors +// corrresponds to the result of the same hand. +// +// Inputs: +// IMAGE - Image +// Image to perform detection on. +// HAND_RECTS - std::vector +// A vector of multiple hand rects enclosing the hand RoI to perform +// landmarks detection on. +// +// +// Outputs: +// LANDMARKS: - std::vector +// Vector of detected hand landmarks. +// WORLD_LANDMARKS - std::vector +// Vector of detected hand landmarks in world coordinates. +// HAND_RECT_NEXT_FRAME - std::vector +// Vector of the predicted rects enclosing the same hand RoI for landmark +// detection on the next frame. +// PRESENCE - std::vector +// Vector of boolean value indicates whether the hand is present. +// PRESENCE_SCORE - std::vector +// Vector of float value indicates the probability that the hand is present. +// HANDEDNESS - std::vector +// Vector of classification of handedness. +// IMAGE_SIZE - std::vector +// The size of input image. +// +// Example: +// node { +// calculator: "mediapipe.tasks.vision.HandLandmarkerSubgraph" +// input_stream: "IMAGE:input_image" +// input_stream: "HAND_RECT:hand_rect" +// output_stream: "LANDMARKS:hand_landmarks" +// output_stream: "WORLD_LANDMARKS:world_hand_landmarks" +// output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame" +// output_stream: "PRESENCE:hand_presence" +// output_stream: "PRESENCE_SCORE:hand_presence_score" +// output_stream: "HANDEDNESS:handedness" +// output_stream: "IMAGE_SIZE:image_size" +// options { +// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarkerSubgraphOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "hand_landmark_lite.tflite" +// } +// } +// min_detection_confidence: 0.5 +// } +// } +// } +class HandLandmarkerSubgraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + ASSIGN_OR_RETURN( + auto hand_landmark_detection_outputs, + BuildHandLandmarkerSubgraph( + sc->Options(), + graph[Input(kImageTag)], + graph[Input>(kHandRectTag)], graph)); + hand_landmark_detection_outputs.landmark_lists >> + graph[Output>(kLandmarksTag)]; + hand_landmark_detection_outputs.world_landmark_lists >> + graph[Output>(kWorldLandmarksTag)]; + hand_landmark_detection_outputs.hand_rects_next_frame >> + graph[Output>(kHandRectNextFrameTag)]; + hand_landmark_detection_outputs.presences >> + graph[Output>(kPresenceTag)]; + hand_landmark_detection_outputs.presence_scores >> + graph[Output>(kPresenceScoreTag)]; + hand_landmark_detection_outputs.handednesses >> + graph[Output>(kHandednessTag)]; + hand_landmark_detection_outputs.image_size >> + graph[Output>(kImageSizeTag)]; + + return graph.GetConfig(); + } + + private: + absl::StatusOr BuildHandLandmarkerSubgraph( + const HandLandmarkerSubgraphOptions& subgraph_options, + Source image_in, + Source> multi_hand_rects, Graph& graph) { + auto& hand_landmark_subgraph = + graph.AddNode("mediapipe.tasks.vision.SingleHandLandmarkerSubgraph"); + hand_landmark_subgraph.GetOptions().CopyFrom( + subgraph_options); + + auto& begin_loop_multi_hand_rects = + graph.AddNode("BeginLoopNormalizedRectCalculator"); + + image_in >> begin_loop_multi_hand_rects.In("CLONE"); + multi_hand_rects >> begin_loop_multi_hand_rects.In("ITERABLE"); + auto batch_end = begin_loop_multi_hand_rects.Out("BATCH_END"); + auto image = begin_loop_multi_hand_rects.Out("CLONE"); + auto hand_rect = begin_loop_multi_hand_rects.Out("ITEM"); + + image >> hand_landmark_subgraph.In("IMAGE"); + hand_rect >> hand_landmark_subgraph.In("HAND_RECT"); + auto handedness = hand_landmark_subgraph.Out("HANDEDNESS"); + auto presence = hand_landmark_subgraph.Out("PRESENCE"); + auto presence_score = hand_landmark_subgraph.Out("PRESENCE_SCORE"); + auto hand_rect_next_frame = + hand_landmark_subgraph.Out("HAND_RECT_NEXT_FRAME"); + auto landmarks = hand_landmark_subgraph.Out("LANDMARKS"); + auto world_landmarks = hand_landmark_subgraph.Out("WORLD_LANDMARKS"); + auto image_size = + hand_landmark_subgraph[Output>("IMAGE_SIZE")]; + + auto& end_loop_handedness = + graph.AddNode("EndLoopClassificationListCalculator"); + batch_end >> end_loop_handedness.In("BATCH_END"); + handedness >> end_loop_handedness.In("ITEM"); + auto handednesses = + end_loop_handedness[Output>( + "ITERABLE")]; + + auto& end_loop_presence = graph.AddNode("EndLoopBooleanCalculator"); + batch_end >> end_loop_presence.In("BATCH_END"); + presence >> end_loop_presence.In("ITEM"); + auto presences = end_loop_presence[Output>("ITERABLE")]; + + auto& end_loop_presence_score = graph.AddNode("EndLoopFloatCalculator"); + batch_end >> end_loop_presence_score.In("BATCH_END"); + presence_score >> end_loop_presence_score.In("ITEM"); + auto presence_scores = + end_loop_presence_score[Output>("ITERABLE")]; + + auto& end_loop_landmarks = + graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator"); + batch_end >> end_loop_landmarks.In("BATCH_END"); + landmarks >> end_loop_landmarks.In("ITEM"); + auto landmark_lists = + end_loop_landmarks[Output>( + "ITERABLE")]; + + auto& end_loop_world_landmarks = + graph.AddNode("EndLoopLandmarkListVectorCalculator"); + batch_end >> end_loop_world_landmarks.In("BATCH_END"); + world_landmarks >> end_loop_world_landmarks.In("ITEM"); + auto world_landmark_lists = + end_loop_world_landmarks[Output>("ITERABLE")]; + + auto& end_loop_rects_next_frame = + graph.AddNode("EndLoopNormalizedRectCalculator"); + batch_end >> end_loop_rects_next_frame.In("BATCH_END"); + hand_rect_next_frame >> end_loop_rects_next_frame.In("ITEM"); + auto hand_rects_next_frame = + end_loop_rects_next_frame[Output>( + "ITERABLE")]; + + return {{ + /* landmark_lists= */ landmark_lists, + /* world_landmark_lists= */ world_landmark_lists, + /* hand_rects_next_frame= */ hand_rects_next_frame, + /* presences= */ presences, + /* presence_scores= */ presence_scores, + /* handednesses= */ handednesses, + /* image_size= */ image_size, + }}; + } +}; + +REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::HandLandmarkerSubgraph); } // namespace vision } // namespace tasks diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph_test.cc new file mode 100644 index 000000000..1c2bc6da7 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph_test.cc @@ -0,0 +1,467 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace { + +using ::file::Defaults; +using ::file::GetTextProto; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::vision::DecodeImageFromFile; +using ::mediapipe::tasks::vision::hand_landmarker::proto:: + HandLandmarkerSubgraphOptions; +using ::testing::ElementsAreArray; +using ::testing::EqualsProto; +using ::testing::Pointwise; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kHandLandmarkerLiteModel[] = "hand_landmark_lite.tflite"; +constexpr char kHandLandmarkerFullModel[] = "hand_landmark_full.tflite"; +constexpr char kRightHandsImage[] = "right_hands.jpg"; +constexpr char kLeftHandsImage[] = "left_hands.jpg"; + +constexpr char kTestModelResourcesTag[] = "test_model_resources"; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageName[] = "image_in"; +constexpr char kHandRectTag[] = "HAND_RECT"; +constexpr char kHandRectName[] = "hand_rect_in"; + +constexpr char kLandmarksTag[] = "LANDMARKS"; +constexpr char kLandmarksName[] = "landmarks"; +constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; +constexpr char kWorldLandmarksName[] = "world_landmarks"; +constexpr char kHandRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME"; +constexpr char kHandRectNextFrameName[] = "hand_rect_next_frame"; +constexpr char kPresenceTag[] = "PRESENCE"; +constexpr char kPresenceName[] = "presence"; +constexpr char kPresenceScoreTag[] = "PRESENCE_SCORE"; +constexpr char kPresenceScoreName[] = "presence_score"; +constexpr char kHandednessTag[] = "HANDEDNESS"; +constexpr char kHandednessName[] = "handedness"; + +// Expected hand landmarks positions, in text proto format. +constexpr char kExpectedRightUpHandLandmarksFilename[] = + "expected_right_up_hand_landmarks.prototxt"; +constexpr char kExpectedRightDownHandLandmarksFilename[] = + "expected_right_down_hand_landmarks.prototxt"; +constexpr char kExpectedLeftUpHandLandmarksFilename[] = + "expected_left_up_hand_landmarks.prototxt"; +constexpr char kExpectedLeftDownHandLandmarksFilename[] = + "expected_left_down_hand_landmarks.prototxt"; + +constexpr float kLiteModelFractionDiff = 0.05; // percentage +constexpr float kFullModelFractionDiff = 0.03; // percentage +constexpr float kAbsMargin = 0.03; + +// Helper function to create a Single Hand Landmark TaskRunner. +absl::StatusOr> CreateSingleHandTaskRunner( + absl::string_view model_name) { + Graph graph; + + auto& hand_landmark_detection = + graph.AddNode("mediapipe.tasks.vision.SingleHandLandmarkerSubgraph"); + + auto options = std::make_unique(); + options->mutable_base_options()->mutable_model_asset()->set_file_name( + JoinPath("./", kTestDataDirectory, model_name)); + hand_landmark_detection.GetOptions().Swap( + options.get()); + + graph[Input(kImageTag)].SetName(kImageName) >> + hand_landmark_detection.In(kImageTag); + graph[Input(kHandRectTag)].SetName(kHandRectName) >> + hand_landmark_detection.In(kHandRectTag); + + hand_landmark_detection.Out(kLandmarksTag).SetName(kLandmarksName) >> + graph[Output(kLandmarksTag)]; + hand_landmark_detection.Out(kWorldLandmarksTag) + .SetName(kWorldLandmarksName) >> + graph[Output(kWorldLandmarksTag)]; + hand_landmark_detection.Out(kPresenceTag).SetName(kPresenceName) >> + graph[Output(kPresenceTag)]; + hand_landmark_detection.Out(kPresenceScoreTag).SetName(kPresenceScoreName) >> + graph[Output(kPresenceScoreTag)]; + hand_landmark_detection.Out(kHandednessTag).SetName(kHandednessName) >> + graph[Output(kHandednessTag)]; + hand_landmark_detection.Out(kHandRectNextFrameTag) + .SetName(kHandRectNextFrameName) >> + graph[Output(kHandRectNextFrameTag)]; + + return TaskRunner::Create( + graph.GetConfig(), + absl::make_unique()); +} + +// Helper function to create a Multi Hand Landmark TaskRunner. +absl::StatusOr> CreateMultiHandTaskRunner( + absl::string_view model_name) { + Graph graph; + + auto& multi_hand_landmark_detection = + graph.AddNode("mediapipe.tasks.vision.HandLandmarkerSubgraph"); + + auto options = std::make_unique(); + options->mutable_base_options()->mutable_model_asset()->set_file_name( + JoinPath("./", kTestDataDirectory, model_name)); + multi_hand_landmark_detection.GetOptions() + .Swap(options.get()); + + graph[Input(kImageTag)].SetName(kImageName) >> + multi_hand_landmark_detection.In(kImageTag); + graph[Input>(kHandRectTag)].SetName( + kHandRectName) >> + multi_hand_landmark_detection.In(kHandRectTag); + + multi_hand_landmark_detection.Out(kLandmarksTag).SetName(kLandmarksName) >> + graph[Output>(kLandmarksTag)]; + multi_hand_landmark_detection.Out(kWorldLandmarksTag) + .SetName(kWorldLandmarksName) >> + graph[Output>(kWorldLandmarksTag)]; + multi_hand_landmark_detection.Out(kPresenceTag).SetName(kPresenceName) >> + graph[Output>(kPresenceTag)]; + multi_hand_landmark_detection.Out(kPresenceScoreTag) + .SetName(kPresenceScoreName) >> + graph[Output>(kPresenceScoreTag)]; + multi_hand_landmark_detection.Out(kHandednessTag).SetName(kHandednessName) >> + graph[Output>(kHandednessTag)]; + multi_hand_landmark_detection.Out(kHandRectNextFrameTag) + .SetName(kHandRectNextFrameName) >> + graph[Output>(kHandRectNextFrameTag)]; + + return TaskRunner::Create( + graph.GetConfig(), + absl::make_unique()); +} + +NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) { + NormalizedLandmarkList expected_landmark_list; + MP_EXPECT_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, filename), + &expected_landmark_list, Defaults())); + return expected_landmark_list; +} + +ClassificationList GetExpectedHandedness( + const std::vector& handedness_labels) { + ClassificationList expected_handedness; + for (const auto& handedness_label : handedness_labels) { + auto& classification = *expected_handedness.add_classification(); + classification.set_label(handedness_label); + classification.set_display_name(handedness_label); + } + return expected_handedness; +} + +// Struct holding the parameters for parameterized HandLandmarkerTest +// class. +struct SingeHandTestParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of the model to test. + std::string input_model_name; + // The filename of the test image. + std::string test_image_name; + // RoI on image to detect hands. + NormalizedRect hand_rect; + // Expected hand presence value. + bool expected_presence; + // The expected output landmarks positions in pixels coornidates. + NormalizedLandmarkList expected_landmarks; + // The expected handedness ClassificationList. + ClassificationList expected_handedness; + // The max value difference between expected_positions and detected positions. + float landmarks_diff_threshold; +}; + +struct MultiHandTestParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of the model to test. + std::string input_model_name; + // The filename of the test image. + std::string test_image_name; + // RoIs on image to detect hands. + std::vector hand_rects; + // Expected hand presence values. + std::vector expected_presences; + // The expected output landmarks positions in pixels coornidates. + std::vector expected_landmark_lists; + // The expected handedness ClassificationList. + std::vector expected_handedness; + // The max value difference between expected_positions and detected positions. + float landmarks_diff_threshold; +}; + +// Helper function to construct NormalizeRect proto. +NormalizedRect MakeHandRect(float x_center, float y_center, float width, + float height, float rotation) { + NormalizedRect hand_rect; + hand_rect.set_x_center(x_center); + hand_rect.set_y_center(y_center); + hand_rect.set_width(width); + hand_rect.set_height(height); + hand_rect.set_rotation(rotation); + return hand_rect; +} + +class HandLandmarkerTest : public testing::TestWithParam { +}; + +TEST_P(HandLandmarkerTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + GetParam().test_image_name))); + MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateSingleHandTaskRunner( + GetParam().input_model_name)); + + auto output_packets = task_runner->Process( + {{kImageName, MakePacket(std::move(image))}, + {kHandRectName, + MakePacket(std::move(GetParam().hand_rect))}}); + MP_ASSERT_OK(output_packets); + + const bool presence = (*output_packets)[kPresenceName].Get(); + ASSERT_EQ(presence, GetParam().expected_presence); + + if (presence) { + const ClassificationList& handedness = + (*output_packets)[kHandednessName].Get(); + const ClassificationList expected_handedness = + GetParam().expected_handedness; + EXPECT_THAT(handedness, Partially(EqualsProto(expected_handedness))); + + const NormalizedLandmarkList landmarks = + (*output_packets)[kLandmarksName].Get(); + + const NormalizedLandmarkList& expected_landmarks = + GetParam().expected_landmarks; + + EXPECT_THAT( + landmarks, + Approximately(Partially(EqualsProto(expected_landmarks)), + /*margin=*/kAbsMargin, + /*fraction=*/GetParam().landmarks_diff_threshold)); + } +} + +class MultiHandLandmarkerTest + : public testing::TestWithParam {}; + +TEST_P(MultiHandLandmarkerTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + GetParam().test_image_name))); + MP_ASSERT_OK_AND_ASSIGN( + auto task_runner, CreateMultiHandTaskRunner(GetParam().input_model_name)); + + auto output_packets = task_runner->Process( + {{kImageName, MakePacket(std::move(image))}, + {kHandRectName, MakePacket>( + std::move(GetParam().hand_rects))}}); + MP_ASSERT_OK(output_packets); + + const std::vector& presences = + (*output_packets)[kPresenceName].Get>(); + const std::vector& handednesses = + (*output_packets)[kHandednessName].Get>(); + const std::vector& landmark_lists = + (*output_packets)[kLandmarksName] + .Get>(); + + EXPECT_THAT(presences, ElementsAreArray(GetParam().expected_presences)); + EXPECT_THAT(handednesses, Pointwise(Partially(EqualsProto()), + GetParam().expected_handedness)); + EXPECT_THAT( + landmark_lists, + Pointwise(Approximately(Partially(EqualsProto()), /*margin=*/kAbsMargin, + /*fraction=*/GetParam().landmarks_diff_threshold), + GetParam().expected_landmark_lists)); +} + +INSTANTIATE_TEST_SUITE_P( + HandLandmarkerTest, HandLandmarkerTest, + Values( + SingeHandTestParams{ + .test_name = "HandLandmarkerLiteModelRightUpHand", + .input_model_name = kHandLandmarkerLiteModel, + .test_image_name = kRightHandsImage, + .hand_rect = MakeHandRect(0.25, 0.5, 0.5, 1.0, 0), + .expected_presence = true, + .expected_landmarks = + GetExpectedLandmarkList(kExpectedRightUpHandLandmarksFilename), + .expected_handedness = GetExpectedHandedness({"Right"}), + .landmarks_diff_threshold = kLiteModelFractionDiff}, + SingeHandTestParams{ + .test_name = "HandLandmarkerLiteModelRightDownHand", + .input_model_name = kHandLandmarkerLiteModel, + .test_image_name = kRightHandsImage, + .hand_rect = MakeHandRect(0.75, 0.5, 0.5, 1.0, M_PI), + .expected_presence = true, + .expected_landmarks = GetExpectedLandmarkList( + kExpectedRightDownHandLandmarksFilename), + .expected_handedness = GetExpectedHandedness({"Right"}), + .landmarks_diff_threshold = kLiteModelFractionDiff}, + SingeHandTestParams{ + .test_name = "HandLandmarkerFullModelRightUpHand", + .input_model_name = kHandLandmarkerFullModel, + .test_image_name = kRightHandsImage, + .hand_rect = MakeHandRect(0.25, 0.5, 0.5, 1.0, 0), + .expected_presence = true, + .expected_landmarks = + GetExpectedLandmarkList(kExpectedRightUpHandLandmarksFilename), + .expected_handedness = GetExpectedHandedness({"Right"}), + .landmarks_diff_threshold = kFullModelFractionDiff}, + SingeHandTestParams{ + .test_name = "HandLandmarkerFullModelRightDownHand", + .input_model_name = kHandLandmarkerFullModel, + .test_image_name = kRightHandsImage, + .hand_rect = MakeHandRect(0.75, 0.5, 0.5, 1.0, M_PI), + .expected_presence = true, + .expected_landmarks = GetExpectedLandmarkList( + kExpectedRightDownHandLandmarksFilename), + .expected_handedness = GetExpectedHandedness({"Right"}), + .landmarks_diff_threshold = kFullModelFractionDiff}, + SingeHandTestParams{ + .test_name = "HandLandmarkerLiteModelLeftUpHand", + .input_model_name = kHandLandmarkerLiteModel, + .test_image_name = kLeftHandsImage, + .hand_rect = MakeHandRect(0.75, 0.5, 0.5, 1.0, 0), + .expected_presence = true, + .expected_landmarks = + GetExpectedLandmarkList(kExpectedLeftUpHandLandmarksFilename), + .expected_handedness = GetExpectedHandedness({"Left"}), + .landmarks_diff_threshold = kLiteModelFractionDiff}, + SingeHandTestParams{ + .test_name = "HandLandmarkerLiteModelLeftDownHand", + .input_model_name = kHandLandmarkerLiteModel, + .test_image_name = kLeftHandsImage, + .hand_rect = MakeHandRect(0.25, 0.5, 0.5, 1.0, M_PI), + .expected_presence = true, + .expected_landmarks = + GetExpectedLandmarkList(kExpectedLeftDownHandLandmarksFilename), + .expected_handedness = GetExpectedHandedness({"Left"}), + .landmarks_diff_threshold = kLiteModelFractionDiff}, + SingeHandTestParams{ + .test_name = "HandLandmarkerFullModelLeftUpHand", + .input_model_name = kHandLandmarkerFullModel, + .test_image_name = kLeftHandsImage, + .hand_rect = MakeHandRect(0.75, 0.5, 0.5, 1.0, 0), + .expected_presence = true, + .expected_landmarks = + GetExpectedLandmarkList(kExpectedLeftUpHandLandmarksFilename), + .expected_handedness = GetExpectedHandedness({"Left"}), + .landmarks_diff_threshold = kFullModelFractionDiff}, + SingeHandTestParams{ + .test_name = "HandLandmarkerFullModelLeftDownHand", + .input_model_name = kHandLandmarkerFullModel, + .test_image_name = kLeftHandsImage, + .hand_rect = MakeHandRect(0.25, 0.5, 0.5, 1.0, M_PI), + .expected_presence = true, + .expected_landmarks = + GetExpectedLandmarkList(kExpectedLeftDownHandLandmarksFilename), + .expected_handedness = GetExpectedHandedness({"Left"}), + .landmarks_diff_threshold = kFullModelFractionDiff}), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +INSTANTIATE_TEST_SUITE_P( + MultiHandLandmarkerTest, MultiHandLandmarkerTest, + Values( + MultiHandTestParams{ + .test_name = "MultiHandLandmarkerRightHands", + .input_model_name = kHandLandmarkerLiteModel, + .test_image_name = kRightHandsImage, + .hand_rects = + { + MakeHandRect(0.25, 0.5, 0.5, 1.0, 0), + MakeHandRect(0.75, 0.5, 0.5, 1.0, M_PI), + }, + .expected_presences = {true, true}, + .expected_landmark_lists = + {GetExpectedLandmarkList(kExpectedRightUpHandLandmarksFilename), + GetExpectedLandmarkList( + kExpectedRightDownHandLandmarksFilename)}, + .expected_handedness = {GetExpectedHandedness({"Right"}), + GetExpectedHandedness({"Right"})}, + .landmarks_diff_threshold = kLiteModelFractionDiff, + }, + MultiHandTestParams{ + .test_name = "MultiHandLandmarkerLeftHands", + .input_model_name = kHandLandmarkerLiteModel, + .test_image_name = kLeftHandsImage, + .hand_rects = + { + MakeHandRect(0.75, 0.5, 0.5, 1.0, 0), + MakeHandRect(0.25, 0.5, 0.5, 1.0, M_PI), + }, + .expected_presences = {true, true}, + .expected_landmark_lists = + {GetExpectedLandmarkList(kExpectedLeftUpHandLandmarksFilename), + GetExpectedLandmarkList( + kExpectedLeftDownHandLandmarksFilename)}, + .expected_handedness = {GetExpectedHandedness({"Left"}), + GetExpectedHandedness({"Left"})}, + .landmarks_diff_threshold = kLiteModelFractionDiff, + }), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD new file mode 100644 index 000000000..8cc984c47 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD @@ -0,0 +1,43 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "hand_landmarker_subgraph_options_proto", + srcs = ["hand_landmarker_subgraph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) + +mediapipe_proto_library( + name = "hand_landmarker_options_proto", + srcs = ["hand_landmarker_options.proto"], + deps = [ + ":hand_landmarker_subgraph_options_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_options.proto new file mode 100644 index 000000000..b3d82eda4 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_options.proto @@ -0,0 +1,40 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.vision.hand_landmarker.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; +import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto"; +import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto"; + +message HandLandmarkerOptions { + extend mediapipe.CalculatorOptions { + optional HandLandmarkerOptions ext = 462713202; + } + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite + // model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // The locale to use for display names specified through the TFLite Model + // Metadata, if any. Defaults to English. + optional string display_names_locale = 2 [default = "en"]; + + optional hand_detector.proto.HandDetectorOptions hand_detector_options = 3; + + optional HandLandmarkerSubgraphOptions hand_landmarker_subgraph_options = 4; +} diff --git a/mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto similarity index 90% rename from mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_options.proto rename to mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto index a2cfc7eaf..9e93384d6 100644 --- a/mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto @@ -15,14 +15,14 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks; +package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; -message HandLandmarkDetectorOptions { +message HandLandmarkerSubgraphOptions { extend mediapipe.CalculatorOptions { - optional HandLandmarkDetectorOptions ext = 462713202; + optional HandLandmarkerSubgraphOptions ext = 474472470; } // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. diff --git a/mediapipe/tasks/cc/vision/image_classification/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classification/image_classifier.cc deleted file mode 100644 index 4c70262e2..000000000 --- a/mediapipe/tasks/cc/vision/image_classification/image_classifier.cc +++ /dev/null @@ -1,89 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mediapipe/tasks/cc/vision/image_classification/image_classifier.h" - -#include -#include - -#include "absl/memory/memory.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "mediapipe/framework/api2/builder.h" -#include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/packet.h" -#include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" -#include "mediapipe/tasks/cc/core/base_task_api.h" -#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" -#include "mediapipe/tasks/cc/core/task_api_factory.h" -#include "mediapipe/tasks/cc/core/task_runner.h" -#include "mediapipe/tasks/cc/vision/image_classification/image_classifier_options.pb.h" -#include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/kernels/register.h" - -namespace mediapipe { -namespace tasks { -namespace vision { - -namespace { - -constexpr char kImageStreamName[] = "image_in"; -constexpr char kImageTag[] = "IMAGE"; -constexpr char kClassificationResultStreamName[] = "classification_result_out"; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; -constexpr char kSubgraphTypeName[] = - "mediapipe.tasks.vision.ImageClassifierGraph"; - -// Creates a MediaPipe graph config that only contains a single subgraph node of -// "mediapipe.tasks.vision.ImageClassifierGraph". -CalculatorGraphConfig CreateGraphConfig( - std::unique_ptr options) { - api2::builder::Graph graph; - auto& subgraph = graph.AddNode(kSubgraphTypeName); - subgraph.GetOptions().Swap(options.get()); - graph.In(kImageTag).SetName(kImageStreamName) >> subgraph.In(kImageTag); - subgraph.Out(kClassificationResultTag) - .SetName(kClassificationResultStreamName) >> - graph.Out(kClassificationResultTag); - return graph.GetConfig(); -} - -} // namespace - -absl::StatusOr> ImageClassifier::Create( - std::unique_ptr options, - std::unique_ptr resolver) { - return core::TaskApiFactory::Create( - CreateGraphConfig(std::move(options)), std::move(resolver)); -} - -absl::StatusOr ImageClassifier::Classify(Image image) { - if (image.UsesGpu()) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "GPU input images are currently not supported.", - MediaPipeTasksStatus::kRunnerUnexpectedInputError); - } - ASSIGN_OR_RETURN(auto output_packets, - runner_->Process({{kImageStreamName, - MakePacket(std::move(image))}})); - return output_packets[kClassificationResultStreamName] - .Get(); -} - -} // namespace vision -} // namespace tasks -} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_classification/image_classifier.h b/mediapipe/tasks/cc/vision/image_classification/image_classifier.h deleted file mode 100644 index 452d9e8c4..000000000 --- a/mediapipe/tasks/cc/vision/image_classification/image_classifier.h +++ /dev/null @@ -1,89 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_CLASSIFICATION_IMAGE_CLASSIFIER_H_ -#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_CLASSIFICATION_IMAGE_CLASSIFIER_H_ - -#include - -#include "absl/memory/memory.h" -#include "absl/status/statusor.h" -#include "mediapipe/framework/formats/image.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" -#include "mediapipe/tasks/cc/core/base_task_api.h" -#include "mediapipe/tasks/cc/vision/image_classification/image_classifier_options.pb.h" -#include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/kernels/register.h" - -namespace mediapipe { -namespace tasks { -namespace vision { - -// Performs classification on images. -// -// The API expects a TFLite model with optional, but strongly recommended, -// TFLite Model Metadata. -// -// Input tensor: -// (kTfLiteUInt8/kTfLiteFloat32) -// - image input of size `[batch x height x width x channels]`. -// - batch inference is not supported (`batch` is required to be 1). -// - only RGB inputs are supported (`channels` is required to be 3). -// - if type is kTfLiteFloat32, NormalizationOptions are required to be -// attached to the metadata for input normalization. -// At least one output tensor with: -// (kTfLiteUInt8/kTfLiteFloat32) -// - `N `classes and either 2 or 4 dimensions, i.e. `[1 x N]` or -// `[1 x 1 x 1 x N]` -// - optional (but recommended) label map(s) as AssociatedFile-s with type -// TENSOR_AXIS_LABELS, containing one label per line. The first such -// AssociatedFile (if any) is used to fill the `class_name` field of the -// results. The `display_name` field is filled from the AssociatedFile (if -// any) whose locale matches the `display_names_locale` field of the -// `ImageClassifierOptions` used at creation time ("en" by default, i.e. -// English). If none of these are available, only the `index` field of the -// results will be filled. -// -// An example of such model can be found at: -// https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1 -class ImageClassifier : core::BaseTaskApi { - public: - using BaseTaskApi::BaseTaskApi; - - // Creates an ImageClassifier from the provided options. A non-default - // OpResolver can be specified in order to support custom Ops or specify a - // subset of built-in Ops. - static absl::StatusOr> Create( - std::unique_ptr options, - std::unique_ptr resolver = - absl::make_unique()); - - // Performs actual classification on the provided Image. - // - // TODO: describe exact preprocessing steps once - // YUVToImageCalculator is integrated. - absl::StatusOr Classify(mediapipe::Image image); - - // TODO: add Classify() variant taking a region of interest as - // additional argument. - - // TODO: add ClassifyAsync() method for the streaming use case. -}; - -} // namespace vision -} // namespace tasks -} // namespace mediapipe - -#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_CLASSIFICATION_IMAGE_CLASSIFIER_H_ diff --git a/mediapipe/tasks/cc/vision/image_classification/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classification/image_classifier_test.cc deleted file mode 100644 index 014f11352..000000000 --- a/mediapipe/tasks/cc/vision/image_classification/image_classifier_test.cc +++ /dev/null @@ -1,411 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mediapipe/tasks/cc/vision/image_classification/image_classifier.h" - -#include -#include -#include - -#include "absl/flags/flag.h" -#include "absl/memory/memory.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "mediapipe/framework/deps/file_path.h" -#include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/port/gmock.h" -#include "mediapipe/framework/port/gtest.h" -#include "mediapipe/framework/port/parse_text_proto.h" -#include "mediapipe/framework/port/status_matchers.h" -#include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/classifier_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/category.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" -#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" -#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" -#include "mediapipe/tasks/cc/vision/image_classification/image_classifier_options.pb.h" -#include "mediapipe/tasks/cc/vision/utils/image_utils.h" -#include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" -#include "tensorflow/lite/kernels/builtin_op_kernels.h" -#include "tensorflow/lite/mutable_op_resolver.h" - -namespace mediapipe { -namespace tasks { -namespace vision { -namespace { - -using ::mediapipe::file::JoinPath; -using ::testing::HasSubstr; -using ::testing::Optional; - -constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; -constexpr char kMobileNetFloatWithMetadata[] = "mobilenet_v2_1.0_224.tflite"; -constexpr char kMobileNetQuantizedWithMetadata[] = - "mobilenet_v1_0.25_224_quant.tflite"; - -// Checks that the two provided `ClassificationResult` are equal, with a -// tolerancy on floating-point score to account for numerical instabilities. -void ExpectApproximatelyEqual(const ClassificationResult& actual, - const ClassificationResult& expected) { - const float kPrecision = 1e-6; - ASSERT_EQ(actual.classifications_size(), expected.classifications_size()); - for (int i = 0; i < actual.classifications_size(); ++i) { - const Classifications& a = actual.classifications(i); - const Classifications& b = expected.classifications(i); - EXPECT_EQ(a.head_index(), b.head_index()); - EXPECT_EQ(a.head_name(), b.head_name()); - EXPECT_EQ(a.entries_size(), b.entries_size()); - for (int j = 0; j < a.entries_size(); ++j) { - const ClassificationEntry& x = a.entries(j); - const ClassificationEntry& y = b.entries(j); - EXPECT_EQ(x.timestamp_ms(), y.timestamp_ms()); - EXPECT_EQ(x.categories_size(), y.categories_size()); - for (int k = 0; k < x.categories_size(); ++k) { - EXPECT_EQ(x.categories(k).index(), y.categories(k).index()); - EXPECT_EQ(x.categories(k).category_name(), - y.categories(k).category_name()); - EXPECT_EQ(x.categories(k).display_name(), - y.categories(k).display_name()); - EXPECT_NEAR(x.categories(k).score(), y.categories(k).score(), - kPrecision); - } - } - } -} - -// A custom OpResolver only containing the Ops required by the test model. -class MobileNetQuantizedOpResolver : public ::tflite::MutableOpResolver { - public: - MobileNetQuantizedOpResolver() { - AddBuiltin(::tflite::BuiltinOperator_AVERAGE_POOL_2D, - ::tflite::ops::builtin::Register_AVERAGE_POOL_2D()); - AddBuiltin(::tflite::BuiltinOperator_CONV_2D, - ::tflite::ops::builtin::Register_CONV_2D()); - AddBuiltin(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D, - ::tflite::ops::builtin::Register_DEPTHWISE_CONV_2D()); - AddBuiltin(::tflite::BuiltinOperator_RESHAPE, - ::tflite::ops::builtin::Register_RESHAPE()); - AddBuiltin(::tflite::BuiltinOperator_SOFTMAX, - ::tflite::ops::builtin::Register_SOFTMAX()); - } - - MobileNetQuantizedOpResolver(const MobileNetQuantizedOpResolver& r) = delete; -}; - -// A custom OpResolver missing Ops required by the test model. -class MobileNetQuantizedOpResolverMissingOps - : public ::tflite::MutableOpResolver { - public: - MobileNetQuantizedOpResolverMissingOps() { - AddBuiltin(::tflite::BuiltinOperator_SOFTMAX, - ::tflite::ops::builtin::Register_SOFTMAX()); - } - - MobileNetQuantizedOpResolverMissingOps( - const MobileNetQuantizedOpResolverMissingOps& r) = delete; -}; - -class CreateTest : public tflite_shims::testing::Test {}; - -TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) { - auto options = std::make_unique(); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata)); - - MP_ASSERT_OK(ImageClassifier::Create( - std::move(options), absl::make_unique())); -} - -TEST_F(CreateTest, FailsWithSelectiveOpResolverMissingOps) { - auto options = std::make_unique(); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata)); - - auto image_classifier_or = ImageClassifier::Create( - std::move(options), - absl::make_unique()); - - // TODO: Make MediaPipe InferenceCalculator report the detailed - // interpreter errors (e.g., "Encountered unresolved custom op"). - EXPECT_EQ(image_classifier_or.status().code(), absl::StatusCode::kInternal); - EXPECT_THAT(image_classifier_or.status().message(), - HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk")); -} -TEST_F(CreateTest, FailsWithMissingModel) { - auto image_classifier_or = - ImageClassifier::Create(std::make_unique()); - - EXPECT_EQ(image_classifier_or.status().code(), - absl::StatusCode::kInvalidArgument); - EXPECT_THAT( - image_classifier_or.status().message(), - HasSubstr("ExternalFile must specify at least one of 'file_content', " - "'file_name' or 'file_descriptor_meta'.")); - EXPECT_THAT(image_classifier_or.status().GetPayload(kMediaPipeTasksPayload), - Optional(absl::Cord(absl::StrCat( - MediaPipeTasksStatus::kRunnerInitializationError)))); -} - -TEST_F(CreateTest, FailsWithInvalidMaxResults) { - auto options = std::make_unique(); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata)); - options->mutable_classifier_options()->set_max_results(0); - - auto image_classifier_or = ImageClassifier::Create(std::move(options)); - - EXPECT_EQ(image_classifier_or.status().code(), - absl::StatusCode::kInvalidArgument); - EXPECT_THAT(image_classifier_or.status().message(), - HasSubstr("Invalid `max_results` option")); - EXPECT_THAT(image_classifier_or.status().GetPayload(kMediaPipeTasksPayload), - Optional(absl::Cord(absl::StrCat( - MediaPipeTasksStatus::kRunnerInitializationError)))); -} - -TEST_F(CreateTest, FailsWithCombinedAllowlistAndDenylist) { - auto options = std::make_unique(); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata)); - options->mutable_classifier_options()->add_category_allowlist("foo"); - options->mutable_classifier_options()->add_category_denylist("bar"); - - auto image_classifier_or = ImageClassifier::Create(std::move(options)); - - EXPECT_EQ(image_classifier_or.status().code(), - absl::StatusCode::kInvalidArgument); - EXPECT_THAT(image_classifier_or.status().message(), - HasSubstr("mutually exclusive options")); - EXPECT_THAT(image_classifier_or.status().GetPayload(kMediaPipeTasksPayload), - Optional(absl::Cord(absl::StrCat( - MediaPipeTasksStatus::kRunnerInitializationError)))); -} - -class ClassifyTest : public tflite_shims::testing::Test {}; - -TEST_F(ClassifyTest, SucceedsWithFloatModel) { - MP_ASSERT_OK_AND_ASSIGN( - Image image, - DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); - auto options = std::make_unique(); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata)); - options->mutable_classifier_options()->set_max_results(3); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, - ImageClassifier::Create(std::move(options))); - - MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); - - ExpectApproximatelyEqual(results, ParseTextProtoOrDie( - R"pb(classifications { - entries { - categories { - index: 934 - score: 0.7939592 - category_name: "cheeseburger" - } - categories { - index: 932 - score: 0.027392805 - category_name: "bagel" - } - categories { - index: 925 - score: 0.019340655 - category_name: "guacamole" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -TEST_F(ClassifyTest, SucceedsWithQuantizedModel) { - MP_ASSERT_OK_AND_ASSIGN( - Image image, - DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); - auto options = std::make_unique(); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata)); - // Due to quantization, multiple results beyond top-1 have the exact same - // score. This leads to unstability in results ordering, so we only ask for - // top-1 here. - options->mutable_classifier_options()->set_max_results(1); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, - ImageClassifier::Create(std::move(options))); - - MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); - - ExpectApproximatelyEqual(results, ParseTextProtoOrDie( - R"pb(classifications { - entries { - categories { - index: 934 - score: 0.97265625 - category_name: "cheeseburger" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -TEST_F(ClassifyTest, SucceedsWithMaxResultsOption) { - MP_ASSERT_OK_AND_ASSIGN( - Image image, - DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); - auto options = std::make_unique(); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata)); - options->mutable_classifier_options()->set_max_results(1); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, - ImageClassifier::Create(std::move(options))); - - MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); - - ExpectApproximatelyEqual(results, ParseTextProtoOrDie( - R"pb(classifications { - entries { - categories { - index: 934 - score: 0.7939592 - category_name: "cheeseburger" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) { - MP_ASSERT_OK_AND_ASSIGN( - Image image, - DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); - auto options = std::make_unique(); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata)); - options->mutable_classifier_options()->set_score_threshold(0.02); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, - ImageClassifier::Create(std::move(options))); - - MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); - - ExpectApproximatelyEqual(results, ParseTextProtoOrDie( - R"pb(classifications { - entries { - categories { - index: 934 - score: 0.7939592 - category_name: "cheeseburger" - } - categories { - index: 932 - score: 0.027392805 - category_name: "bagel" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -TEST_F(ClassifyTest, SucceedsWithAllowlistOption) { - MP_ASSERT_OK_AND_ASSIGN( - Image image, - DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); - auto options = std::make_unique(); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata)); - options->mutable_classifier_options()->add_category_allowlist("cheeseburger"); - options->mutable_classifier_options()->add_category_allowlist("guacamole"); - options->mutable_classifier_options()->add_category_allowlist("meat loaf"); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, - ImageClassifier::Create(std::move(options))); - - MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); - - ExpectApproximatelyEqual(results, ParseTextProtoOrDie( - R"pb(classifications { - entries { - categories { - index: 934 - score: 0.7939592 - category_name: "cheeseburger" - } - categories { - index: 925 - score: 0.019340655 - category_name: "guacamole" - } - categories { - index: 963 - score: 0.0063278517 - category_name: "meat loaf" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -TEST_F(ClassifyTest, SucceedsWithDenylistOption) { - MP_ASSERT_OK_AND_ASSIGN( - Image image, - DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); - auto options = std::make_unique(); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata)); - options->mutable_classifier_options()->set_max_results(3); - options->mutable_classifier_options()->add_category_denylist("bagel"); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, - ImageClassifier::Create(std::move(options))); - - MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); - - ExpectApproximatelyEqual(results, ParseTextProtoOrDie( - R"pb(classifications { - entries { - categories { - index: 934 - score: 0.7939592 - category_name: "cheeseburger" - } - categories { - index: 925 - score: 0.019340655 - category_name: "guacamole" - } - categories { - index: 963 - score: 0.0063278517 - category_name: "meat loaf" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -} // namespace -} // namespace vision -} // namespace tasks -} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_classification/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD similarity index 72% rename from mediapipe/tasks/cc/vision/image_classification/BUILD rename to mediapipe/tasks/cc/vision/image_classifier/BUILD index 6e1119d21..4dcecdbbe 100644 --- a/mediapipe/tasks/cc/vision/image_classification/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -12,33 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") - package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -mediapipe_proto_library( - name = "image_classifier_options_proto", - srcs = ["image_classifier_options.proto"], - deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components:classifier_options_proto", - "//mediapipe/tasks/cc/core/proto:base_options_proto", - ], -) - cc_library( name = "image_classifier_graph", srcs = ["image_classifier_graph.cc"], deps = [ - ":image_classifier_options_cc_proto", "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc/components:classification_postprocessing", "//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto", "//mediapipe/tasks/cc/components:image_preprocessing", @@ -46,6 +33,7 @@ cc_library( "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto", "@com_google_absl//absl/status:statusor", ], alwayslink = 1, @@ -57,21 +45,25 @@ cc_library( hdrs = ["image_classifier.h"], deps = [ ":image_classifier_graph", - ":image_classifier_options_cc_proto", "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", - "//mediapipe/tasks/cc:common", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components:classifier_options", "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:task_api_factory", + "//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "@com_google_absl//absl/memory", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ], ) diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc new file mode 100644 index 000000000..eb74c3d98 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc @@ -0,0 +1,219 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/image_classifier/image_classifier.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/tasks/cc/components/classifier_options.h" +#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" +#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" +#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { + +namespace { + +constexpr char kClassificationResultStreamName[] = "classification_result_out"; +constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; +constexpr char kImageInStreamName[] = "image_in"; +constexpr char kImageOutStreamName[] = "image_out"; +constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectName[] = "norm_rect_in"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kSubgraphTypeName[] = + "mediapipe.tasks.vision.ImageClassifierGraph"; +constexpr int kMicroSecondsPerMilliSecond = 1000; + +using ::mediapipe::tasks::core::PacketMap; +using ImageClassifierOptionsProto = + image_classifier::proto::ImageClassifierOptions; + +// Builds a NormalizedRect covering the entire image. +NormalizedRect BuildFullImageNormRect() { + NormalizedRect norm_rect; + norm_rect.set_x_center(0.5); + norm_rect.set_y_center(0.5); + norm_rect.set_width(1); + norm_rect.set_height(1); + return norm_rect; +} + +// Creates a MediaPipe graph config that contains a subgraph node of +// "mediapipe.tasks.vision.ImageClassifierGraph". If the task is running in the +// live stream mode, a "FlowLimiterCalculator" will be added to limit the number +// of frames in flight. +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options_proto, + bool enable_flow_limiting) { + api2::builder::Graph graph; + graph.In(kImageTag).SetName(kImageInStreamName); + graph.In(kNormRectTag).SetName(kNormRectName); + auto& task_subgraph = graph.AddNode(kSubgraphTypeName); + task_subgraph.GetOptions().Swap( + options_proto.get()); + task_subgraph.Out(kClassificationResultTag) + .SetName(kClassificationResultStreamName) >> + graph.Out(kClassificationResultTag); + task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> + graph.Out(kImageTag); + if (enable_flow_limiting) { + return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph, + {kImageTag, kNormRectTag}, + kClassificationResultTag); + } + graph.In(kImageTag) >> task_subgraph.In(kImageTag); + graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); + return graph.GetConfig(); +} + +// Converts the user-facing ImageClassifierOptions struct to the internal +// ImageClassifierOptions proto. +std::unique_ptr +ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) { + auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + options_proto->mutable_base_options()->set_use_stream_mode( + options->running_mode != core::RunningMode::IMAGE); + auto classifier_options_proto = + std::make_unique( + components::ConvertClassifierOptionsToProto( + &(options->classifier_options))); + options_proto->mutable_classifier_options()->Swap( + classifier_options_proto.get()); + return options_proto; +} + +} // namespace + +absl::StatusOr> ImageClassifier::Create( + std::unique_ptr options) { + auto options_proto = ConvertImageClassifierOptionsToProto(options.get()); + tasks::core::PacketsCallback packets_callback = nullptr; + if (options->result_callback) { + auto result_callback = options->result_callback; + packets_callback = + [=](absl::StatusOr status_or_packets) { + if (!status_or_packets.ok()) { + Image image; + result_callback(status_or_packets.status(), image, + Timestamp::Unset().Value()); + } + if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { + return; + } + Packet classification_result_packet = + status_or_packets.value()[kClassificationResultStreamName]; + Packet image_packet = status_or_packets.value()[kImageOutStreamName]; + result_callback( + classification_result_packet.Get(), + image_packet.Get(), + classification_result_packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond); + }; + } + return core::VisionTaskApiFactory::Create( + CreateGraphConfig( + std::move(options_proto), + options->running_mode == core::RunningMode::LIVE_STREAM), + std::move(options->base_options.op_resolver), options->running_mode, + std::move(packets_callback)); +} + +absl::StatusOr ImageClassifier::Classify( + Image image, std::optional roi) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "GPU input images are currently not supported.", + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + NormalizedRect norm_rect = + roi.has_value() ? roi.value() : BuildFullImageNormRect(); + ASSIGN_OR_RETURN( + auto output_packets, + ProcessImageData( + {{kImageInStreamName, MakePacket(std::move(image))}, + {kNormRectName, MakePacket(std::move(norm_rect))}})); + return output_packets[kClassificationResultStreamName] + .Get(); +} + +absl::StatusOr ImageClassifier::ClassifyForVideo( + Image image, int64 timestamp_ms, std::optional roi) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "GPU input images are currently not supported.", + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + NormalizedRect norm_rect = + roi.has_value() ? roi.value() : BuildFullImageNormRect(); + ASSIGN_OR_RETURN( + auto output_packets, + ProcessVideoData( + {{kImageInStreamName, + MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectName, + MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); + return output_packets[kClassificationResultStreamName] + .Get(); +} + +absl::Status ImageClassifier::ClassifyAsync(Image image, int64 timestamp_ms, + std::optional roi) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "GPU input images are currently not supported.", + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + NormalizedRect norm_rect = + roi.has_value() ? roi.value() : BuildFullImageNormRect(); + return SendLiveStreamData( + {{kImageInStreamName, + MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectName, + MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); +} + +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h new file mode 100644 index 000000000..2fbac71b2 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h @@ -0,0 +1,168 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_CLASSIFIER_IMAGE_CLASSIFIER_H_ +#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_CLASSIFIER_IMAGE_CLASSIFIER_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/components/classifier_options.h" +#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" + +namespace mediapipe { +namespace tasks { +namespace vision { + +// The options for configuring a Mediapipe image classifier task. +struct ImageClassifierOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + tasks::core::BaseOptions base_options; + + // The running mode of the task. Default to the image mode. + // Image classifier has three running modes: + // 1) The image mode for classifying image on single image inputs. + // 2) The video mode for classifying image on the decoded frames of a video. + // 3) The live stream mode for classifying image on the live stream of input + // data, such as from camera. In this mode, the "result_callback" below must + // be specified to receive the segmentation results asynchronously. + core::RunningMode running_mode = core::RunningMode::IMAGE; + + // Options for configuring the classifier behavior, such as score threshold, + // number of results, etc. + components::ClassifierOptions classifier_options; + + // The user-defined result callback for processing live stream data. + // The result callback should only be specified when the running mode is set + // to RunningMode::LIVE_STREAM. + std::function, const Image&, int64)> + result_callback = nullptr; +}; + +// Performs classification on images. +// +// The API expects a TFLite model with optional, but strongly recommended, +// TFLite Model Metadata. +// +// Input tensor: +// (kTfLiteUInt8/kTfLiteFloat32) +// - image input of size `[batch x height x width x channels]`. +// - batch inference is not supported (`batch` is required to be 1). +// - only RGB inputs are supported (`channels` is required to be 3). +// - if type is kTfLiteFloat32, NormalizationOptions are required to be +// attached to the metadata for input normalization. +// At least one output tensor with: +// (kTfLiteUInt8/kTfLiteFloat32) +// - `N `classes and either 2 or 4 dimensions, i.e. `[1 x N]` or +// `[1 x 1 x 1 x N]` +// - optional (but recommended) label map(s) as AssociatedFile-s with type +// TENSOR_AXIS_LABELS, containing one label per line. The first such +// AssociatedFile (if any) is used to fill the `class_name` field of the +// results. The `display_name` field is filled from the AssociatedFile (if +// any) whose locale matches the `display_names_locale` field of the +// `ImageClassifierOptions` used at creation time ("en" by default, i.e. +// English). If none of these are available, only the `index` field of the +// results will be filled. +// - optional score calibration can be attached using ScoreCalibrationOptions +// and an AssociatedFile with type TENSOR_AXIS_SCORE_CALIBRATION. See +// metadata_schema.fbs [1] for more details. +// +// An example of such model can be found at: +// https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1 +// +// [1]: +// https://github.com/google/mediapipe/blob/6cdc6443b6a7ed662744e2a2ce2d58d9c83e6d6f/mediapipe/tasks/metadata/metadata_schema.fbs#L456 +class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { + public: + using BaseVisionTaskApi::BaseVisionTaskApi; + + // Creates an ImageClassifier from the provided options. A non-default + // OpResolver can be specified in the BaseOptions in order to support custom + // Ops or specify a subset of built-in Ops. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Performs image classification on the provided single image. Classification + // is performed on the region of interest specified by the `roi` argument if + // provided, or on the entire image otherwise. + // + // Only use this method when the ImageClassifier is created with the image + // running mode. + // + // The image can be of any size with format RGB or RGBA. + // TODO: describe exact preprocessing steps once + // YUVToImageCalculator is integrated. + absl::StatusOr Classify( + mediapipe::Image image, + std::optional roi = std::nullopt); + + // Performs image classification on the provided video frame. Classification + // is performed on the region of interested specified by the `roi` argument if + // provided, or on the entire image otherwise. + // + // Only use this method when the ImageClassifier is created with the video + // running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide the video frame's timestamp (in milliseconds). The input timestamps + // must be monotonically increasing. + absl::StatusOr ClassifyForVideo( + mediapipe::Image image, int64 timestamp_ms, + std::optional roi = std::nullopt); + + // Sends live image data to image classification, and the results will be + // available via the "result_callback" provided in the ImageClassifierOptions. + // Classification is performed on the region of interested specified by the + // `roi` argument if provided, or on the entire image otherwise. + // + // Only use this method when the ImageClassifier is created with the live + // stream running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide a timestamp (in milliseconds) to indicate when the input image is + // sent to the object detector. The input timestamps must be monotonically + // increasing. + // + // The "result_callback" prvoides + // - The classification results as a ClassificationResult object. + // - The const reference to the corresponding input image that the image + // classifier runs on. Note that the const reference to the image will no + // longer be valid when the callback returns. To access the image data + // outside of the callback, callers need to make a copy of the image. + // - The input timestamp in milliseconds. + absl::Status ClassifyAsync( + mediapipe::Image image, int64 timestamp_ms, + std::optional roi = std::nullopt); + + // TODO: add Classify() variants taking a region of interest as + // additional argument. + + // Shuts down the ImageClassifier when all works are done. + absl::Status Close() { return runner_->Close(); } +}; + +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_CLASSIFIER_IMAGE_CLASSIFIER_H_ diff --git a/mediapipe/tasks/cc/vision/image_classification/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc similarity index 66% rename from mediapipe/tasks/cc/vision/image_classification/image_classifier_graph.cc rename to mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 8dd8fe530..532b7db45 100644 --- a/mediapipe/tasks/cc/vision/image_classification/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/classification_postprocessing.h" #include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/containers/classifications.pb.h" @@ -28,7 +29,7 @@ limitations under the License. #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" -#include "mediapipe/tasks/cc/vision/image_classification/image_classifier_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h" namespace mediapipe { namespace tasks { @@ -41,13 +42,23 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ImageClassifierOptionsProto = + image_classifier::proto::ImageClassifierOptions; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kTensorsTag[] = "TENSORS"; +// Struct holding the different output streams produced by the image classifier +// subgraph. +struct ImageClassifierOutputStreams { + Source classification_result; + Source image; +}; + } // namespace // A "mediapipe.tasks.vision.ImageClassifierGraph" performs image @@ -57,19 +68,30 @@ constexpr char kTensorsTag[] = "TENSORS"; // Inputs: // IMAGE - Image // Image to perform classification on. -// +// NORM_RECT - NormalizedRect @Optional +// Describes region of image to perform classification on. +// @Optional: rect covering the whole image is used if not specified. // Outputs: // CLASSIFICATION_RESULT - ClassificationResult // The aggregated classification result object has two dimensions: // (classification head, classification category) +// IMAGE - Image +// The image that object detection runs on. // // Example: // node { // calculator: "mediapipe.tasks.vision.ImageClassifierGraph" // input_stream: "IMAGE:image_in" // output_stream: "CLASSIFICATION_RESULT:classification_result_out" +// output_stream: "IMAGE:image_out" // options { -// [mediapipe.tasks.vision.ImageClassifierOptions.ext] { +// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "/path/to/model.tflite" +// } +// } // max_results: 3 // score_threshold: 0.5 // category_allowlist: "foo" @@ -83,15 +105,17 @@ class ImageClassifierGraph : public core::ModelTaskGraph { absl::StatusOr GetConfig( SubgraphContext* sc) override { ASSIGN_OR_RETURN(const auto* model_resources, - CreateModelResources(sc)); + CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN( - auto classification_result_out, - BuildImageClassificationTask(sc->Options(), - *model_resources, - graph[Input(kImageTag)], graph)); - classification_result_out >> + auto output_streams, + BuildImageClassificationTask( + sc->Options(), *model_resources, + graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], graph)); + output_streams.classification_result >> graph[Output(kClassificationResultTag)]; + output_streams.image >> graph[Output(kImageTag)]; return graph.GetConfig(); } @@ -104,38 +128,44 @@ class ImageClassifierGraph : public core::ModelTaskGraph { // task_options: the mediapipe tasks ImageClassifierOptions. // model_resources: the ModelSources object initialized from an image // classification model file with model metadata. - // image_in: (mediapipe::Image) stream to run object detection on. + // image_in: (mediapipe::Image) stream to run classification on. // graph: the mediapipe builder::Graph instance to be updated. - absl::StatusOr> BuildImageClassificationTask( - const ImageClassifierOptions& task_options, + absl::StatusOr BuildImageClassificationTask( + const ImageClassifierOptionsProto& task_options, const core::ModelResources& model_resources, Source image_in, - Graph& graph) { + Source norm_rect_in, Graph& graph) { // Adds preprocessing calculators and connects them to the graph input image // stream. auto& preprocessing = - graph.AddNode("mediapipe.tasks.ImagePreprocessingSubgraph"); + graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( model_resources, - &preprocessing.GetOptions())); + &preprocessing + .GetOptions())); image_in >> preprocessing.In(kImageTag); + norm_rect_in >> preprocessing.In(kNormRectTag); // Adds inference subgraph and connects its input stream to the outoput // tensors produced by the ImageToTensorCalculator. - auto& inference = AddInference(model_resources, graph); + auto& inference = AddInference( + model_resources, task_options.base_options().acceleration(), graph); preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag); // Adds postprocessing calculators and connects them to the graph output. - auto& postprocessing = - graph.AddNode("mediapipe.tasks.ClassificationPostprocessingSubgraph"); + auto& postprocessing = graph.AddNode( + "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( model_resources, task_options.classifier_options(), - &postprocessing.GetOptions())); + &postprocessing.GetOptions< + tasks::components::ClassificationPostprocessingOptions>())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Outputs the aggregated classification result as the subgraph output // stream. - return postprocessing[Output( - kClassificationResultTag)]; + return ImageClassifierOutputStreams{ + /*classification_result=*/postprocessing[Output( + kClassificationResultTag)], + /*image=*/preprocessing[Output(kImageTag)]}; } }; REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::ImageClassifierGraph); diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc new file mode 100644 index 000000000..7cf6414bf --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -0,0 +1,819 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/image_classifier/image_classifier.h" + +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/category.pb.h" +#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/kernels/builtin_op_kernels.h" +#include "tensorflow/lite/mutable_op_resolver.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace { + +using ::mediapipe::file::JoinPath; +using ::testing::HasSubstr; +using ::testing::Optional; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kMobileNetFloatWithMetadata[] = "mobilenet_v2_1.0_224.tflite"; +constexpr char kMobileNetQuantizedWithMetadata[] = + "mobilenet_v1_0.25_224_quant.tflite"; +constexpr char kMobileNetQuantizedWithDummyScoreCalibration[] = + "mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite"; + +// Checks that the two provided `ClassificationResult` are equal, with a +// tolerancy on floating-point score to account for numerical instabilities. +void ExpectApproximatelyEqual(const ClassificationResult& actual, + const ClassificationResult& expected) { + const float kPrecision = 1e-6; + ASSERT_EQ(actual.classifications_size(), expected.classifications_size()); + for (int i = 0; i < actual.classifications_size(); ++i) { + const Classifications& a = actual.classifications(i); + const Classifications& b = expected.classifications(i); + EXPECT_EQ(a.head_index(), b.head_index()); + EXPECT_EQ(a.head_name(), b.head_name()); + EXPECT_EQ(a.entries_size(), b.entries_size()); + for (int j = 0; j < a.entries_size(); ++j) { + const ClassificationEntry& x = a.entries(j); + const ClassificationEntry& y = b.entries(j); + EXPECT_EQ(x.timestamp_ms(), y.timestamp_ms()); + EXPECT_EQ(x.categories_size(), y.categories_size()); + for (int k = 0; k < x.categories_size(); ++k) { + EXPECT_EQ(x.categories(k).index(), y.categories(k).index()); + EXPECT_EQ(x.categories(k).category_name(), + y.categories(k).category_name()); + EXPECT_EQ(x.categories(k).display_name(), + y.categories(k).display_name()); + EXPECT_NEAR(x.categories(k).score(), y.categories(k).score(), + kPrecision); + } + } + } +} + +// Generates expected results for "burger.jpg" using kMobileNetFloatWithMetadata +// with max_results set to 3. +ClassificationResult GenerateBurgerResults(int64 timestamp) { + return ParseTextProtoOrDie( + absl::StrFormat(R"pb(classifications { + entries { + categories { + index: 934 + score: 0.7939592 + category_name: "cheeseburger" + } + categories { + index: 932 + score: 0.027392805 + category_name: "bagel" + } + categories { + index: 925 + score: 0.019340655 + category_name: "guacamole" + } + timestamp_ms: %d + } + head_index: 0 + head_name: "probability" + })pb", + timestamp)); +} + +// Generates expected results for "multi_objects.jpg" using +// kMobileNetFloatWithMetadata with max_results set to 1 and the right bounding +// box set around the soccer ball. +ClassificationResult GenerateSoccerBallResults(int64 timestamp) { + return ParseTextProtoOrDie( + absl::StrFormat(R"pb(classifications { + entries { + categories { + index: 806 + score: 0.996527493 + category_name: "soccer ball" + } + timestamp_ms: %d + } + head_index: 0 + head_name: "probability" + })pb", + timestamp)); +} + +// A custom OpResolver only containing the Ops required by the test model. +class MobileNetQuantizedOpResolver : public ::tflite::MutableOpResolver { + public: + MobileNetQuantizedOpResolver() { + AddBuiltin(::tflite::BuiltinOperator_AVERAGE_POOL_2D, + ::tflite::ops::builtin::Register_AVERAGE_POOL_2D()); + AddBuiltin(::tflite::BuiltinOperator_CONV_2D, + ::tflite::ops::builtin::Register_CONV_2D()); + AddBuiltin(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D, + ::tflite::ops::builtin::Register_DEPTHWISE_CONV_2D()); + AddBuiltin(::tflite::BuiltinOperator_RESHAPE, + ::tflite::ops::builtin::Register_RESHAPE()); + AddBuiltin(::tflite::BuiltinOperator_SOFTMAX, + ::tflite::ops::builtin::Register_SOFTMAX()); + } + + MobileNetQuantizedOpResolver(const MobileNetQuantizedOpResolver& r) = delete; +}; + +// A custom OpResolver missing Ops required by the test model. +class MobileNetQuantizedOpResolverMissingOps + : public ::tflite::MutableOpResolver { + public: + MobileNetQuantizedOpResolverMissingOps() { + AddBuiltin(::tflite::BuiltinOperator_SOFTMAX, + ::tflite::ops::builtin::Register_SOFTMAX()); + } + + MobileNetQuantizedOpResolverMissingOps( + const MobileNetQuantizedOpResolverMissingOps& r) = delete; +}; + +class CreateTest : public tflite_shims::testing::Test {}; + +TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata); + options->base_options.op_resolver = + std::make_unique(); + + MP_ASSERT_OK(ImageClassifier::Create(std::move(options))); +} + +TEST_F(CreateTest, FailsWithSelectiveOpResolverMissingOps) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata); + options->base_options.op_resolver = + std::make_unique(); + + auto image_classifier = ImageClassifier::Create(std::move(options)); + + // TODO: Make MediaPipe InferenceCalculator report the detailed + // interpreter errors (e.g., "Encountered unresolved custom op"). + EXPECT_EQ(image_classifier.status().code(), absl::StatusCode::kInternal); + EXPECT_THAT(image_classifier.status().message(), + HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk")); +} +TEST_F(CreateTest, FailsWithMissingModel) { + auto image_classifier = + ImageClassifier::Create(std::make_unique()); + + EXPECT_EQ(image_classifier.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + image_classifier.status().message(), + HasSubstr("ExternalFile must specify at least one of 'file_content', " + "'file_name' or 'file_descriptor_meta'.")); + EXPECT_THAT(image_classifier.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInitializationError)))); +} + +TEST_F(CreateTest, FailsWithInvalidMaxResults) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata); + options->classifier_options.max_results = 0; + + auto image_classifier = ImageClassifier::Create(std::move(options)); + + EXPECT_EQ(image_classifier.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_classifier.status().message(), + HasSubstr("Invalid `max_results` option")); + EXPECT_THAT(image_classifier.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInitializationError)))); +} + +TEST_F(CreateTest, FailsWithCombinedAllowlistAndDenylist) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata); + options->classifier_options.category_allowlist = {"foo"}; + options->classifier_options.category_denylist = {"bar"}; + + auto image_classifier = ImageClassifier::Create(std::move(options)); + + EXPECT_EQ(image_classifier.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_classifier.status().message(), + HasSubstr("mutually exclusive options")); + EXPECT_THAT(image_classifier.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInitializationError)))); +} + +TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) { + for (auto running_mode : + {core::RunningMode::IMAGE, core::RunningMode::VIDEO}) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata); + options->running_mode = running_mode; + options->result_callback = [](absl::StatusOr, + const Image& image, int64 timestamp_ms) {}; + + auto image_classifier = ImageClassifier::Create(std::move(options)); + + EXPECT_EQ(image_classifier.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + image_classifier.status().message(), + HasSubstr("a user-defined result callback shouldn't be provided")); + EXPECT_THAT(image_classifier.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kInvalidTaskGraphConfigError)))); + } +} + +TEST_F(CreateTest, FailsWithMissingCallbackInLiveStreamMode) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata); + options->running_mode = core::RunningMode::LIVE_STREAM; + + auto image_classifier = ImageClassifier::Create(std::move(options)); + + EXPECT_EQ(image_classifier.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_classifier.status().message(), + HasSubstr("a user-defined result callback must be provided")); + EXPECT_THAT(image_classifier.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kInvalidTaskGraphConfigError)))); +} + +class ImageModeTest : public tflite_shims::testing::Test {}; + +TEST_F(ImageModeTest, FailsWithCallingWrongMethod) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + auto results = image_classifier->ClassifyForVideo(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the video mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + + results = image_classifier->ClassifyAsync(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the live stream mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + MP_ASSERT_OK(image_classifier->Close()); +} + +TEST_F(ImageModeTest, SucceedsWithFloatModel) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->classifier_options.max_results = 3; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); + + ExpectApproximatelyEqual(results, GenerateBurgerResults(0)); +} + +TEST_F(ImageModeTest, SucceedsWithQuantizedModel) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata); + // Due to quantization, multiple results beyond top-1 have the exact same + // score. This leads to unstability in results ordering, so we only ask for + // top-1 here. + options->classifier_options.max_results = 1; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); + + ExpectApproximatelyEqual(results, ParseTextProtoOrDie( + R"pb(classifications { + entries { + categories { + index: 934 + score: 0.97265625 + category_name: "cheeseburger" + } + timestamp_ms: 0 + } + head_index: 0 + head_name: "probability" + })pb")); +} + +TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->classifier_options.max_results = 1; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); + + ExpectApproximatelyEqual(results, ParseTextProtoOrDie( + R"pb(classifications { + entries { + categories { + index: 934 + score: 0.7939592 + category_name: "cheeseburger" + } + timestamp_ms: 0 + } + head_index: 0 + head_name: "probability" + })pb")); +} + +TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->classifier_options.score_threshold = 0.02; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); + + ExpectApproximatelyEqual(results, ParseTextProtoOrDie( + R"pb(classifications { + entries { + categories { + index: 934 + score: 0.7939592 + category_name: "cheeseburger" + } + categories { + index: 932 + score: 0.027392805 + category_name: "bagel" + } + timestamp_ms: 0 + } + head_index: 0 + head_name: "probability" + })pb")); +} + +TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->classifier_options.category_allowlist = {"cheeseburger", "guacamole", + "meat loaf"}; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); + + ExpectApproximatelyEqual(results, ParseTextProtoOrDie( + R"pb(classifications { + entries { + categories { + index: 934 + score: 0.7939592 + category_name: "cheeseburger" + } + categories { + index: 925 + score: 0.019340655 + category_name: "guacamole" + } + categories { + index: 963 + score: 0.0063278517 + category_name: "meat loaf" + } + timestamp_ms: 0 + } + head_index: 0 + head_name: "probability" + })pb")); +} + +TEST_F(ImageModeTest, SucceedsWithDenylistOption) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->classifier_options.max_results = 3; + options->classifier_options.category_denylist = {"bagel"}; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); + + ExpectApproximatelyEqual(results, ParseTextProtoOrDie( + R"pb(classifications { + entries { + categories { + index: 934 + score: 0.7939592 + category_name: "cheeseburger" + } + categories { + index: 925 + score: 0.019340655 + category_name: "guacamole" + } + categories { + index: 963 + score: 0.0063278517 + category_name: "meat loaf" + } + timestamp_ms: 0 + } + head_index: 0 + head_name: "probability" + })pb")); +} + +TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = JoinPath( + "./", kTestDataDirectory, kMobileNetQuantizedWithDummyScoreCalibration); + // Due to quantization, multiple results beyond top-1 have the exact same + // score. This leads to unstability in results ordering, so we only ask for + // top-1 here. + options->classifier_options.max_results = 1; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); + + ExpectApproximatelyEqual(results, ParseTextProtoOrDie( + R"pb(classifications { + entries { + categories { + index: 934 + score: 0.725648628 + category_name: "cheeseburger" + } + timestamp_ms: 0 + } + head_index: 0 + head_name: "probability" + })pb")); +} + +TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "multi_objects.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->classifier_options.max_results = 1; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + // NormalizedRect around the soccer ball. + NormalizedRect roi; + roi.set_x_center(0.532); + roi.set_y_center(0.521); + roi.set_width(0.164); + roi.set_height(0.427); + + MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image, roi)); + + ExpectApproximatelyEqual(results, GenerateSoccerBallResults(0)); +} + +class VideoModeTest : public tflite_shims::testing::Test {}; + +TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->running_mode = core::RunningMode::VIDEO; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + auto results = image_classifier->Classify(image); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the image mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + + results = image_classifier->ClassifyAsync(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the live stream mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + MP_ASSERT_OK(image_classifier->Close()); +} + +TEST_F(VideoModeTest, FailsWithOutOfOrderInputTimestamps) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->running_mode = core::RunningMode::VIDEO; + options->classifier_options.max_results = 3; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + MP_ASSERT_OK(image_classifier->ClassifyForVideo(image, 1)); + auto results = image_classifier->ClassifyForVideo(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("timestamp must be monotonically increasing")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInvalidTimestampError)))); + MP_ASSERT_OK(image_classifier->ClassifyForVideo(image, 2)); + MP_ASSERT_OK(image_classifier->Close()); +} + +TEST_F(VideoModeTest, Succeeds) { + int iterations = 100; + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->running_mode = core::RunningMode::VIDEO; + options->classifier_options.max_results = 3; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + for (int i = 0; i < iterations; ++i) { + MP_ASSERT_OK_AND_ASSIGN(auto results, + image_classifier->ClassifyForVideo(image, i)); + ExpectApproximatelyEqual(results, GenerateBurgerResults(i)); + } + MP_ASSERT_OK(image_classifier->Close()); +} + +TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) { + int iterations = 100; + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "multi_objects.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->running_mode = core::RunningMode::VIDEO; + options->classifier_options.max_results = 1; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + // NormalizedRect around the soccer ball. + NormalizedRect roi; + roi.set_x_center(0.532); + roi.set_y_center(0.521); + roi.set_width(0.164); + roi.set_height(0.427); + + for (int i = 0; i < iterations; ++i) { + MP_ASSERT_OK_AND_ASSIGN(auto results, + image_classifier->ClassifyForVideo(image, i, roi)); + ExpectApproximatelyEqual(results, GenerateSoccerBallResults(i)); + } + MP_ASSERT_OK(image_classifier->Close()); +} + +class LiveStreamModeTest : public tflite_shims::testing::Test {}; + +TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->running_mode = core::RunningMode::LIVE_STREAM; + options->result_callback = [](absl::StatusOr, + const Image& image, int64 timestamp_ms) {}; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + auto results = image_classifier->Classify(image); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the image mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + + results = image_classifier->ClassifyForVideo(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the video mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + MP_ASSERT_OK(image_classifier->Close()); +} + +TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->running_mode = core::RunningMode::LIVE_STREAM; + options->result_callback = [](absl::StatusOr, + const Image& image, int64 timestamp_ms) {}; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + MP_ASSERT_OK(image_classifier->ClassifyAsync(image, 1)); + auto status = image_classifier->ClassifyAsync(image, 0); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + HasSubstr("timestamp must be monotonically increasing")); + EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInvalidTimestampError)))); + MP_ASSERT_OK(image_classifier->ClassifyAsync(image, 2)); + MP_ASSERT_OK(image_classifier->Close()); +} + +struct LiveStreamModeResults { + ClassificationResult classification_result; + std::pair image_size; + int64 timestamp_ms; +}; + +TEST_F(LiveStreamModeTest, Succeeds) { + int iterations = 100; + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + std::vector results; + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->running_mode = core::RunningMode::LIVE_STREAM; + options->classifier_options.max_results = 3; + options->result_callback = + [&results](absl::StatusOr classification_result, + const Image& image, int64 timestamp_ms) { + MP_ASSERT_OK(classification_result.status()); + results.push_back( + {.classification_result = std::move(classification_result).value(), + .image_size = {image.width(), image.height()}, + .timestamp_ms = timestamp_ms}); + }; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + for (int i = 0; i < iterations; ++i) { + MP_ASSERT_OK(image_classifier->ClassifyAsync(image, i)); + } + MP_ASSERT_OK(image_classifier->Close()); + + // Due to the flow limiter, the total of outputs will be smaller than the + // number of iterations. + ASSERT_LE(results.size(), iterations); + ASSERT_GT(results.size(), 0); + int64 timestamp_ms = -1; + for (const auto& result : results) { + EXPECT_GT(result.timestamp_ms, timestamp_ms); + timestamp_ms = result.timestamp_ms; + EXPECT_EQ(result.image_size.first, image.width()); + EXPECT_EQ(result.image_size.second, image.height()); + ExpectApproximatelyEqual(result.classification_result, + GenerateBurgerResults(timestamp_ms)); + } +} + +TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { + int iterations = 100; + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "multi_objects.jpg"))); + std::vector results; + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->running_mode = core::RunningMode::LIVE_STREAM; + options->classifier_options.max_results = 1; + options->result_callback = + [&results](absl::StatusOr classification_result, + const Image& image, int64 timestamp_ms) { + MP_ASSERT_OK(classification_result.status()); + results.push_back( + {.classification_result = std::move(classification_result).value(), + .image_size = {image.width(), image.height()}, + .timestamp_ms = timestamp_ms}); + }; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + // NormalizedRect around the soccer ball. + NormalizedRect roi; + roi.set_x_center(0.532); + roi.set_y_center(0.521); + roi.set_width(0.164); + roi.set_height(0.427); + + for (int i = 0; i < iterations; ++i) { + MP_ASSERT_OK(image_classifier->ClassifyAsync(image, i, roi)); + } + MP_ASSERT_OK(image_classifier->Close()); + + // Due to the flow limiter, the total of outputs will be smaller than the + // number of iterations. + ASSERT_LE(results.size(), iterations); + ASSERT_GT(results.size(), 0); + int64 timestamp_ms = -1; + for (const auto& result : results) { + EXPECT_GT(result.timestamp_ms, timestamp_ms); + timestamp_ms = result.timestamp_ms; + EXPECT_EQ(result.image_size.first, image.width()); + EXPECT_EQ(result.image_size.second, image.height()); + ExpectApproximatelyEqual(result.classification_result, + GenerateSoccerBallResults(timestamp_ms)); + } +} + +} // namespace +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD b/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD new file mode 100644 index 000000000..dc8241799 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD @@ -0,0 +1,30 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "image_classifier_options_proto", + srcs = ["image_classifier_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/proto:classifier_options_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/image_classification/image_classifier_options.proto b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.proto similarity index 86% rename from mediapipe/tasks/cc/vision/image_classification/image_classifier_options.proto rename to mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.proto index 21fb3cd8c..8aa8b4615 100644 --- a/mediapipe/tasks/cc/vision/image_classification/image_classifier_options.proto +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.proto @@ -15,10 +15,10 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks; +package mediapipe.tasks.vision.image_classifier.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/classifier_options.proto"; +import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; message ImageClassifierOptions { @@ -31,5 +31,5 @@ message ImageClassifierOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - optional ClassifierOptions classifier_options = 2; + optional components.proto.ClassifierOptions classifier_options = 2; } diff --git a/mediapipe/tasks/cc/vision/image_embedder/BUILD b/mediapipe/tasks/cc/vision/image_embedder/BUILD new file mode 100644 index 000000000..e619b8d1b --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_embedder/BUILD @@ -0,0 +1,68 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "image_embedder_graph", + srcs = ["image_embedder_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", + "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + +cc_library( + name = "image_embedder", + srcs = ["image_embedder.cc"], + hdrs = ["image_embedder.h"], + deps = [ + ":image_embedder_graph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/tool:options_map", + "//mediapipe/tasks/cc/components:embedder_options", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto", + "//mediapipe/tasks/cc/components/utils:cosine_similarity", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], +) + +# TODO: This test fails in OSS diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc new file mode 100644 index 000000000..24fd2862c --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc @@ -0,0 +1,218 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/image_embedder/image_embedder.h" + +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/tool/options_map.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#include "mediapipe/tasks/cc/components/embedder_options.h" +#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" +#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" +#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" +#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_embedder { + +namespace { + +constexpr char kEmbeddingResultStreamName[] = "embedding_result_out"; +constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT"; +constexpr char kImageInStreamName[] = "image_in"; +constexpr char kImageOutStreamName[] = "image_out"; +constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectStreamName[] = "norm_rect_in"; +constexpr char kNormRectTag[] = "NORM_RECT"; + +constexpr char kGraphTypeName[] = + "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"; +constexpr int kMicroSecondsPerMilliSecond = 1000; + +using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; +using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +using ::mediapipe::tasks::core::PacketMap; +using ::mediapipe::tasks::vision::image_embedder::proto:: + ImageEmbedderGraphOptions; + +// Builds a NormalizedRect covering the entire image. +NormalizedRect BuildFullImageNormRect() { + NormalizedRect norm_rect; + norm_rect.set_x_center(0.5); + norm_rect.set_y_center(0.5); + norm_rect.set_width(1); + norm_rect.set_height(1); + return norm_rect; +} + +// Creates a MediaPipe graph config that contains a single node of type +// "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph". If the task is +// running in the live stream mode, a "FlowLimiterCalculator" will be added to +// limit the number of frames in flight. +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options_proto, + bool enable_flow_limiting) { + api2::builder::Graph graph; + graph.In(kImageTag).SetName(kImageInStreamName); + graph.In(kNormRectTag).SetName(kNormRectStreamName); + auto& task_graph = graph.AddNode(kGraphTypeName); + task_graph.GetOptions().Swap(options_proto.get()); + task_graph.Out(kEmbeddingResultTag).SetName(kEmbeddingResultStreamName) >> + graph.Out(kEmbeddingResultTag); + task_graph.Out(kImageTag).SetName(kImageOutStreamName) >> + graph.Out(kImageTag); + if (enable_flow_limiting) { + return tasks::core::AddFlowLimiterCalculator( + graph, task_graph, {kImageTag, kNormRectTag}, kEmbeddingResultTag); + } + graph.In(kImageTag) >> task_graph.In(kImageTag); + graph.In(kNormRectTag) >> task_graph.In(kNormRectTag); + return graph.GetConfig(); +} + +// Converts the user-facing ImageEmbedderOptions struct to the internal +// ImageEmbedderGraphOptions proto. +std::unique_ptr ConvertImageEmbedderOptionsToProto( + ImageEmbedderOptions* options) { + auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + options_proto->mutable_base_options()->set_use_stream_mode( + options->running_mode != core::RunningMode::IMAGE); + auto embedder_options_proto = + std::make_unique( + components::ConvertEmbedderOptionsToProto( + &(options->embedder_options))); + options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get()); + return options_proto; +} + +} // namespace + +absl::StatusOr> ImageEmbedder::Create( + std::unique_ptr options) { + auto options_proto = ConvertImageEmbedderOptionsToProto(options.get()); + tasks::core::PacketsCallback packets_callback = nullptr; + if (options->result_callback) { + auto result_callback = options->result_callback; + packets_callback = + [=](absl::StatusOr status_or_packets) { + if (!status_or_packets.ok()) { + Image image; + result_callback(status_or_packets.status(), image, + Timestamp::Unset().Value()); + } + if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { + return; + } + Packet embedding_result_packet = + status_or_packets.value()[kEmbeddingResultStreamName]; + Packet image_packet = status_or_packets.value()[kImageOutStreamName]; + result_callback(embedding_result_packet.Get(), + image_packet.Get(), + embedding_result_packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond); + }; + } + return core::VisionTaskApiFactory::Create( + CreateGraphConfig( + std::move(options_proto), + options->running_mode == core::RunningMode::LIVE_STREAM), + std::move(options->base_options.op_resolver), options->running_mode, + std::move(packets_callback)); +} + +absl::StatusOr ImageEmbedder::Embed( + Image image, std::optional roi) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "GPU input images are currently not supported.", + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + NormalizedRect norm_rect = + roi.has_value() ? roi.value() : BuildFullImageNormRect(); + ASSIGN_OR_RETURN( + auto output_packets, + ProcessImageData( + {{kImageInStreamName, MakePacket(std::move(image))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect))}})); + return output_packets[kEmbeddingResultStreamName].Get(); +} + +absl::StatusOr ImageEmbedder::EmbedForVideo( + Image image, int64 timestamp_ms, std::optional roi) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "GPU input images are currently not supported.", + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + NormalizedRect norm_rect = + roi.has_value() ? roi.value() : BuildFullImageNormRect(); + ASSIGN_OR_RETURN( + auto output_packets, + ProcessVideoData( + {{kImageInStreamName, + MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); + return output_packets[kEmbeddingResultStreamName].Get(); +} + +absl::Status ImageEmbedder::EmbedAsync(Image image, int64 timestamp_ms, + std::optional roi) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "GPU input images are currently not supported.", + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + NormalizedRect norm_rect = + roi.has_value() ? roi.value() : BuildFullImageNormRect(); + return SendLiveStreamData( + {{kImageInStreamName, + MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); +} + +absl::StatusOr ImageEmbedder::CosineSimilarity( + const EmbeddingEntry& u, const EmbeddingEntry& v) { + return components::utils::CosineSimilarity(u, v); +} + +} // namespace image_embedder +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h new file mode 100644 index 000000000..13f4702d1 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h @@ -0,0 +1,161 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_EMBEDDER_IMAGE_EMBEDDER_H_ +#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_EMBEDDER_IMAGE_EMBEDDER_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#include "mediapipe/tasks/cc/components/embedder_options.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_embedder { + +// The options for configuring a MediaPipe image embedder task. +struct ImageEmbedderOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + tasks::core::BaseOptions base_options; + + // The running mode of the task. Default to the image mode. + // Image embedder has three running modes: + // 1) The image mode for embedding image on single image inputs. + // 2) The video mode for embedding image on the decoded frames of a video. + // 3) The live stream mode for embedding image on the live stream of input + // data, such as from camera. In this mode, the "result_callback" below must + // be specified to receive the embedding results asynchronously. + core::RunningMode running_mode = core::RunningMode::IMAGE; + + // Options for configuring the embedder behavior, such as L2-normalization or + // scalar-quantization. + components::EmbedderOptions embedder_options; + + // The user-defined result callback for processing live stream data. + // The result callback should only be specified when the running mode is set + // to RunningMode::LIVE_STREAM. + std::function, + const Image&, int64)> + result_callback = nullptr; +}; + +// Performs embedding extraction on images. +// +// The API expects a TFLite model with optional, but strongly recommended, +// TFLite Model Metadata. +// +// Input tensor: +// (kTfLiteUInt8/kTfLiteFloat32) +// - image input of size `[batch x height x width x channels]`. +// - batch inference is not supported (`batch` is required to be 1). +// - only RGB inputs are supported (`channels` is required to be 3). +// - if type is kTfLiteFloat32, NormalizationOptions are required to be +// attached to the metadata for input normalization. +// At least one output tensor with: +// (kTfLiteUInt8/kTfLiteFloat32) +// - `N` components corresponding to the `N` dimensions of the returned +// feature vector for this output layer. +// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. +class ImageEmbedder : core::BaseVisionTaskApi { + public: + using BaseVisionTaskApi::BaseVisionTaskApi; + + // Creates an ImageEmbedder from the provided options. A non-default + // OpResolver can be specified in the BaseOptions in order to support custom + // Ops or specify a subset of built-in Ops. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Performs embedding extraction on the provided single image. Extraction + // is performed on the region of interest specified by the `roi` argument if + // provided, or on the entire image otherwise. + // + // Only use this method when the ImageEmbedder is created with the image + // running mode. + // + // The image can be of any size with format RGB or RGBA. + absl::StatusOr Embed( + mediapipe::Image image, + std::optional roi = std::nullopt); + + // Performs embedding extraction on the provided video frame. Extraction + // is performed on the region of interested specified by the `roi` argument if + // provided, or on the entire image otherwise. + // + // Only use this method when the ImageEmbedder is created with the video + // running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide the video frame's timestamp (in milliseconds). The input timestamps + // must be monotonically increasing. + absl::StatusOr EmbedForVideo( + mediapipe::Image image, int64 timestamp_ms, + std::optional roi = std::nullopt); + + // Sends live image data to embedder, and the results will be available via + // the "result_callback" provided in the ImageEmbedderOptions. Embedding + // extraction is performed on the region of interested specified by the `roi` + // argument if provided, or on the entire image otherwise. + // + // Only use this method when the ImageEmbedder is created with the live + // stream running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide a timestamp (in milliseconds) to indicate when the input image is + // sent to the object detector. The input timestamps must be monotonically + // increasing. + // + // The "result_callback" prvoides + // - The embedding results as a + // components::containers::proto::EmbeddingResult object. + // - The const reference to the corresponding input image that the image + // embedder runs on. Note that the const reference to the image will no + // longer be valid when the callback returns. To access the image data + // outside of the callback, callers need to make a copy of the image. + // - The input timestamp in milliseconds. + absl::Status EmbedAsync( + mediapipe::Image image, int64 timestamp_ms, + std::optional roi = std::nullopt); + + // Shuts down the ImageEmbedder when all works are done. + absl::Status Close() { return runner_->Close(); } + + // Utility function to compute cosine similarity [1] between two embedding + // entries. May return an InvalidArgumentError if e.g. the feature vectors are + // of different types (quantized vs. float), have different sizes, or have a + // an L2-norm of 0. + // + // [1]: https://en.wikipedia.org/wiki/Cosine_similarity + static absl::StatusOr CosineSimilarity( + const components::containers::proto::EmbeddingEntry& u, + const components::containers::proto::EmbeddingEntry& v); +}; + +} // namespace image_embedder +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_EMBEDDER_IMAGE_EMBEDDER_H_ diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc new file mode 100644 index 000000000..fff0f4366 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -0,0 +1,172 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_embedder { + +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::GenericNode; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +using ::mediapipe::tasks::components::proto:: + EmbeddingPostprocessingGraphOptions; + +constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT"; +constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kTensorsTag[] = "TENSORS"; + +// Struct holding the different output streams produced by the image embedder +// graph. +struct ImageEmbedderOutputStreams { + Source embedding_result; + Source image; +}; + +} // namespace + +// An ImageEmbedderGraph performs image embedding extraction. +// - Accepts CPU input images and outputs embeddings on CPU. +// +// Inputs: +// IMAGE - Image +// Image to perform embedding extraction on. +// NORM_RECT - NormalizedRect @Optional +// Describes region of image to perform embedding extraction on. +// @Optional: rect covering the whole image is used if not specified. +// Outputs: +// EMBEDDING_RESULT - EmbeddingResult +// The embedding result. +// IMAGE - Image +// The image that embedding extraction runs on. +// +// Example: +// node { +// calculator: "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph" +// input_stream: "IMAGE:image_in" +// output_stream: "EMBEDDING_RESULT:embedding_result_out" +// output_stream: "IMAGE:image_out" +// options { +// [mediapipe.tasks.vision.image_embedder.proto.ImageEmbedderOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "/path/to/model.tflite" +// } +// } +// embedder_options { +// l2_normalize: true +// } +// } +// } +// } +class ImageEmbedderGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + ASSIGN_OR_RETURN( + const auto* model_resources, + CreateModelResources(sc)); + Graph graph; + ASSIGN_OR_RETURN( + auto output_streams, + BuildImageEmbedderTask( + sc->Options(), *model_resources, + graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], graph)); + output_streams.embedding_result >> + graph[Output(kEmbeddingResultTag)]; + output_streams.image >> graph[Output(kImageTag)]; + return graph.GetConfig(); + } + + private: + // Adds a mediapipe image embedding teask graph into the provided + // builder::Graph instance. The image embedding task takes images + // (mediapipe::Image) and optional region-of-interest + // (mediapipe::NormalizedRect) as inputs and returns on embedding result per + // input image. + // + // task_options: the mediapipe tasks ImageEmbedderGraphOptions. + // model_resources: the ModelSources object initialized from an image + // embedding model file with model optional metadata. + // image_in: (mediapipe::Image) stream to run embedding extraction on. + // norm_rect_in: (mediapipe::NormalizedRect) optional region-of-interest to + // perform embedding extraction on. + // graph: the mediapipe builder::Graph instance to be updated. + absl::StatusOr BuildImageEmbedderTask( + const proto::ImageEmbedderGraphOptions& task_options, + const core::ModelResources& model_resources, Source image_in, + Source norm_rect_in, Graph& graph) { + // Adds preprocessing calculators and connects them to the graph input image + // stream. + auto& preprocessing = + graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + model_resources, + &preprocessing + .GetOptions())); + image_in >> preprocessing.In(kImageTag); + norm_rect_in >> preprocessing.In(kNormRectTag); + + // Adds inference subgraph and connects its input stream to the outoput + // tensors produced by the ImageToTensorCalculator. + auto& inference = AddInference( + model_resources, task_options.base_options().acceleration(), graph); + preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag); + + // Adds postprocessing calculators and connects its input stream to the + // inference results. + auto& postprocessing = graph.AddNode( + "mediapipe.tasks.components.EmbeddingPostprocessingGraph"); + MP_RETURN_IF_ERROR(components::ConfigureEmbeddingPostprocessing( + model_resources, task_options.embedder_options(), + &postprocessing.GetOptions())); + inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); + + // Outputs the embedding results. + return ImageEmbedderOutputStreams{ + /*embedding_result=*/postprocessing[Output( + kEmbeddingResultTag)], + /*image=*/preprocessing[Output(kImageTag)]}; + } +}; +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::image_embedder::ImageEmbedderGraph); + +} // namespace image_embedder +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc new file mode 100644 index 000000000..08a0d6a25 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc @@ -0,0 +1,557 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/image_embedder/image_embedder.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/kernels/builtin_op_kernels.h" +#include "tensorflow/lite/mutable_op_resolver.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_embedder { +namespace { + +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +using ::testing::HasSubstr; +using ::testing::Optional; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kMobileNetV3Embedder[] = + "mobilenet_v3_small_100_224_embedder.tflite"; +constexpr double kSimilarityTolerancy = 1e-6; + +// Utility function to check the sizes, head_index and head_names of a result +// procuded by kMobileNetV3Embedder. +void CheckMobileNetV3Result(const EmbeddingResult& result, bool quantized) { + EXPECT_EQ(result.embeddings().size(), 1); + EXPECT_EQ(result.embeddings(0).head_index(), 0); + EXPECT_EQ(result.embeddings(0).head_name(), "feature"); + EXPECT_EQ(result.embeddings(0).entries().size(), 1); + if (quantized) { + EXPECT_EQ( + result.embeddings(0).entries(0).quantized_embedding().values().size(), + 1024); + } else { + EXPECT_EQ(result.embeddings(0).entries(0).float_embedding().values().size(), + 1024); + } +} + +// A custom OpResolver only containing the Ops required by the test model. +class MobileNetV3OpResolver : public ::tflite::MutableOpResolver { + public: + MobileNetV3OpResolver() { + AddBuiltin(::tflite::BuiltinOperator_MUL, + ::tflite::ops::builtin::Register_MUL()); + AddBuiltin(::tflite::BuiltinOperator_SUB, + ::tflite::ops::builtin::Register_SUB()); + AddBuiltin(::tflite::BuiltinOperator_CONV_2D, + ::tflite::ops::builtin::Register_CONV_2D()); + AddBuiltin(::tflite::BuiltinOperator_HARD_SWISH, + ::tflite::ops::builtin::Register_HARD_SWISH()); + AddBuiltin(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D, + ::tflite::ops::builtin::Register_DEPTHWISE_CONV_2D()); + AddBuiltin(::tflite::BuiltinOperator_MEAN, + ::tflite::ops::builtin::Register_MEAN()); + AddBuiltin(::tflite::BuiltinOperator_ADD, + ::tflite::ops::builtin::Register_ADD()); + AddBuiltin(::tflite::BuiltinOperator_AVERAGE_POOL_2D, + ::tflite::ops::builtin::Register_AVERAGE_POOL_2D()); + AddBuiltin(::tflite::BuiltinOperator_RESHAPE, + ::tflite::ops::builtin::Register_RESHAPE()); + } + + MobileNetV3OpResolver(const MobileNetV3OpResolver& r) = delete; +}; + +// A custom OpResolver missing Ops required by the test model. +class MobileNetV3OpResolverMissingOps : public ::tflite::MutableOpResolver { + public: + MobileNetV3OpResolverMissingOps() { + AddBuiltin(::tflite::BuiltinOperator_SOFTMAX, + ::tflite::ops::builtin::Register_SOFTMAX()); + } + + MobileNetV3OpResolverMissingOps(const MobileNetV3OpResolverMissingOps& r) = + delete; +}; + +class CreateTest : public tflite_shims::testing::Test {}; + +TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + options->base_options.op_resolver = std::make_unique(); + + MP_ASSERT_OK(ImageEmbedder::Create(std::move(options))); +} + +TEST_F(CreateTest, FailsWithSelectiveOpResolverMissingOps) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + options->base_options.op_resolver = + std::make_unique(); + + auto image_embedder = ImageEmbedder::Create(std::move(options)); + + EXPECT_EQ(image_embedder.status().code(), absl::StatusCode::kInternal); + EXPECT_THAT(image_embedder.status().message(), + HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk")); +} + +TEST_F(CreateTest, FailsWithMissingModel) { + auto image_embedder = + ImageEmbedder::Create(std::make_unique()); + + EXPECT_EQ(image_embedder.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + image_embedder.status().message(), + HasSubstr("ExternalFile must specify at least one of 'file_content', " + "'file_name' or 'file_descriptor_meta'.")); + EXPECT_THAT(image_embedder.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInitializationError)))); +} + +TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) { + for (auto running_mode : + {core::RunningMode::IMAGE, core::RunningMode::VIDEO}) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + options->running_mode = running_mode; + options->result_callback = [](absl::StatusOr, + const Image& image, int64 timestamp_ms) {}; + + auto image_embedder = ImageEmbedder::Create(std::move(options)); + + EXPECT_EQ(image_embedder.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + image_embedder.status().message(), + HasSubstr("a user-defined result callback shouldn't be provided")); + EXPECT_THAT(image_embedder.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kInvalidTaskGraphConfigError)))); + } +} + +TEST_F(CreateTest, FailsWithMissingCallbackInLiveStreamMode) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + options->running_mode = core::RunningMode::LIVE_STREAM; + + auto image_embedder = ImageEmbedder::Create(std::move(options)); + + EXPECT_EQ(image_embedder.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_embedder.status().message(), + HasSubstr("a user-defined result callback must be provided")); + EXPECT_THAT(image_embedder.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kInvalidTaskGraphConfigError)))); +} + +class ImageModeTest : public tflite_shims::testing::Test {}; + +TEST_F(ImageModeTest, FailsWithCallingWrongMethod) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, + ImageEmbedder::Create(std::move(options))); + + auto results = image_embedder->EmbedForVideo(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the video mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + + results = image_embedder->EmbedAsync(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the live stream mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + MP_ASSERT_OK(image_embedder->Close()); +} + +TEST_F(ImageModeTest, SucceedsWithoutL2Normalization) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, + ImageEmbedder::Create(std::move(options))); + // Load images: one is a crop of the other. + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + MP_ASSERT_OK_AND_ASSIGN( + Image crop, DecodeImageFromFile( + JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); + + // Extract both embeddings. + MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, + image_embedder->Embed(image)); + MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, + image_embedder->Embed(crop)); + + // Check results. + CheckMobileNetV3Result(image_result, false); + CheckMobileNetV3Result(crop_result, false); + // CheckCosineSimilarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, + ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), + crop_result.embeddings(0).entries(0))); + double expected_similarity = 0.925519; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); +} + +TEST_F(ImageModeTest, SucceedsWithL2Normalization) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + options->embedder_options.l2_normalize = true; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, + ImageEmbedder::Create(std::move(options))); + // Load images: one is a crop of the other. + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + MP_ASSERT_OK_AND_ASSIGN( + Image crop, DecodeImageFromFile( + JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); + + // Extract both embeddings. + MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, + image_embedder->Embed(image)); + MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, + image_embedder->Embed(crop)); + + // Check results. + CheckMobileNetV3Result(image_result, false); + CheckMobileNetV3Result(crop_result, false); + // CheckCosineSimilarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, + ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), + crop_result.embeddings(0).entries(0))); + double expected_similarity = 0.925519; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); +} + +TEST_F(ImageModeTest, SucceedsWithQuantization) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + options->embedder_options.quantize = true; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, + ImageEmbedder::Create(std::move(options))); + // Load images: one is a crop of the other. + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + MP_ASSERT_OK_AND_ASSIGN( + Image crop, DecodeImageFromFile( + JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); + + // Extract both embeddings. + MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, + image_embedder->Embed(image)); + MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, + image_embedder->Embed(crop)); + + // Check results. + CheckMobileNetV3Result(image_result, true); + CheckMobileNetV3Result(crop_result, true); + // CheckCosineSimilarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, + ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), + crop_result.embeddings(0).entries(0))); + double expected_similarity = 0.926791; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); +} + +TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, + ImageEmbedder::Create(std::move(options))); + // Load images: one is a crop of the other. + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + MP_ASSERT_OK_AND_ASSIGN( + Image crop, DecodeImageFromFile( + JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); + // Bounding box in "burger.jpg" corresponding to "burger_crop.jpg". + NormalizedRect roi; + roi.set_x_center(200.0 / 480); + roi.set_y_center(0.5); + roi.set_width(400.0 / 480); + roi.set_height(1.0f); + + // Extract both embeddings. + MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, + image_embedder->Embed(image, roi)); + MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, + image_embedder->Embed(crop)); + + // Check results. + CheckMobileNetV3Result(image_result, false); + CheckMobileNetV3Result(crop_result, false); + // CheckCosineSimilarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, + ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), + crop_result.embeddings(0).entries(0))); + double expected_similarity = 0.999931; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); +} + +class VideoModeTest : public tflite_shims::testing::Test {}; + +TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + options->running_mode = core::RunningMode::VIDEO; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, + ImageEmbedder::Create(std::move(options))); + + auto results = image_embedder->Embed(image); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the image mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + + results = image_embedder->EmbedAsync(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the live stream mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + MP_ASSERT_OK(image_embedder->Close()); +} + +TEST_F(VideoModeTest, FailsWithOutOfOrderInputTimestamps) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + options->running_mode = core::RunningMode::VIDEO; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, + ImageEmbedder::Create(std::move(options))); + + MP_ASSERT_OK(image_embedder->EmbedForVideo(image, 1)); + auto results = image_embedder->EmbedForVideo(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("timestamp must be monotonically increasing")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInvalidTimestampError)))); + MP_ASSERT_OK(image_embedder->EmbedForVideo(image, 2)); + MP_ASSERT_OK(image_embedder->Close()); +} + +TEST_F(VideoModeTest, Succeeds) { + int iterations = 100; + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + options->running_mode = core::RunningMode::VIDEO; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, + ImageEmbedder::Create(std::move(options))); + + EmbeddingResult previous_results; + for (int i = 0; i < iterations; ++i) { + MP_ASSERT_OK_AND_ASSIGN(auto results, + image_embedder->EmbedForVideo(image, i)); + CheckMobileNetV3Result(results, false); + if (i > 0) { + MP_ASSERT_OK_AND_ASSIGN(double similarity, + ImageEmbedder::CosineSimilarity( + results.embeddings(0).entries(0), + previous_results.embeddings(0).entries(0))); + double expected_similarity = 1.000000; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); + } + previous_results = results; + } + MP_ASSERT_OK(image_embedder->Close()); +} + +class LiveStreamModeTest : public tflite_shims::testing::Test {}; + +TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + options->running_mode = core::RunningMode::LIVE_STREAM; + options->result_callback = [](absl::StatusOr, + const Image& image, int64 timestamp_ms) {}; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, + ImageEmbedder::Create(std::move(options))); + + auto results = image_embedder->Embed(image); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the image mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + + results = image_embedder->EmbedForVideo(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the video mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + MP_ASSERT_OK(image_embedder->Close()); +} + +TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + options->running_mode = core::RunningMode::LIVE_STREAM; + options->result_callback = [](absl::StatusOr, + const Image& image, int64 timestamp_ms) {}; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, + ImageEmbedder::Create(std::move(options))); + + MP_ASSERT_OK(image_embedder->EmbedAsync(image, 1)); + auto status = image_embedder->EmbedAsync(image, 0); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + HasSubstr("timestamp must be monotonically increasing")); + EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInvalidTimestampError)))); + MP_ASSERT_OK(image_embedder->EmbedAsync(image, 2)); + MP_ASSERT_OK(image_embedder->Close()); +} + +struct LiveStreamModeResults { + EmbeddingResult embedding_result; + std::pair image_size; + int64 timestamp_ms; +}; + +TEST_F(LiveStreamModeTest, Succeeds) { + int iterations = 100; + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + std::vector results; + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + options->running_mode = core::RunningMode::LIVE_STREAM; + options->result_callback = + [&results](absl::StatusOr embedding_result, + const Image& image, int64 timestamp_ms) { + MP_ASSERT_OK(embedding_result.status()); + results.push_back( + {.embedding_result = std::move(embedding_result).value(), + .image_size = {image.width(), image.height()}, + .timestamp_ms = timestamp_ms}); + }; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, + ImageEmbedder::Create(std::move(options))); + + for (int i = 0; i < iterations; ++i) { + MP_ASSERT_OK(image_embedder->EmbedAsync(image, i)); + } + MP_ASSERT_OK(image_embedder->Close()); + + // Due to the flow limiter, the total of outputs will be smaller than the + // number of iterations. + ASSERT_LE(results.size(), iterations); + ASSERT_GT(results.size(), 0); + int64 timestamp_ms = -1; + for (int i = 0; i < results.size(); ++i) { + const auto& result = results[i]; + EXPECT_GT(result.timestamp_ms, timestamp_ms); + timestamp_ms = result.timestamp_ms; + EXPECT_EQ(result.image_size.first, image.width()); + EXPECT_EQ(result.image_size.second, image.height()); + CheckMobileNetV3Result(result.embedding_result, false); + if (i > 0) { + MP_ASSERT_OK_AND_ASSIGN( + double similarity, + ImageEmbedder::CosineSimilarity( + result.embedding_result.embeddings(0).entries(0), + results[i - 1].embedding_result.embeddings(0).entries(0))); + double expected_similarity = 1.000000; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); + } + } +} + +} // namespace +} // namespace image_embedder +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_embedder/proto/BUILD b/mediapipe/tasks/cc/vision/image_embedder/proto/BUILD new file mode 100644 index 000000000..83407001f --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_embedder/proto/BUILD @@ -0,0 +1,30 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "image_embedder_graph_options_proto", + srcs = ["image_embedder_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/proto:embedder_options_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto new file mode 100644 index 000000000..e5e31a729 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto @@ -0,0 +1,35 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.vision.image_embedder.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/proto/embedder_options.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; + +message ImageEmbedderGraphOptions { + extend mediapipe.CalculatorOptions { + optional ImageEmbedderGraphOptions ext = 476348187; + } + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite + // model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // Options for configuring the embedder behavior, such as normalization or + // quantization. + optional components.proto.EmbedderOptions embedder_options = 2; +} diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index cb0482e42..6af733657 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -24,7 +24,7 @@ cc_library( ":image_segmenter_graph", "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", - "//mediapipe/tasks/cc/components:segmenter_options_cc_proto", + "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", @@ -53,11 +53,12 @@ cc_library( "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components:segmenter_options_cc_proto", "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator", "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator_cc_proto", + "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 090149d92..84ceea88a 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -17,7 +17,7 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/tasks/cc/components/segmenter_options.pb.h" +#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" @@ -34,9 +34,11 @@ constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageTag[] = "IMAGE"; constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.ImageSegmenterGraph"; +constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; +using ::mediapipe::tasks::components::proto::SegmenterOptions; using ImageSegmenterOptionsProto = image_segmenter::proto::ImageSegmenterOptions; @@ -105,6 +107,28 @@ absl::StatusOr> ImageSegmenter::Create( std::unique_ptr options) { auto options_proto = ConvertImageSegmenterOptionsToProto(options.get()); tasks::core::PacketsCallback packets_callback = nullptr; + if (options->result_callback) { + auto result_callback = options->result_callback; + packets_callback = + [=](absl::StatusOr status_or_packets) { + if (!status_or_packets.ok()) { + Image image; + result_callback(status_or_packets.status(), image, + Timestamp::Unset().Value()); + return; + } + if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { + return; + } + Packet segmented_masks = + status_or_packets.value()[kSegmentationStreamName]; + Packet image_packet = status_or_packets.value()[kImageOutStreamName]; + result_callback(segmented_masks.Get>(), + image_packet.Get(), + segmented_masks.Timestamp().Value() / + kMicroSecondsPerMilliSecond); + }; + } return core::VisionTaskApiFactory::Create( CreateGraphConfig( @@ -129,6 +153,36 @@ absl::StatusOr> ImageSegmenter::Segment( return output_packets[kSegmentationStreamName].Get>(); } +absl::StatusOr> ImageSegmenter::SegmentForVideo( + mediapipe::Image image, int64 timestamp_ms) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("GPU input images are currently not supported."), + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + ASSIGN_OR_RETURN( + auto output_packets, + ProcessVideoData( + {{kImageInStreamName, + MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); + return output_packets[kSegmentationStreamName].Get>(); +} + +absl::Status ImageSegmenter::SegmentAsync(Image image, int64 timestamp_ms) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("GPU input images are currently not supported."), + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + return SendLiveStreamData( + {{kImageInStreamName, + MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); +} + } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index 00c63953a..ce9cb104c 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -112,8 +112,60 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { static absl::StatusOr> Create( std::unique_ptr options); - // Runs the actual segmentation task. + // Performs image segmentation on the provided single image. + // Only use this method when the ImageSegmenter is created with the image + // running mode. + // + // The image can be of any size with format RGB or RGBA. + // TODO: Describes how the input image will be preprocessed + // after the yuv support is implemented. + // + // If the output_type is CATEGORY_MASK, the returned vector of images is + // per-category segmented image mask. + // If the output_type is CONFIDENCE_MASK, the returned vector of images + // contains only one confidence image mask. absl::StatusOr> Segment(mediapipe::Image image); + + // Performs image segmentation on the provided video frame. + // Only use this method when the ImageSegmenter is created with the video + // running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide the video frame's timestamp (in milliseconds). The input timestamps + // must be monotonically increasing. + // + // If the output_type is CATEGORY_MASK, the returned vector of images is + // per-category segmented image mask. + // If the output_type is CONFIDENCE_MASK, the returned vector of images + // contains only one confidence image mask. + absl::StatusOr> SegmentForVideo( + mediapipe::Image image, int64 timestamp_ms); + + // Sends live image data to perform image segmentation, and the results will + // be available via the "result_callback" provided in the + // ImageSegmenterOptions. Only use this method when the ImageSegmenter is + // created with the live stream running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide a timestamp (in milliseconds) to indicate when the input image is + // sent to the image segmenter. The input timestamps must be monotonically + // increasing. + // + // The "result_callback" prvoides + // - A vector of segmented image masks. + // If the output_type is CATEGORY_MASK, the returned vector of images is + // per-category segmented image mask. + // If the output_type is CONFIDENCE_MASK, the returned vector of images + // contains only one confidence image mask. + // - The const reference to the corresponding input image that the image + // segmentation runs on. Note that the const reference to the image will + // no longer be valid when the callback returns. To access the image data + // outside of the callback, callers need to make a copy of the image. + // - The input timestamp in milliseconds. + absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms); + + // Shuts down the ImageSegmenter when all works are done. + absl::Status Close() { return runner_->Close(); } }; } // namespace vision diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index d843689e2..1678dd083 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -28,9 +28,10 @@ limitations under the License. #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/segmenter_options.pb.h" +#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" @@ -51,7 +52,7 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::MultiSource; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::SegmenterOptions; +using ::mediapipe::tasks::components::proto::SegmenterOptions; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::mediapipe::tasks::vision::image_segmenter::proto::ImageSegmenterOptions; using ::tflite::Tensor; @@ -176,6 +177,11 @@ absl::StatusOr GetOutputTensor( // options { // [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterOptions.ext] // { +// base_options { +// model_asset { +// file_name: "/path/to/model.tflite" +// } +// } // segmenter_options { // output_type: CONFIDENCE_MASK // activation: SOFTMAX @@ -228,19 +234,22 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { // Adds preprocessing calculators and connects them to the graph input image // stream. auto& preprocessing = - graph.AddNode("mediapipe.tasks.ImagePreprocessingSubgraph"); + graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( model_resources, - &preprocessing.GetOptions())); + &preprocessing + .GetOptions())); image_in >> preprocessing.In(kImageTag); // Adds inference subgraph and connects its input stream to the output // tensors produced by the ImageToTensorCalculator. - auto& inference = AddInference(model_resources, graph); + auto& inference = AddInference( + model_resources, task_options.base_options().acceleration(), graph); preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag); // Adds segmentation calculators for output streams. - auto& tensor_to_images = graph.AddNode("TensorsToSegmentationCalculator"); + auto& tensor_to_images = + graph.AddNode("mediapipe.tasks.TensorsToSegmentationCalculator"); RET_CHECK_OK(ConfigureTensorsToSegmentationCalculator( task_options, model_resources, &tensor_to_images diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index f43d28fca..2f1c26a79 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -29,7 +29,6 @@ limitations under the License. #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h" @@ -164,7 +163,7 @@ class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver { TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) { auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); options->base_options.op_resolver = absl::make_unique(); MP_ASSERT_OK(ImageSegmenter::Create(std::move(options))); @@ -172,7 +171,7 @@ TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) { TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); options->base_options.op_resolver = absl::make_unique(); @@ -199,15 +198,15 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { MediaPipeTasksStatus::kRunnerInitializationError)))); } -class SegmentationTest : public tflite_shims::testing::Test {}; +class ImageModeTest : public tflite_shims::testing::Test {}; -TEST_F(SegmentationTest, SucceedsWithCategoryMask) { +TEST_F(ImageModeTest, SucceedsWithCategoryMask) { MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "segmentation_input_rotation0.jpg"))); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; @@ -227,12 +226,12 @@ TEST_F(SegmentationTest, SucceedsWithCategoryMask) { kGoldenMaskMagnificationFactor)); } -TEST_F(SegmentationTest, SucceedsWithConfidenceMask) { +TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg"))); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; options->activation = ImageSegmenterOptions::Activation::SOFTMAX; @@ -255,11 +254,11 @@ TEST_F(SegmentationTest, SucceedsWithConfidenceMask) { SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } -TEST_F(SegmentationTest, SucceedsSelfie128x128Segmentation) { +TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { Image image = GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg")); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata); options->base_options.op_resolver = absl::make_unique(); @@ -285,11 +284,11 @@ TEST_F(SegmentationTest, SucceedsSelfie128x128Segmentation) { SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } -TEST_F(SegmentationTest, SucceedsSelfie144x256Segmentations) { +TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { Image image = GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg")); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata); options->base_options.op_resolver = absl::make_unique(); @@ -313,6 +312,185 @@ TEST_F(SegmentationTest, SucceedsSelfie144x256Segmentations) { SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } +class VideoModeTest : public tflite_shims::testing::Test {}; + +TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "segmentation_input_rotation0.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + options->running_mode = core::RunningMode::VIDEO; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + auto results = segmenter->Segment(image); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the image mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + + results = segmenter->SegmentAsync(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the live stream mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + MP_ASSERT_OK(segmenter->Close()); +} + +TEST_F(VideoModeTest, Succeeds) { + constexpr int iterations = 100; + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "segmentation_input_rotation0.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + options->running_mode = core::RunningMode::VIDEO; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + cv::Mat expected_mask = cv::imread( + JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"), + cv::IMREAD_GRAYSCALE); + for (int i = 0; i < iterations; ++i) { + MP_ASSERT_OK_AND_ASSIGN(auto category_masks, + segmenter->SegmentForVideo(image, i)); + EXPECT_EQ(category_masks.size(), 1); + cv::Mat actual_mask = mediapipe::formats::MatView( + category_masks[0].GetImageFrameSharedPtr().get()); + EXPECT_THAT(actual_mask, + SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, + kGoldenMaskMagnificationFactor)); + } + MP_ASSERT_OK(segmenter->Close()); +} + +class LiveStreamModeTest : public tflite_shims::testing::Test {}; + +TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { + MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath( + "./", kTestDataDirectory, + "cats_and_dogs_no_resizing.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + options->running_mode = core::RunningMode::LIVE_STREAM; + options->result_callback = + [](absl::StatusOr> segmented_masks, const Image& image, + int64 timestamp_ms) {}; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + + auto results = segmenter->Segment(image); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the image mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + + results = segmenter->SegmentForVideo(image, 0); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("not initialized with the video mode")); + EXPECT_THAT(results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); + MP_ASSERT_OK(segmenter->Close()); +} + +TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { + MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath( + "./", kTestDataDirectory, + "cats_and_dogs_no_resizing.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + options->running_mode = core::RunningMode::LIVE_STREAM; + options->result_callback = + [](absl::StatusOr> segmented_masks, const Image& image, + int64 timestamp_ms) {}; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + MP_ASSERT_OK(segmenter->SegmentAsync(image, 1)); + + auto status = segmenter->SegmentAsync(image, 0); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + HasSubstr("timestamp must be monotonically increasing")); + EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInvalidTimestampError)))); + MP_ASSERT_OK(segmenter->SegmentAsync(image, 2)); + MP_ASSERT_OK(segmenter->Close()); +} + +TEST_F(LiveStreamModeTest, Succeeds) { + constexpr int iterations = 100; + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "segmentation_input_rotation0.jpg"))); + std::vector> segmented_masks_results; + std::vector> image_sizes; + std::vector timestamps; + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + options->running_mode = core::RunningMode::LIVE_STREAM; + options->result_callback = + [&segmented_masks_results, &image_sizes, ×tamps]( + absl::StatusOr> segmented_masks, + const Image& image, int64 timestamp_ms) { + MP_ASSERT_OK(segmented_masks.status()); + segmented_masks_results.push_back(std::move(segmented_masks).value()); + image_sizes.push_back({image.width(), image.height()}); + timestamps.push_back(timestamp_ms); + }; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + for (int i = 0; i < iterations; ++i) { + MP_ASSERT_OK(segmenter->SegmentAsync(image, i)); + } + MP_ASSERT_OK(segmenter->Close()); + // Due to the flow limiter, the total of outputs will be smaller than the + // number of iterations. + ASSERT_LE(segmented_masks_results.size(), iterations); + ASSERT_GT(segmented_masks_results.size(), 0); + cv::Mat expected_mask = cv::imread( + JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"), + cv::IMREAD_GRAYSCALE); + for (const auto& segmented_masks : segmented_masks_results) { + EXPECT_EQ(segmented_masks.size(), 1); + cv::Mat actual_mask = mediapipe::formats::MatView( + segmented_masks[0].GetImageFrameSharedPtr().get()); + EXPECT_THAT(actual_mask, + SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, + kGoldenMaskMagnificationFactor)); + } + for (const auto& image_size : image_sizes) { + EXPECT_EQ(image_size.first, image.width()); + EXPECT_EQ(image_size.second, image.height()); + } + int64 timestamp_ms = -1; + for (const auto& timestamp : timestamps) { + EXPECT_GT(timestamp, timestamp_ms); + timestamp_ms = timestamp; + } +} + // TODO: Add test for hair segmentation model. } // namespace diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD index b9b8ea436..d768c2bb1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD @@ -24,7 +24,7 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components:segmenter_options_proto", + "//mediapipe/tasks/cc/components/proto:segmenter_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.proto index fcb2914cf..6e24a6665 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.proto @@ -18,7 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_segmenter.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/segmenter_options.proto"; +import "mediapipe/tasks/cc/components/proto/segmenter_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; message ImageSegmenterOptions { @@ -34,5 +34,5 @@ message ImageSegmenterOptions { optional string display_names_locale = 2 [default = "en"]; // Segmentation output options. - optional SegmenterOptions segmenter_options = 3; + optional components.proto.SegmenterOptions segmenter_options = 3; } diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 515608418..6a9a25fc1 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -20,6 +20,7 @@ cc_library( name = "object_detector_graph", srcs = ["object_detector_graph.cc"], deps = [ + "//mediapipe/calculators/core:split_vector_calculator_cc_proto", "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/tensor:tensors_to_detections_calculator", "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", @@ -32,11 +33,16 @@ cc_library( "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", + "//mediapipe/tasks/cc/components/utils:source_or_node_output", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", @@ -55,10 +61,13 @@ cc_library( hdrs = ["object_detector.h"], deps = [ ":object_detector_graph", + "//mediapipe/calculators/core:concatenate_vector_calculator", + "//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image", "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc index d56c25066..8b7473d48 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc @@ -152,7 +152,7 @@ absl::StatusOr> ObjectDetector::Detect( return output_packets[kDetectionsOutStreamName].Get>(); } -absl::StatusOr> ObjectDetector::Detect( +absl::StatusOr> ObjectDetector::DetectForVideo( mediapipe::Image image, int64 timestamp_ms) { if (image.UsesGpu()) { return CreateStatusWithPayload( diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.h b/mediapipe/tasks/cc/vision/object_detector/object_detector.h index e98013223..0fa1b087b 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.h +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.h @@ -113,11 +113,17 @@ struct ObjectDetectorOptions { // (kTfLiteFloat32) // - scores tensor of size `[num_results]`, each value representing the score // of the detected object. +// - optional score calibration can be attached using ScoreCalibrationOptions +// and an AssociatedFile with type TENSOR_AXIS_SCORE_CALIBRATION. See +// metadata_schema.fbs [1] for more details. // (kTfLiteFloat32) // - integer num_results as a tensor of size `[1]` // // An example of such model can be found at: // https://tfhub.dev/google/lite-model/object_detection/mobile_object_localizer_v1/1/metadata/1 +// +// [1]: +// https://github.com/google/mediapipe/blob/6cdc6443b6a7ed662744e2a2ce2d58d9c83e6d6f/mediapipe/tasks/metadata/metadata_schema.fbs#L456 class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { public: using BaseVisionTaskApi::BaseVisionTaskApi; @@ -166,7 +172,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // unrotated input frame of reference coordinates system, i.e. in `[0, // image_width) x [0, image_height)`, which are the dimensions of the // underlying image data. - absl::StatusOr> Detect( + absl::StatusOr> DetectForVideo( mediapipe::Image image, int64 timestamp_ms); // Sends live image data to perform object detection, and the results will be diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index 94f217378..b0533e469 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -14,11 +14,13 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "mediapipe/calculators/core/split_vector_calculator.pb.h" #include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" #include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" @@ -26,10 +28,15 @@ limitations under the License. #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" @@ -59,6 +66,8 @@ using ::tflite::TensorMetadata; using LabelItems = mediapipe::proto_ns::Map; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; +using TensorsSource = + mediapipe::tasks::SourceOrNodeOutput>; constexpr int kDefaultLocationsIndex = 0; constexpr int kDefaultCategoriesIndex = 1; @@ -72,12 +81,15 @@ constexpr char kCategoryTensorName[] = "category"; constexpr char kScoreTensorName[] = "score"; constexpr char kNumberOfDetectionsTensorName[] = "number of detections"; +constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES"; constexpr char kDetectionsTag[] = "DETECTIONS"; -constexpr char kImageTag[] = "IMAGE"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kImageTag[] = "IMAGE"; +constexpr char kIndicesTag[] = "INDICES"; constexpr char kMatrixTag[] = "MATRIX"; constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS"; constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX"; +constexpr char kScoresTag[] = "SCORES"; constexpr char kTensorTag[] = "TENSORS"; // Struct holding the different output streams produced by the object detection @@ -111,7 +123,8 @@ struct PostProcessingSpecs { absl::flat_hash_set allow_or_deny_categories; // Indicates `allow_or_deny_categories` is an allowlist or a denylist. bool is_allowlist; - // TODO: Adds score calibration. + // Score calibration options, if any. + std::optional score_calibration_options; }; absl::Status SanityCheckOptions(const ObjectDetectorOptionsProto& options) { @@ -265,6 +278,43 @@ absl::StatusOr> GetAllowOrDenyCategoryIndicesIfAny( return category_indices; } +absl::StatusOr> +GetScoreCalibrationOptionsIfAny( + const ModelMetadataExtractor& metadata_extractor, + const TensorMetadata& tensor_metadata) { + // Get ScoreCalibrationOptions, if any. + ASSIGN_OR_RETURN( + const ProcessUnit* score_calibration_process_unit, + metadata_extractor.FindFirstProcessUnit( + tensor_metadata, tflite::ProcessUnitOptions_ScoreCalibrationOptions)); + if (score_calibration_process_unit == nullptr) { + return std::nullopt; + } + auto* score_calibration_options = + score_calibration_process_unit->options_as_ScoreCalibrationOptions(); + // Get corresponding AssociatedFile. + auto score_calibration_filename = + metadata_extractor.FindFirstAssociatedFileName( + tensor_metadata, + tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION); + if (score_calibration_filename.empty()) { + return CreateStatusWithPayload( + absl::StatusCode::kNotFound, + "Found ScoreCalibrationOptions but missing required associated " + "parameters file with type TENSOR_AXIS_SCORE_CALIBRATION.", + MediaPipeTasksStatus::kMetadataAssociatedFileNotFoundError); + } + ASSIGN_OR_RETURN( + absl::string_view score_calibration_file, + metadata_extractor.GetAssociatedFile(score_calibration_filename)); + ScoreCalibrationCalculatorOptions score_calibration_calculator_options; + MP_RETURN_IF_ERROR(ConfigureScoreCalibration( + score_calibration_options->score_transformation(), + score_calibration_options->default_score(), score_calibration_file, + &score_calibration_calculator_options)); + return score_calibration_calculator_options; +} + std::vector GetOutputTensorIndices( const flatbuffers::Vector>* tensor_metadatas) { @@ -353,6 +403,12 @@ absl::StatusOr BuildPostProcessingSpecs( *output_tensors_metadata->Get( specs.output_tensor_indices[2]))); } + // Builds score calibration options (if available) from metadata. + ASSIGN_OR_RETURN( + specs.score_calibration_options, + GetScoreCalibrationOptionsIfAny( + *metadata_extractor, + *output_tensors_metadata->Get(specs.output_tensor_indices[2]))); return specs; } @@ -417,6 +473,11 @@ void ConfigureTensorsToDetectionsCalculator( // options { // [mediapipe.tasks.vision.object_detector.proto.ObjectDetectorOptions.ext] // { +// base_options { +// model_asset { +// file_name: "/path/to/model.tflite" +// } +// } // max_results: 4 // score_threshold: 0.5 // category_allowlist: "foo" @@ -460,8 +521,25 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { const core::ModelResources& model_resources, Source image_in, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); - auto metadata_extractor = model_resources.GetMetadataExtractor(); + // Checks that the model has 4 outputs. + auto& model = *model_resources.GetTfLiteModel(); + if (model.subgraphs()->size() != 1 || + (*model.subgraphs())[0]->outputs()->size() != 4) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Expected a model with a single subgraph, found %d.", + model.subgraphs()->size()), + MediaPipeTasksStatus::kInvalidArgumentError); + } + if (model.subgraphs()->Get(0)->outputs()->size() != 4) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Expected a model with 4 output tensors, found %d.", + model.subgraphs()->Get(0)->outputs()->size()), + MediaPipeTasksStatus::kInvalidArgumentError); + } // Checks that metadata is available. + auto* metadata_extractor = model_resources.GetMetadataExtractor(); if (metadata_extractor->GetModelMetadata() == nullptr || metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) { @@ -475,21 +553,65 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { // Adds preprocessing calculators and connects them to the graph input image // stream. auto& preprocessing = - graph.AddNode("mediapipe.tasks.ImagePreprocessingSubgraph"); + graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( model_resources, - &preprocessing.GetOptions())); + &preprocessing + .GetOptions())); image_in >> preprocessing.In(kImageTag); // Adds inference subgraph and connects its input stream to the output // tensors produced by the ImageToTensorCalculator. - auto& inference = AddInference(model_resources, graph); + auto& inference = AddInference( + model_resources, task_options.base_options().acceleration(), graph); preprocessing.Out(kTensorTag) >> inference.In(kTensorTag); // Adds post processing calculators. ASSIGN_OR_RETURN( auto post_processing_specs, BuildPostProcessingSpecs(task_options, metadata_extractor)); + // Calculators to perform score calibration, if specified in the metadata. + TensorsSource calibrated_tensors = {&inference, kTensorTag}; + if (post_processing_specs.score_calibration_options.has_value()) { + // Split tensors. + auto* split_tensor_vector_node = + &graph.AddNode("SplitTensorVectorCalculator"); + auto& split_tensor_vector_options = + split_tensor_vector_node + ->GetOptions(); + for (int i = 0; i < 4; ++i) { + auto* range = split_tensor_vector_options.add_ranges(); + range->set_begin(i); + range->set_end(i + 1); + } + calibrated_tensors >> split_tensor_vector_node->In(0); + + // Add score calibration calculator. + auto* score_calibration_node = + &graph.AddNode("ScoreCalibrationCalculator"); + score_calibration_node->GetOptions() + .CopyFrom(*post_processing_specs.score_calibration_options); + split_tensor_vector_node->Out( + post_processing_specs.output_tensor_indices[1]) >> + score_calibration_node->In(kIndicesTag); + split_tensor_vector_node->Out( + post_processing_specs.output_tensor_indices[2]) >> + score_calibration_node->In(kScoresTag); + + // Re-concatenate tensors. + auto* concatenate_tensor_vector_node = + &graph.AddNode("ConcatenateTensorVectorCalculator"); + for (int i = 0; i < 4; ++i) { + if (i == post_processing_specs.output_tensor_indices[2]) { + score_calibration_node->Out(kCalibratedScoresTag) >> + concatenate_tensor_vector_node->In(i); + } else { + split_tensor_vector_node->Out(i) >> + concatenate_tensor_vector_node->In(i); + } + } + calibrated_tensors = {concatenate_tensor_vector_node, 0}; + } // Calculator to convert output tensors to a detection proto vector. // Connects TensorsToDetectionsCalculator's input stream to the output // tensors produced by the inference subgraph. @@ -499,7 +621,7 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { post_processing_specs, &tensors_to_detections .GetOptions()); - inference.Out(kTensorTag) >> tensors_to_detections.In(kTensorTag); + calibrated_tensors >> tensors_to_detections.In(kTensorTag); // Calculator to projects detections back to the original coordinate system. auto& detection_projection = graph.AddNode("DetectionProjectionCalculator"); diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index faca6ef24..463c92566 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -68,8 +68,9 @@ using ::testing::Optional; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kMobileSsdWithMetadata[] = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; -constexpr char kMobileSsdWithMetadataDummyScoreCalibration[] = - "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_score_calibration.tflite"; +constexpr char kMobileSsdWithDummyScoreCalibration[] = + "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration." + "tflite"; // The model has different output tensor order. constexpr char kEfficientDetWithMetadata[] = "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite"; @@ -153,7 +154,7 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {}; TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) { auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->base_options.op_resolver = absl::make_unique(); @@ -185,7 +186,7 @@ class MobileSsdQuantizedOpResolverMissingOps TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->base_options.op_resolver = absl::make_unique(); @@ -194,7 +195,7 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { // interpreter errors (e.g., "Encountered unresolved custom op"). EXPECT_EQ(object_detector.status().code(), absl::StatusCode::kInternal); EXPECT_THAT(object_detector.status().message(), - HasSubstr("interpreter->AllocateTensors() == kTfLiteOk")); + HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk")); } TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { @@ -215,7 +216,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) { auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->max_results = 0; @@ -233,7 +234,7 @@ TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) { TEST_F(CreateFromOptionsTest, FailsWithCombinedAllowlistAndDenylist) { auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->category_allowlist.push_back("foo"); options->category_denylist.push_back("bar"); @@ -253,7 +254,7 @@ TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInImageOrVideoMode) { for (auto running_mode : {core::RunningMode::IMAGE, core::RunningMode::VIDEO}) { auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->running_mode = running_mode; options->result_callback = @@ -274,7 +275,7 @@ TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInImageOrVideoMode) { TEST_F(CreateFromOptionsTest, FailsWithMissingCallbackInLiveStreamMode) { auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->running_mode = core::RunningMode::LIVE_STREAM; absl::StatusOr> object_detector = @@ -299,11 +300,11 @@ TEST_F(ImageModeTest, FailsWithCallingWrongMethod) { "./", kTestDataDirectory, "cats_and_dogs_no_resizing.jpg"))); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); - auto results = object_detector->Detect(image, 0); + auto results = object_detector->DetectForVideo(image, 0); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), HasSubstr("not initialized with the video mode")); @@ -327,7 +328,7 @@ TEST_F(ImageModeTest, Succeeds) { "cats_and_dogs.jpg"))); auto options = std::make_unique(); options->max_results = 4; - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); @@ -370,7 +371,7 @@ TEST_F(ImageModeTest, SucceedsEfficientDetModel) { "cats_and_dogs.jpg"))); auto options = std::make_unique(); options->max_results = 4; - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kEfficientDetWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); @@ -413,7 +414,7 @@ TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { "cats_and_dogs_no_resizing.jpg"))); auto options = std::make_unique(); options->max_results = 4; - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); @@ -423,8 +424,27 @@ TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { results, GenerateMobileSsdNoImageResizingFullExpectedResults()); } -// TODO: Add SucceedswithScoreCalibrations after score calibration -// is implemented. +TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { + MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath( + "./", kTestDataDirectory, + "cats_and_dogs_no_resizing.jpg"))); + auto options = std::make_unique(); + options->max_results = 1; + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileSsdWithDummyScoreCalibration); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, + ObjectDetector::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); + MP_ASSERT_OK(object_detector->Close()); + ExpectApproximatelyEqual( + results, {ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.6531269142 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } + })pb")}); +} TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath( @@ -432,7 +452,7 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { "cats_and_dogs_no_resizing.jpg"))); auto options = std::make_unique(); options->score_threshold = 0.5; - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); @@ -451,7 +471,7 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { "cats_and_dogs_no_resizing.jpg"))); auto options = std::make_unique(); options->max_results = 2; - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); @@ -470,7 +490,7 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { auto options = std::make_unique(); options->max_results = 1; options->category_allowlist.push_back("dog"); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); @@ -488,7 +508,7 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) { auto options = std::make_unique(); options->max_results = 1; options->category_denylist.push_back("cat"); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); @@ -506,7 +526,7 @@ TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { "./", kTestDataDirectory, "cats_and_dogs_no_resizing.jpg"))); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->running_mode = core::RunningMode::VIDEO; @@ -538,12 +558,13 @@ TEST_F(VideoModeTest, Succeeds) { auto options = std::make_unique(); options->max_results = 2; options->running_mode = core::RunningMode::VIDEO; - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); for (int i = 0; i < iterations; ++i) { - MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image, i)); + MP_ASSERT_OK_AND_ASSIGN(auto results, + object_detector->DetectForVideo(image, i)); std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); ExpectApproximatelyEqual( @@ -559,7 +580,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { "./", kTestDataDirectory, "cats_and_dogs_no_resizing.jpg"))); auto options = std::make_unique(); - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->running_mode = core::RunningMode::LIVE_STREAM; options->result_callback = @@ -576,7 +597,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { Optional(absl::Cord(absl::StrCat( MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError)))); - results = object_detector->Detect(image, 0); + results = object_detector->DetectForVideo(image, 0); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), HasSubstr("not initialized with the video mode")); @@ -592,7 +613,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { "cats_and_dogs_no_resizing.jpg"))); auto options = std::make_unique(); options->running_mode = core::RunningMode::LIVE_STREAM; - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->result_callback = [](absl::StatusOr> detections, const Image& image, @@ -623,7 +644,7 @@ TEST_F(LiveStreamModeTest, Succeeds) { std::vector> detection_results; std::vector> image_sizes; std::vector timestamps; - options->base_options.model_file_name = + options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->result_callback = [&detection_results, &image_sizes, ×tamps]( diff --git a/mediapipe/tasks/java/BUILD b/mediapipe/tasks/java/BUILD new file mode 100644 index 000000000..024510737 --- /dev/null +++ b/mediapipe/tasks/java/BUILD @@ -0,0 +1 @@ +# dummy file for tap test to find the pattern diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD new file mode 100644 index 000000000..65c1214af --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD new file mode 100644 index 000000000..65c1214af --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD new file mode 100644 index 000000000..65c1214af --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/AndroidManifest.xml new file mode 100644 index 000000000..3b52683e8 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/BUILD new file mode 100644 index 000000000..65c1214af --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/BUILD @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/tasks/javatests/BUILD b/mediapipe/tasks/javatests/BUILD new file mode 100644 index 000000000..024510737 --- /dev/null +++ b/mediapipe/tasks/javatests/BUILD @@ -0,0 +1 @@ +# dummy file for tap test to find the pattern diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/BUILD new file mode 100644 index 000000000..65c1214af --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/BUILD @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/AndroidManifest.xml new file mode 100644 index 000000000..3e5e81920 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/BUILD new file mode 100644 index 000000000..1bec2be3e --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/BUILD @@ -0,0 +1,15 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/metadata/metadata_schema.fbs b/mediapipe/tasks/metadata/metadata_schema.fbs index 776b960d5..933fdfb2a 100644 --- a/mediapipe/tasks/metadata/metadata_schema.fbs +++ b/mediapipe/tasks/metadata/metadata_schema.fbs @@ -1,4 +1,4 @@ -// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/python/components/containers/bounding_box.py b/mediapipe/tasks/python/components/containers/bounding_box.py index f41fdb386..7cfbdf794 100644 --- a/mediapipe/tasks/python/components/containers/bounding_box.py +++ b/mediapipe/tasks/python/components/containers/bounding_box.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/python/components/containers/category.py b/mediapipe/tasks/python/components/containers/category.py index ac94491dc..00f68e532 100644 --- a/mediapipe/tasks/python/components/containers/category.py +++ b/mediapipe/tasks/python/components/containers/category.py @@ -1,4 +1,4 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/python/components/containers/detections.py b/mediapipe/tasks/python/components/containers/detections.py index 39a0fe81a..b4d550633 100644 --- a/mediapipe/tasks/python/components/containers/detections.py +++ b/mediapipe/tasks/python/components/containers/detections.py @@ -1,4 +1,4 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/python/core/base_options.py b/mediapipe/tasks/python/core/base_options.py index 75b42ab3c..122dc620f 100644 --- a/mediapipe/tasks/python/core/base_options.py +++ b/mediapipe/tasks/python/core/base_options.py @@ -28,39 +28,39 @@ _ExternalFileProto = external_file_pb2.ExternalFile class BaseOptions: """Base options for MediaPipe Tasks' Python APIs. - Represents external files used by the Task APIs (e.g. TF Lite FlatBuffer or - plain-text labels file). The files can be specified by one of the following - two ways: + Represents external model asset used by the Task APIs. The files can be + specified by one of the following two ways: - (1) file contents loaded in `file_content`. - (2) file path in `file_name`. + (1) model asset file path in `model_asset_path`. + (2) model asset contents loaded in `model_asset_buffer`. If more than one field of these fields is provided, they are used in this precedence order. Attributes: - file_name: Path to the index. - file_content: The index file contents as bytes. + model_asset_path: Path to the model asset file. + model_asset_buffer: The model asset file contents as bytes. """ - file_name: Optional[str] = None - file_content: Optional[bytes] = None + model_asset_path: Optional[str] = None + model_asset_buffer: Optional[bytes] = None # TODO: Allow Python API to specify acceleration settings. @doc_controls.do_not_generate_docs def to_pb2(self) -> _BaseOptionsProto: """Generates a BaseOptions protobuf object.""" return _BaseOptionsProto( - model_file=_ExternalFileProto( - file_name=self.file_name, file_content=self.file_content)) + model_asset=_ExternalFileProto( + file_name=self.model_asset_path, + file_content=self.model_asset_buffer)) @classmethod @doc_controls.do_not_generate_docs def create_from_pb2(cls, pb2_obj: _BaseOptionsProto) -> 'BaseOptions': """Creates a `BaseOptions` object from the given protobuf object.""" return BaseOptions( - file_name=pb2_obj.model_file.file_name, - file_content=pb2_obj.model_file.file_content) + model_asset_path=pb2_obj.model_asset.file_name, + model_asset_buffer=pb2_obj.model_asset.file_content) def __eq__(self, other: Any) -> bool: """Checks if this object is equal to the given object. diff --git a/mediapipe/tasks/python/test/test_util.py b/mediapipe/tasks/python/test/test_util.py index 0e2063a8c..cf1dfec2e 100644 --- a/mediapipe/tasks/python/test/test_util.py +++ b/mediapipe/tasks/python/test/test_util.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index bb495338d..6980a12a0 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -18,4 +18,21 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -# TODO: This test fails in OSS +py_test( + name = "object_detector_test", + srcs = ["object_detector_test.py"], + data = [ + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/components/containers:bounding_box", + "//mediapipe/tasks/python/components/containers:category", + "//mediapipe/tasks/python/components/containers:detections", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_util", + "//mediapipe/tasks/python/vision:object_detector", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + ], +) diff --git a/mediapipe/tasks/python/test/vision/object_detector_test.py b/mediapipe/tasks/python/test/vision/object_detector_test.py index daab7a183..a83031342 100644 --- a/mediapipe/tasks/python/test/vision/object_detector_test.py +++ b/mediapipe/tasks/python/test/vision/object_detector_test.py @@ -14,6 +14,7 @@ """Tests for object detector.""" import enum +from unittest import mock from absl.testing import absltest from absl.testing import parameterized @@ -108,7 +109,7 @@ class ObjectDetectorTest(parameterized.TestCase): def test_create_from_options_succeeds_with_valid_model_path(self): # Creates with options containing model file successfully. - base_options = _BaseOptions(file_name=self.model_path) + base_options = _BaseOptions(model_asset_path=self.model_path) options = _ObjectDetectorOptions(base_options=base_options) with _ObjectDetector.create_from_options(options) as detector: self.assertIsInstance(detector, _ObjectDetector) @@ -119,14 +120,14 @@ class ObjectDetectorTest(parameterized.TestCase): ValueError, r"ExternalFile must specify at least one of 'file_content', " r"'file_name' or 'file_descriptor_meta'."): - base_options = _BaseOptions(file_name='') + base_options = _BaseOptions(model_asset_path='') options = _ObjectDetectorOptions(base_options=base_options) _ObjectDetector.create_from_options(options) def test_create_from_options_succeeds_with_valid_model_content(self): # Creates with options containing model content successfully. with open(self.model_path, 'rb') as f: - base_options = _BaseOptions(file_content=f.read()) + base_options = _BaseOptions(model_asset_buffer=f.read()) options = _ObjectDetectorOptions(base_options=base_options) detector = _ObjectDetector.create_from_options(options) self.assertIsInstance(detector, _ObjectDetector) @@ -138,11 +139,11 @@ class ObjectDetectorTest(parameterized.TestCase): expected_detection_result): # Creates detector. if model_file_type is ModelFileType.FILE_NAME: - base_options = _BaseOptions(file_name=self.model_path) + base_options = _BaseOptions(model_asset_path=self.model_path) elif model_file_type is ModelFileType.FILE_CONTENT: with open(self.model_path, 'rb') as f: model_content = f.read() - base_options = _BaseOptions(file_content=model_content) + base_options = _BaseOptions(model_asset_buffer=model_content) else: # Should never happen raise ValueError('model_file_type is invalid.') @@ -165,11 +166,11 @@ class ObjectDetectorTest(parameterized.TestCase): def test_detect_in_context(self, model_file_type, max_results, expected_detection_result): if model_file_type is ModelFileType.FILE_NAME: - base_options = _BaseOptions(file_name=self.model_path) + base_options = _BaseOptions(model_asset_path=self.model_path) elif model_file_type is ModelFileType.FILE_CONTENT: with open(self.model_path, 'rb') as f: - model_content = f.read() - base_options = _BaseOptions(file_content=model_content) + model_contents = f.read() + base_options = _BaseOptions(model_asset_buffer=model_contents) else: # Should never happen raise ValueError('model_file_type is invalid.') @@ -184,7 +185,7 @@ class ObjectDetectorTest(parameterized.TestCase): def test_score_threshold_option(self): options = _ObjectDetectorOptions( - base_options=_BaseOptions(file_name=self.model_path), + base_options=_BaseOptions(model_asset_path=self.model_path), score_threshold=_SCORE_THRESHOLD) with _ObjectDetector.create_from_options(options) as detector: # Performs object detection on the input. @@ -199,7 +200,7 @@ class ObjectDetectorTest(parameterized.TestCase): def test_max_results_option(self): options = _ObjectDetectorOptions( - base_options=_BaseOptions(file_name=self.model_path), + base_options=_BaseOptions(model_asset_path=self.model_path), max_results=_MAX_RESULTS) with _ObjectDetector.create_from_options(options) as detector: # Performs object detection on the input. @@ -211,7 +212,7 @@ class ObjectDetectorTest(parameterized.TestCase): def test_allow_list_option(self): options = _ObjectDetectorOptions( - base_options=_BaseOptions(file_name=self.model_path), + base_options=_BaseOptions(model_asset_path=self.model_path), category_allowlist=_ALLOW_LIST) with _ObjectDetector.create_from_options(options) as detector: # Performs object detection on the input. @@ -225,7 +226,7 @@ class ObjectDetectorTest(parameterized.TestCase): def test_deny_list_option(self): options = _ObjectDetectorOptions( - base_options=_BaseOptions(file_name=self.model_path), + base_options=_BaseOptions(model_asset_path=self.model_path), category_denylist=_DENY_LIST) with _ObjectDetector.create_from_options(options) as detector: # Performs object detection on the input. @@ -244,7 +245,7 @@ class ObjectDetectorTest(parameterized.TestCase): r'`category_allowlist` and `category_denylist` are mutually ' r'exclusive options.'): options = _ObjectDetectorOptions( - base_options=_BaseOptions(file_name=self.model_path), + base_options=_BaseOptions(model_asset_path=self.model_path), category_allowlist=['foo'], category_denylist=['bar']) with _ObjectDetector.create_from_options(options) as unused_detector: @@ -252,7 +253,8 @@ class ObjectDetectorTest(parameterized.TestCase): def test_empty_detection_outputs(self): options = _ObjectDetectorOptions( - base_options=_BaseOptions(file_name=self.model_path), score_threshold=1) + base_options=_BaseOptions(model_asset_path=self.model_path), + score_threshold=1) with _ObjectDetector.create_from_options(options) as detector: # Performs object detection on the input. image_result = detector.detect(self.test_image) @@ -260,7 +262,7 @@ class ObjectDetectorTest(parameterized.TestCase): def test_missing_result_callback(self): options = _ObjectDetectorOptions( - base_options=_BaseOptions(file_name=self.model_path), + base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM) with self.assertRaisesRegex(ValueError, r'result callback must be provided'): @@ -269,31 +271,99 @@ class ObjectDetectorTest(parameterized.TestCase): @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO)) def test_illegal_result_callback(self, running_mode): - - def pass_through(unused_result: _DetectionResult, - unused_output_image: _Image, unused_timestamp_ms: int): - pass - options = _ObjectDetectorOptions( - base_options=_BaseOptions(file_name=self.model_path), + base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=running_mode, - result_callback=pass_through) + result_callback=mock.MagicMock()) with self.assertRaisesRegex(ValueError, r'result callback should not be provided'): with _ObjectDetector.create_from_options(options) as unused_detector: pass - def test_detect_async_calls_with_illegal_timestamp(self): - - def pass_through(unused_result: _DetectionResult, - unused_output_image: _Image, unused_timestamp_ms: int): - pass - + def test_calling_detect_for_video_in_image_mode(self): options = _ObjectDetectorOptions( - base_options=_BaseOptions(file_name=self.model_path), + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE) + with _ObjectDetector.create_from_options(options) as detector: + with self.assertRaisesRegex(ValueError, + r'not initialized with the video mode'): + detector.detect_for_video(self.test_image, 0) + + def test_calling_detect_async_in_image_mode(self): + options = _ObjectDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE) + with _ObjectDetector.create_from_options(options) as detector: + with self.assertRaisesRegex(ValueError, + r'not initialized with the live stream mode'): + detector.detect_async(self.test_image, 0) + + def test_calling_detect_in_video_mode(self): + options = _ObjectDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _ObjectDetector.create_from_options(options) as detector: + with self.assertRaisesRegex(ValueError, + r'not initialized with the image mode'): + detector.detect(self.test_image) + + def test_calling_detect_async_in_video_mode(self): + options = _ObjectDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _ObjectDetector.create_from_options(options) as detector: + with self.assertRaisesRegex(ValueError, + r'not initialized with the live stream mode'): + detector.detect_async(self.test_image, 0) + + def test_detect_for_video_with_out_of_order_timestamp(self): + options = _ObjectDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _ObjectDetector.create_from_options(options) as detector: + unused_result = detector.detect_for_video(self.test_image, 1) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing'): + detector.detect_for_video(self.test_image, 0) + + # TODO: Tests how `detect_for_video` handles the temporal data + # with a real video. + def test_detect_for_video(self): + options = _ObjectDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + max_results=4) + with _ObjectDetector.create_from_options(options) as detector: + for timestamp in range(0, 300, 30): + detection_result = detector.detect_for_video(self.test_image, timestamp) + self.assertEqual(detection_result, _EXPECTED_DETECTION_RESULT) + + def test_calling_detect_in_live_stream_mode(self): + options = _ObjectDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock()) + with _ObjectDetector.create_from_options(options) as detector: + with self.assertRaisesRegex(ValueError, + r'not initialized with the image mode'): + detector.detect(self.test_image) + + def test_calling_detect_for_video_in_live_stream_mode(self): + options = _ObjectDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock()) + with _ObjectDetector.create_from_options(options) as detector: + with self.assertRaisesRegex(ValueError, + r'not initialized with the video mode'): + detector.detect_for_video(self.test_image, 0) + + def test_detect_async_calls_with_illegal_timestamp(self): + options = _ObjectDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, max_results=4, - result_callback=pass_through) + result_callback=mock.MagicMock()) with _ObjectDetector.create_from_options(options) as detector: detector.detect_async(self.test_image, 100) with self.assertRaisesRegex( @@ -315,7 +385,7 @@ class ObjectDetectorTest(parameterized.TestCase): self.observed_timestamp_ms = timestamp_ms options = _ObjectDetectorOptions( - base_options=_BaseOptions(file_name=self.model_path), + base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, max_results=4, score_threshold=threshold, diff --git a/mediapipe/tasks/python/vision/object_detector.py b/mediapipe/tasks/python/vision/object_detector.py index cdf36f386..a50e55861 100644 --- a/mediapipe/tasks/python/vision/object_detector.py +++ b/mediapipe/tasks/python/vision/object_detector.py @@ -121,7 +121,7 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi): file such as invalid file path. RuntimeError: If other types of error occurred. """ - base_options = _BaseOptions(file_name=model_path) + base_options = _BaseOptions(model_asset_path=model_path) options = ObjectDetectorOptions( base_options=base_options, running_mode=_RunningMode.IMAGE) return cls.create_from_options(options) @@ -175,6 +175,9 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi): image: image_module.Image) -> detections_module.DetectionResult: """Performs object detection on the provided MediaPipe Image. + Only use this method when the ObjectDetector is created with the image + running mode. + Args: image: MediaPipe Image. @@ -197,15 +200,54 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi): for result in detection_proto_list ]) + def detect_for_video(self, image: image_module.Image, + timestamp_ms: int) -> detections_module.DetectionResult: + """Performs object detection on the provided video frames. + + Only use this method when the ObjectDetector is created with the video + running mode. It's required to provide the video frame's timestamp (in + milliseconds) along with the video frame. The input timestamps should be + monotonically increasing for adjacent calls of this method. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input video frame in milliseconds. + + Returns: + A detection result object that contains a list of detections, each + detection has a bounding box that is expressed in the unrotated input + frame of reference coordinates system, i.e. in `[0,image_width) x [0, + image_height)`, which are the dimensions of the underlying image data. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If object detection failed to run. + """ + output_packets = self._process_video_data({ + _IMAGE_IN_STREAM_NAME: + packet_creator.create_image(image).at(timestamp_ms) + }) + detection_proto_list = packet_getter.get_proto_list( + output_packets[_DETECTIONS_OUT_STREAM_NAME]) + return detections_module.DetectionResult([ + detections_module.Detection.create_from_pb2(result) + for result in detection_proto_list + ]) + def detect_async(self, image: image_module.Image, timestamp_ms: int) -> None: """Sends live image data (an Image with a unique timestamp) to perform object detection. - This method will return immediately after the input image is accepted. The - results will be available via the `result_callback` provided in the - `ObjectDetectorOptions`. The `detect_async` method is designed to process - live stream data such as camera input. To lower the overall latency, object - detector may drop the input images if needed. In other words, it's not - guaranteed to have output per input image. The `result_callback` prvoides: + Only use this method when the ObjectDetector is created with the live stream + running mode. The input timestamps should be monotonically increasing for + adjacent calls of this method. This method will return immediately after the + input image is accepted. The results will be available via the + `result_callback` provided in the `ObjectDetectorOptions`. The + `detect_async` method is designed to process live stream data such as camera + input. To lower the overall latency, object detector may drop the input + images if needed. In other words, it's not guaranteed to have output per + input image. + + The `result_callback` prvoides: - A detection result object that contains a list of detections, each detection has a bounding box that is expressed in the unrotated input frame of reference coordinates system, i.e. in `[0,image_width) x [0, diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index 5dbb27327..80f1163db 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -25,9 +25,9 @@ package( mediapipe_files(srcs = [ "30k-clean.model", "albert_with_metadata.tflite", - "bert_nl_classifier.tflite", + "bert_text_classifier.tflite", "mobilebert_with_metadata.tflite", - "test_model_nl_classifier_with_regex_tokenizer.tflite", + "test_model_text_classifier_with_regex_tokenizer.tflite", ]) exports_files(srcs = [ @@ -72,13 +72,13 @@ filegroup( ) filegroup( - name = "nl_classifier_models", + name = "text_classifier_models", srcs = glob([ - "test_model_nl_classifier*.tflite", + "test_model_text_classifier*.tflite", ]), ) filegroup( - name = "bert_nl_classifier_models", - srcs = ["bert_nl_classifier.tflite"], + name = "bert_text_classifier_models", + srcs = ["bert_text_classifier.tflite"], ) diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index b52604c2b..41eb44c21 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -24,13 +24,14 @@ package( mediapipe_files(srcs = [ "burger.jpg", + "burger_crop.jpg", "cat.jpg", "cat_mask.jpg", "cats_and_dogs.jpg", "cats_and_dogs_no_resizing.jpg", "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", - "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_score_calibration.tflite", + "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", @@ -39,9 +40,13 @@ mediapipe_files(srcs = [ "mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_metadata_1.tflite", "mobilenet_v1_0.25_224_quant.tflite", + "mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite", "mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite", "mobilenet_v2_1.0_224.tflite", + "mobilenet_v3_small_100_224_embedder.tflite", "mozart_square.jpg", + "multi_objects.jpg", + "palm_detection_full.tflite", "right_hands.jpg", "segmentation_golden_rotation0.png", "segmentation_input_rotation0.jpg", @@ -64,6 +69,7 @@ filegroup( name = "test_images", srcs = [ "burger.jpg", + "burger_crop.jpg", "cat.jpg", "cat_mask.jpg", "cats_and_dogs.jpg", @@ -72,6 +78,7 @@ filegroup( "hand_landmark_lite.tflite", "left_hands.jpg", "mozart_square.jpg", + "multi_objects.jpg", "right_hands.jpg", "segmentation_golden_rotation0.png", "segmentation_input_rotation0.jpg", @@ -86,7 +93,7 @@ filegroup( srcs = [ "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", - "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_score_calibration.tflite", + "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", @@ -94,8 +101,11 @@ filegroup( "mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_metadata_1.tflite", "mobilenet_v1_0.25_224_quant.tflite", + "mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite", "mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite", "mobilenet_v2_1.0_224.tflite", + "mobilenet_v3_small_100_224_embedder.tflite", + "palm_detection_full.tflite", "selfie_segm_128_128_3.tflite", "selfie_segm_144_256_3.tflite", ], @@ -108,6 +118,8 @@ filegroup( "expected_left_up_hand_landmarks.prototxt", "expected_right_down_hand_landmarks.prototxt", "expected_right_up_hand_landmarks.prototxt", + "hand_detector_result_one_hand.pbtxt", + "hand_detector_result_two_hands.pbtxt", "pointing_up_landmarks.pbtxt", "thumb_up_landmarks.pbtxt", ], diff --git a/mediapipe/tasks/testdata/vision/hand_detector_result_one_hand.pbtxt b/mediapipe/tasks/testdata/vision/hand_detector_result_one_hand.pbtxt new file mode 100644 index 000000000..41b2f5584 --- /dev/null +++ b/mediapipe/tasks/testdata/vision/hand_detector_result_one_hand.pbtxt @@ -0,0 +1,33 @@ +detections { + label: "Palm" + score: 0.9694994 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.64567286 + ymin: 0.04196969 + width: 0.22876495 + height: 0.43088135 + } + } +} +detections { + label: "Palm" + score: 0.9691078 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.1332045 + ymin: 0.53324974 + width: 0.22739677 + height: 0.42832413 + } + } +} +hand_rects { + x_center: 0.6807422 + y_center: 0.41254658 + height: 1.121068 + width: 0.59478885 + rotation: -2.374725 +} diff --git a/mediapipe/tasks/testdata/vision/hand_detector_result_two_hands.pbtxt b/mediapipe/tasks/testdata/vision/hand_detector_result_two_hands.pbtxt new file mode 100644 index 000000000..ba46fc666 --- /dev/null +++ b/mediapipe/tasks/testdata/vision/hand_detector_result_two_hands.pbtxt @@ -0,0 +1,40 @@ +detections { + label: "Palm" + score: 0.9694994 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.64567286 + ymin: 0.04196969 + width: 0.22876495 + height: 0.43088135 + } + } +} +detections { + label: "Palm" + score: 0.9691078 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.1332045 + ymin: 0.53324974 + width: 0.22739677 + height: 0.42832413 + } + } +} +hand_rects { + x_center: 0.32884726 + y_center: 0.5990523 + height: 1.1143632 + width: 0.5912316 + rotation: 0.80550915 +} +hand_rects { + x_center: 0.6807422 + y_center: 0.41254658 + height: 1.121068 + width: 0.59478885 + rotation: -2.374725 +} diff --git a/mediapipe/util/sequence/media_sequence.h b/mediapipe/util/sequence/media_sequence.h index 6b80c519f..620d6d483 100644 --- a/mediapipe/util/sequence/media_sequence.h +++ b/mediapipe/util/sequence/media_sequence.h @@ -203,6 +203,10 @@ const char kClipLabelIndexKey[] = "clip/label/index"; const char kClipLabelStringKey[] = "clip/label/string"; // A list of label confidences for this clip. const char kClipLabelConfidenceKey[] = "clip/label/confidence"; +// A list of label start timestamps for this clip. +const char kClipLabelStartTimestampKey[] = "clip/label/start/timestamp"; +// A list of label end timestamps for this clip. +const char kClipLabelEndTimestampKey[] = "clip/label/end/timestamp"; BYTES_CONTEXT_FEATURE(ExampleId, kExampleIdKey); BYTES_CONTEXT_FEATURE(ExampleDatasetName, kExampleDatasetNameKey); @@ -220,6 +224,9 @@ INT64_CONTEXT_FEATURE(ClipEndTimestamp, kClipEndTimestampKey); VECTOR_BYTES_CONTEXT_FEATURE(ClipLabelString, kClipLabelStringKey); VECTOR_INT64_CONTEXT_FEATURE(ClipLabelIndex, kClipLabelIndexKey); VECTOR_FLOAT_CONTEXT_FEATURE(ClipLabelConfidence, kClipLabelConfidenceKey); +VECTOR_INT64_CONTEXT_FEATURE(ClipLabelStartTimestamp, + kClipLabelStartTimestampKey); +VECTOR_INT64_CONTEXT_FEATURE(ClipLabelEndTimestamp, kClipLabelEndTimestampKey); // *********************** SEGMENTS ************************************* // Context Keys: diff --git a/mediapipe/util/sequence/media_sequence.py b/mediapipe/util/sequence/media_sequence.py index 9aea821eb..1b96383d6 100644 --- a/mediapipe/util/sequence/media_sequence.py +++ b/mediapipe/util/sequence/media_sequence.py @@ -188,6 +188,10 @@ CLIP_LABEL_INDEX_KEY = "clip/label/index" CLIP_LABEL_STRING_KEY = "clip/label/string" # A list of label confidences for this clip. CLIP_LABEL_CONFIDENCE_KEY = "clip/label/confidence" +# A list of label start timestamps for this clip. +CLIP_LABEL_START_TIMESTAMP_KEY = "clip/label/start/timestamp" +# A list of label end timestamps for this clip. +CLIP_LABEL_END_TIMESTAMP_KEY = "clip/label/end/timestamp" msu.create_bytes_context_feature( "example_id", EXAMPLE_ID_KEY, module_dict=globals()) msu.create_bytes_context_feature( @@ -218,6 +222,14 @@ msu.create_int_list_context_feature( "clip_label_index", CLIP_LABEL_INDEX_KEY, module_dict=globals()) msu.create_float_list_context_feature( "clip_label_confidence", CLIP_LABEL_CONFIDENCE_KEY, module_dict=globals()) +msu.create_int_list_context_feature( + "clip_label_start_timestamp", + CLIP_LABEL_START_TIMESTAMP_KEY, + module_dict=globals()) +msu.create_int_list_context_feature( + "clip_label_end_timestamp", + CLIP_LABEL_END_TIMESTAMP_KEY, + module_dict=globals()) ################################## SEGMENTS ################################# # A list of segment start times in microseconds. diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 33230fe26..e246bbd8d 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -13,7 +13,7 @@ def external_files(): http_file( name = "com_google_mediapipe_30k-clean_model", sha256 = "fefb02b667a6c5c2fe27602d28e5fb3428f66ab89c7d6f388e7c8d44a02d0336", - urls = ["https://storage.googleapis.com/mediapipe-assets/30k-clean.model?generation=1661875643984613"], + urls = ["https://storage.googleapis.com/mediapipe-assets/30k-clean.model?generation=1663006350848402"], ) http_file( @@ -23,9 +23,9 @@ def external_files(): ) http_file( - name = "com_google_mediapipe_bert_nl_classifier_tflite", + name = "com_google_mediapipe_bert_text_classifier_tflite", sha256 = "1e5a550c09bff0a13e61858bcfac7654d7fcc6d42106b4f15e11117695069600", - urls = ["https://storage.googleapis.com/mediapipe-assets/bert_nl_classifier.tflite?generation=1661875658827092"], + urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier.tflite?generation=1663009542017720"], ) http_file( @@ -34,6 +34,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=1661875663693976"], ) + http_file( + name = "com_google_mediapipe_burger_crop_jpg", + sha256 = "8f58de573f0bf59a49c3d86cfabb9ad4061481f574aa049177e8da3963dddc50", + urls = ["https://storage.googleapis.com/mediapipe-assets/burger_crop.jpg?generation=1664184735043119"], + ) + http_file( name = "com_google_mediapipe_burger_jpg", sha256 = "97c15bbbf3cf3615063b1031c85d669de55839f59262bbe145d15ca75b36ecbf", @@ -70,18 +76,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite?generation=1661875692679200"], ) - http_file( - name = "com_google_mediapipe_coco_ssd_mobilenet_v1_1_0_quant_2018_06_29_score_calibration_tflite", - sha256 = "072b44c01f35ba4274adfab69bd8b0f21e7481168782279105426a25b6da5d4a", - urls = ["https://storage.googleapis.com/mediapipe-assets/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_score_calibration.tflite?generation=1661875697443279"], - ) - http_file( name = "com_google_mediapipe_coco_ssd_mobilenet_v1_1_0_quant_2018_06_29_tflite", sha256 = "61d598093ed03ed41aa47c3a39a28ac01e960d6a810a5419b9a5016a1e9c469b", urls = ["https://storage.googleapis.com/mediapipe-assets/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite?generation=1661875702588267"], ) + http_file( + name = "com_google_mediapipe_coco_ssd_mobilenet_v1_1_0_quant_2018_06_29_with_dummy_score_calibration_tflite", + sha256 = "81b2681e3631c3813769396ff914a8f333b191fefcd8c61297fd165bc81e7e19", + urls = ["https://storage.googleapis.com/mediapipe-assets/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite?generation=1662653237233967"], + ) + http_file( name = "com_google_mediapipe_corrupted_mobilenet_v1_0_25_224_1_default_1_tflite", sha256 = "f0cbeb8061f4c693e20de779ce255af923508492e8a24f6db320845a52facb51", @@ -166,6 +172,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/hair_segmentation.tflite?generation=1661875756623461"], ) + http_file( + name = "com_google_mediapipe_hand_detector_result_one_hand_pbtxt", + sha256 = "4b2deb84992bbfe68e3409d2b76914960d1c65aa6edd4524ff3455ca489df5f1", + urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand.pbtxt?generation=1662745351291628"], + ) + + http_file( + name = "com_google_mediapipe_hand_detector_result_two_hands_pbtxt", + sha256 = "2589cb08b0ee027dc24649fe597adcfa2156a21d12ea2480f83832714ebdf95f", + urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_two_hands.pbtxt?generation=1662745353586157"], + ) + http_file( name = "com_google_mediapipe_hand_landmark_full_tflite", sha256 = "11c272b891e1a99ab034208e23937a8008388cf11ed2a9d776ed3d01d0ba00e3", @@ -286,6 +304,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v1_0.25_224_quant.tflite?generation=1661875831485992"], ) + http_file( + name = "com_google_mediapipe_mobilenet_v1_0_25_224_quant_with_dummy_score_calibration_tflite", + sha256 = "1fc6578a8f85f1f0454af6d908fba897fe17500c921e4d79434395abfb0e92f1", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite?generation=1662650659741978"], + ) + http_file( name = "com_google_mediapipe_mobilenet_v1_0_25_224_quant_without_subgraph_metadata_tflite", sha256 = "78f8b9bb5c873d3ad53ffc03b27651213016e45b6a2df42010c93191293bf694", @@ -298,6 +322,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.tflite?generation=1661875840611150"], ) + http_file( + name = "com_google_mediapipe_mobilenet_v3_small_100_224_embedder_tflite", + sha256 = "f7b9a563cb803bdcba76e8c7e82abde06f5c7a8e67b5e54e43e23095dfe79a78", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v3_small_100_224_embedder.tflite?generation=1664184739429109"], + ) + http_file( name = "com_google_mediapipe_mobile_object_classifier_v0_2_3-metadata-no-name_tflite", sha256 = "27fdb2dce68b8bd9a0f16583eefc4df13605808c1417cec268d1e838920c1a81", @@ -322,6 +352,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/mozart_square.jpg?generation=1661875853838871"], ) + http_file( + name = "com_google_mediapipe_multi_objects_jpg", + sha256 = "ada6e36b40519cf0a4fbdf1b535de7fa7d0c472f2c0a08ada8ee5728e16c0c68", + urls = ["https://storage.googleapis.com/mediapipe-assets/multi_objects.jpg?generation=1663251779213308"], + ) + http_file( name = "com_google_mediapipe_object_detection_3d_camera_tflite", sha256 = "f66e92e81ed3f4698f74d565a7668e016e2288ea92fb42938e33b778bd1e110d", @@ -366,8 +402,8 @@ def external_files(): http_file( name = "com_google_mediapipe_palm_detection_full_tflite", - sha256 = "2f25e740121983f68ffc05f99991d524dc0ea812134f6316a26125816941ee85", - urls = ["https://storage.googleapis.com/mediapipe-assets/palm_detection_full.tflite?generation=1661875883244842"], + sha256 = "1b14e9422c6ad006cde6581a46c8b90dd573c07ab7f3934b5589e7cea3f89a54", + urls = ["https://storage.googleapis.com/mediapipe-assets/palm_detection_full.tflite?generation=1662745358034050"], ) http_file( @@ -376,6 +412,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/palm_detection_lite.tflite?generation=1661875885885770"], ) + http_file( + name = "com_google_mediapipe_pointing_up_jpg", + sha256 = "ecf8ca2611d08fa25948a4fc10710af9120e88243a54da6356bacea17ff3e36e", + urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up.jpg?generation=1662650662527717"], + ) + + http_file( + name = "com_google_mediapipe_pointing_up_landmarks_pbtxt", + sha256 = "1255b6ba17b4ef7a9b3ce92c0a139e74fbcec272dc251b049b2f06732f9fed83", + urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_landmarks.pbtxt?generation=1662650664573638"], + ) + http_file( name = "com_google_mediapipe_pose_detection_tflite", sha256 = "a63c614bef30d35947f13be361820b1e4e3bec9cfeebf4d11216a18373108e85", @@ -497,9 +545,9 @@ def external_files(): ) http_file( - name = "com_google_mediapipe_test_model_nl_classifier_with_regex_tokenizer_tflite", + name = "com_google_mediapipe_test_model_text_classifier_with_regex_tokenizer_tflite", sha256 = "cb12618d084b813cb7b90ceb39c9fe4b18dae4de9880b912cdcd4b577cd65b4f", - urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_nl_classifier_with_regex_tokenizer.tflite?generation=1661875953222362"], + urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_text_classifier_with_regex_tokenizer.tflite?generation=1663009546758456"], ) http_file( @@ -514,6 +562,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_without_custom_op.tflite?generation=1661875959757731"], ) + http_file( + name = "com_google_mediapipe_thumb_up_jpg", + sha256 = "5d673c081ab13b8a1812269ff57047066f9c33c07db5f4178089e8cb3fdc0291", + urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up.jpg?generation=1662650667349746"], + ) + + http_file( + name = "com_google_mediapipe_thumb_up_landmarks_pbtxt", + sha256 = "bf1913df6ac7cc14b492c10411c827832839985c057b112789e04ce7c1fdd0fa", + urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_landmarks.pbtxt?generation=1662650669387278"], + ) + http_file( name = "com_google_mediapipe_two_heads_16000_hz_mono_wav", sha256 = "a291a9c22c39bba30138a26915e154a96286ba6ca3b413053123c504a58cce3b", diff --git a/third_party/pffft.BUILD b/third_party/pffft.BUILD new file mode 100644 index 000000000..3d6828e63 --- /dev/null +++ b/third_party/pffft.BUILD @@ -0,0 +1,19 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # FFTPACKv5 License. + +exports_files(["LICENSE"]) + +cc_library( + name = "pffft", + srcs = ["pffft.c"], + hdrs = ["pffft.h"], + copts = select({ + "@bazel_tools//src/conditions:windows": [ + "/D_USE_MATH_DEFINES", + "/W0", + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], +)