From 28b48a3f7d082c8728cb6f51d71af5fa4bfc7f87 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 29 Sep 2022 14:15:47 -0700 Subject: [PATCH 001/132] Replace Protobuf target in MediaPipe Tasks Java PiperOrigin-RevId: 477820107 --- mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index 86dd080a1..8da7b8561 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -31,7 +31,7 @@ android_library( "//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", "//third_party:autovalue", - "//third_party/java/protobuf:protobuf_lite", + "@com_google_protobuf//:protobuf_javalite", "@maven//:com_google_guava_guava", ], ) From 2e8bec69d4a1e96577905f312dc0a9c0cde06e75 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 29 Sep 2022 21:59:42 +0000 Subject: [PATCH 002/132] Internal change PiperOrigin-RevId: 477830748 --- mediapipe/calculators/util/BUILD | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 2ed158f89..3a9ddc36f 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -143,9 +143,7 @@ mediapipe_proto_library( cc_library( name = "packet_frequency_calculator", srcs = ["packet_frequency_calculator.cc"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:packet_frequency_calculator_cc_proto", "//mediapipe/calculators/util:packet_frequency_cc_proto", @@ -190,9 +188,7 @@ cc_test( cc_library( name = "packet_latency_calculator", srcs = ["packet_latency_calculator.cc"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:latency_cc_proto", "//mediapipe/calculators/util:packet_latency_calculator_cc_proto", From e7acc0a857b333d485b02276a1fded0b2db1686c Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 29 Sep 2022 23:10:17 +0000 Subject: [PATCH 003/132] Remove reference to internal-only build rule PiperOrigin-RevId: 477846586 --- .../com/google/mediapipe/tasks/vision/objectdetector/BUILD | 2 -- 1 file changed, 2 deletions(-) diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/BUILD index 5a6522a25..a7f804c64 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/BUILD +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/BUILD @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@build_bazel_rules_android//android:rules.bzl", "android_library_test") - package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) From 8af4cca4138e5637f1dc7f963c1202ac23a115c0 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 30 Sep 2022 00:33:19 +0000 Subject: [PATCH 004/132] Internal change PiperOrigin-RevId: 477863040 --- .../util/labels_to_render_data_calculator.cc | 11 +++++++++++ .../util/labels_to_render_data_calculator.proto | 7 +++++++ mediapipe/util/annotation_renderer.cc | 10 ++++++++++ mediapipe/util/render_data.proto | 6 ++++++ 4 files changed, 34 insertions(+) diff --git a/mediapipe/calculators/util/labels_to_render_data_calculator.cc b/mediapipe/calculators/util/labels_to_render_data_calculator.cc index 4aab3b676..dcd76d47b 100644 --- a/mediapipe/calculators/util/labels_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/labels_to_render_data_calculator.cc @@ -184,6 +184,17 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { text->set_left(label_left_px_); text->set_baseline(label_baseline_px + i * label_height_px_); text->set_font_face(options_.font_face()); + if (options_.outline_thickness() > 0) { + text->set_outline_thickness(options_.outline_thickness()); + if (options_.outline_color_size() > 0) { + *(text->mutable_outline_color()) = + options_.outline_color(i % options_.outline_color_size()); + } else { + text->mutable_outline_color()->set_r(0); + text->mutable_outline_color()->set_g(0); + text->mutable_outline_color()->set_b(0); + } + } } cc->Outputs() .Tag(kRenderDataTag) diff --git a/mediapipe/calculators/util/labels_to_render_data_calculator.proto b/mediapipe/calculators/util/labels_to_render_data_calculator.proto index cf0ada9c2..7946ff683 100644 --- a/mediapipe/calculators/util/labels_to_render_data_calculator.proto +++ b/mediapipe/calculators/util/labels_to_render_data_calculator.proto @@ -30,6 +30,13 @@ message LabelsToRenderDataCalculatorOptions { // Thickness for drawing the label(s). optional double thickness = 2 [default = 2]; + // Color of outline around each character, if any. One per label, as with + // color attribute. + repeated Color outline_color = 12; + + // Thickness of outline around each character. + optional double outline_thickness = 11; + // The font height in absolute pixels. optional int32 font_height_px = 3 [default = 50]; diff --git a/mediapipe/util/annotation_renderer.cc b/mediapipe/util/annotation_renderer.cc index 19fbbc14d..671f47505 100644 --- a/mediapipe/util/annotation_renderer.cc +++ b/mediapipe/util/annotation_renderer.cc @@ -552,6 +552,16 @@ void AnnotationRenderer::DrawText(const RenderAnnotation& annotation) { origin.y += text_size.height / 2; } + if (text.outline_thickness() > 0.0) { + const int background_thickness = ClampThickness( + round((annotation.thickness() + 2.0 * text.outline_thickness()) * + scale_factor_)); + const cv::Scalar outline_color = + MediapipeColorToOpenCVColor(text.outline_color()); + cv::putText(mat_image_, text.display_text(), origin, font_face, font_scale, + outline_color, background_thickness, /*lineType=*/8, + /*bottomLeftOrigin=*/flip_text_vertically_); + } cv::putText(mat_image_, text.display_text(), origin, font_face, font_scale, color, thickness, /*lineType=*/8, /*bottomLeftOrigin=*/flip_text_vertically_); diff --git a/mediapipe/util/render_data.proto b/mediapipe/util/render_data.proto index 0ff6b3409..62cb750b0 100644 --- a/mediapipe/util/render_data.proto +++ b/mediapipe/util/render_data.proto @@ -168,6 +168,12 @@ message RenderAnnotation { // [left, baseline] represent [center_x, center_y]. optional bool center_horizontally = 7 [default = false]; optional bool center_vertically = 8 [default = false]; + + // Thickness of the text outline. + optional double outline_thickness = 11 [default = 0.0]; + + // Color of the text outline. + optional Color outline_color = 12; } // The RenderAnnotation can be one of the below formats. From 382158298bc09dbe387043b4bc31715fdd881d10 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 30 Sep 2022 01:43:38 +0000 Subject: [PATCH 005/132] Update default Tflite model OpResolver in BaseOptions. PiperOrigin-RevId: 477873299 --- mediapipe/tasks/cc/core/BUILD | 16 +++++++++ mediapipe/tasks/cc/core/base_options.h | 3 +- .../mediapipe_builtin_op_resolver.cc} | 12 +++---- .../mediapipe_builtin_op_resolver.h} | 18 +++++----- mediapipe/tasks/cc/vision/hand_detector/BUILD | 12 ------- .../hand_detector/hand_detector_graph_test.cc | 6 ++-- .../hand_detector_op_resolver.cc | 35 ------------------- .../hand_detector/hand_detector_op_resolver.h | 34 ------------------ .../tasks/cc/vision/image_segmenter/BUILD | 16 --------- .../vision/image_segmenter/image_segmenter.h | 1 - .../image_segmenter/image_segmenter_test.cc | 5 --- 11 files changed, 33 insertions(+), 125 deletions(-) rename mediapipe/tasks/cc/{vision/image_segmenter/image_segmenter_op_resolvers.cc => core/mediapipe_builtin_op_resolver.cc} (87%) rename mediapipe/tasks/cc/{vision/image_segmenter/image_segmenter_op_resolvers.h => core/mediapipe_builtin_op_resolver.h} (65%) delete mode 100644 mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc delete mode 100644 mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 38030c525..8d19227f1 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -23,6 +23,7 @@ cc_library( srcs = ["base_options.cc"], hdrs = ["base_options.h"], deps = [ + ":mediapipe_builtin_op_resolver", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", @@ -50,6 +51,21 @@ cc_library( ], ) +cc_library( + name = "mediapipe_builtin_op_resolver", + srcs = ["mediapipe_builtin_op_resolver.cc"], + hdrs = ["mediapipe_builtin_op_resolver.h"], + deps = [ + "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", + "//mediapipe/util/tflite/operations:max_pool_argmax", + "//mediapipe/util/tflite/operations:max_unpooling", + "//mediapipe/util/tflite/operations:transform_landmarks", + "//mediapipe/util/tflite/operations:transform_tensor_bilinear", + "//mediapipe/util/tflite/operations:transpose_conv_bias", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) + # TODO: Switch to use cc_library_with_tflite after the MediaPipe InferenceCalculator # supports TFLite-in-GMSCore. cc_library( diff --git a/mediapipe/tasks/cc/core/base_options.h b/mediapipe/tasks/cc/core/base_options.h index 67a03385b..4717ea50e 100644 --- a/mediapipe/tasks/cc/core/base_options.h +++ b/mediapipe/tasks/cc/core/base_options.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/kernels/register.h" @@ -63,7 +64,7 @@ struct BaseOptions { // A non-default OpResolver to support custom Ops or specify a subset of // built-in Ops. std::unique_ptr op_resolver = - absl::make_unique(); + absl::make_unique(); }; // Converts a BaseOptions to a BaseOptionsProto. diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc similarity index 87% rename from mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc rename to mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc index cd3b5690f..62898a005 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc +++ b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" #include "mediapipe/util/tflite/operations/max_pool_argmax.h" @@ -21,14 +21,11 @@ limitations under the License. #include "mediapipe/util/tflite/operations/transform_landmarks.h" #include "mediapipe/util/tflite/operations/transform_tensor_bilinear.h" #include "mediapipe/util/tflite/operations/transpose_conv_bias.h" -#include "tensorflow/lite/kernels/register.h" namespace mediapipe { namespace tasks { -namespace vision { - -SelfieSegmentationModelOpResolver::SelfieSegmentationModelOpResolver() - : BuiltinOpResolver() { +namespace core { +MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() { AddCustom("MaxPoolingWithArgmax2D", mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D()); AddCustom("MaxUnpooling2D", @@ -46,7 +43,6 @@ SelfieSegmentationModelOpResolver::SelfieSegmentationModelOpResolver() mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(), /*version=*/2); } - -} // namespace vision +} // namespace core } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h similarity index 65% rename from mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h rename to mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h index a0538a674..a7c28aa71 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h +++ b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h @@ -13,25 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ -#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ +#ifndef MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_ +#define MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_ #include "tensorflow/lite/kernels/register.h" namespace mediapipe { namespace tasks { -namespace vision { - -class SelfieSegmentationModelOpResolver +namespace core { +class MediaPipeBuiltinOpResolver : public tflite::ops::builtin::BuiltinOpResolver { public: - SelfieSegmentationModelOpResolver(); - SelfieSegmentationModelOpResolver( - const SelfieSegmentationModelOpResolver& r) = delete; + MediaPipeBuiltinOpResolver(); + MediaPipeBuiltinOpResolver(const MediaPipeBuiltinOpResolver& r) = delete; }; -} // namespace vision +} // namespace core } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ +#endif // MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_ diff --git a/mediapipe/tasks/cc/vision/hand_detector/BUILD b/mediapipe/tasks/cc/vision/hand_detector/BUILD index 23cf5f72d..c87cc50a6 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/BUILD +++ b/mediapipe/tasks/cc/vision/hand_detector/BUILD @@ -18,18 +18,6 @@ package(default_visibility = [ licenses(["notice"]) -cc_library( - name = "hand_detector_op_resolver", - srcs = ["hand_detector_op_resolver.cc"], - hdrs = ["hand_detector_op_resolver.h"], - deps = [ - "//mediapipe/util/tflite/operations:max_pool_argmax", - "//mediapipe/util/tflite/operations:max_unpooling", - "//mediapipe/util/tflite/operations:transpose_conv_bias", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - ], -) - cc_library( name = "hand_detector_graph", srcs = ["hand_detector_graph.cc"], diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc index a2fbd7c54..3fa97664e 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc @@ -35,11 +35,11 @@ limitations under the License. #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" -#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" @@ -121,8 +121,8 @@ absl::StatusOr> CreateTaskRunner( hand_detection.Out(kHandNormRectsTag).SetName(kHandNormRectsName) >> graph[Output>(kHandNormRectsTag)]; - return TaskRunner::Create(graph.GetConfig(), - absl::make_unique()); + return TaskRunner::Create( + graph.GetConfig(), std::make_unique()); } HandDetectorResult GetExpectedHandDetectorResult(absl::string_view file_name) { diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc deleted file mode 100644 index 262fb2c75..000000000 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h" - -#include "mediapipe/util/tflite/operations/max_pool_argmax.h" -#include "mediapipe/util/tflite/operations/max_unpooling.h" -#include "mediapipe/util/tflite/operations/transpose_conv_bias.h" - -namespace mediapipe { -namespace tasks { -namespace vision { -HandDetectorOpResolver::HandDetectorOpResolver() { - AddCustom("MaxPoolingWithArgmax2D", - mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D()); - AddCustom("MaxUnpooling2D", - mediapipe::tflite_operations::RegisterMaxUnpooling2D()); - AddCustom("Convolution2DTransposeBias", - mediapipe::tflite_operations::RegisterConvolution2DTransposeBias()); -} -} // namespace vision -} // namespace tasks -} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h deleted file mode 100644 index a55661fa3..000000000 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ -#define MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ - -#include "tensorflow/lite/kernels/register.h" - -namespace mediapipe { -namespace tasks { -namespace vision { -class HandDetectorOpResolver : public tflite::ops::builtin::BuiltinOpResolver { - public: - HandDetectorOpResolver(); - HandDetectorOpResolver(const HandDetectorOpResolver& r) = delete; -}; - -} // namespace vision -} // namespace tasks -} // namespace mediapipe - -#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 6af733657..6bdbf41da 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -33,7 +33,6 @@ cc_library( "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ], ) @@ -73,19 +72,4 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "image_segmenter_op_resolvers", - srcs = ["image_segmenter_op_resolvers.cc"], - hdrs = ["image_segmenter_op_resolvers.h"], - deps = [ - "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", - "//mediapipe/util/tflite/operations:max_pool_argmax", - "//mediapipe/util/tflite/operations:max_unpooling", - "//mediapipe/util/tflite/operations:transform_landmarks", - "//mediapipe/util/tflite/operations:transform_tensor_bilinear", - "//mediapipe/util/tflite/operations:transpose_conv_bias", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - ], -) - # TODO: This test fails in OSS diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index ce9cb104c..e2734c4e4 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -26,7 +26,6 @@ limitations under the License. #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" -#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/kernels/register.h" namespace mediapipe { diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 2f1c26a79..1d3f3e786 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -31,7 +31,6 @@ limitations under the License. #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" -#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" @@ -260,8 +259,6 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata); - options->base_options.op_resolver = - absl::make_unique(); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; options->activation = ImageSegmenterOptions::Activation::SOFTMAX; @@ -290,8 +287,6 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata); - options->base_options.op_resolver = - absl::make_unique(); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; options->activation = ImageSegmenterOptions::Activation::NONE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, From 133c3b3c00deed45b396e286134e8ab70662bc0a Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 30 Sep 2022 03:26:32 +0000 Subject: [PATCH 006/132] Internal change PiperOrigin-RevId: 477887963 --- mediapipe/gpu/gl_texture_buffer.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 69d2fab7a..7d095a5d4 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -185,7 +185,10 @@ void GlTextureBuffer::Updated(std::shared_ptr prod_token) { << "Updated existing texture which had not been marked for reuse!"; CHECK(prod_token); producer_sync_ = std::move(prod_token); - producer_context_ = producer_sync_->GetContext(); + const auto& synced_context = producer_sync_->GetContext(); + if (synced_context) { + producer_context_ = synced_context; + } } void GlTextureBuffer::DidRead(std::shared_ptr cons_token) const { From 3a3a470a0c7e1e1346b53f0d90b35d7881b305a0 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 30 Sep 2022 05:13:59 +0000 Subject: [PATCH 007/132] Move hand_association_calculator to open source MP PiperOrigin-RevId: 477901001 --- .../vision/hand_landmarker/calculators/BUILD | 49 +++ .../hand_association_calculator.cc | 125 ++++++++ .../hand_association_calculator.proto | 28 ++ .../hand_association_calculator_test.cc | 302 ++++++++++++++++++ 4 files changed, 504 insertions(+) create mode 100644 mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD create mode 100644 mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc create mode 100644 mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.proto create mode 100644 mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD new file mode 100644 index 000000000..dea81bae3 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD @@ -0,0 +1,49 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = [ + "//mediapipe/app/xeno:__subpackages__", + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "hand_association_calculator_proto", + srcs = ["hand_association_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "hand_association_calculator", + srcs = ["hand_association_calculator.cc"], + deps = [ + ":hand_association_calculator_cc_proto", + "//mediapipe/calculators/util:association_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:rectangle", + "//mediapipe/framework/port:status", + "//mediapipe/util:rectangle_util", + ], + alwayslink = 1, +) + +# TODO: Enable this test diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc new file mode 100644 index 000000000..b6df80588 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc @@ -0,0 +1,125 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/rectangle.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h" +#include "mediapipe/util/rectangle_util.h" + +namespace mediapipe::api2 { + +// HandAssociationCalculator accepts multiple inputs of vectors of +// NormalizedRect. The output is a vector of NormalizedRect that contains +// rects from the input vectors that don't overlap with each other. When two +// rects overlap, the rect that comes in from an earlier input stream is +// kept in the output. If a rect has no ID (i.e. from detection stream), +// then a unique rect ID is assigned for it. + +// The rects in multiple input streams are effectively flattened to a single +// list. For example: +// Stream1 : rect 1, rect 2 +// Stream2: rect 3, rect 4 +// Stream3: rect 5, rect 6 +// (Conceptually) flattened list : rect 1, 2, 3, 4, 5, 6 +// In the flattened list, if a rect with a higher index overlaps with a rect a +// lower index, beyond a specified IOU threshold, the rect with the lower +// index will be in the output, and the rect with higher index will be +// discarded. +// TODO: Upgrade this to latest API for calculators +class HandAssociationCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + // Initialize input and output streams. + for (auto& input_stream : cc->Inputs()) { + input_stream.Set>(); + } + cc->Outputs().Index(0).Set>(); + + return absl::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + + options_ = cc->Options(); + CHECK_GT(options_.min_similarity_threshold(), 0.0); + CHECK_LE(options_.min_similarity_threshold(), 1.0); + + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + ASSIGN_OR_RETURN(auto result, GetNonOverlappingElements(cc)); + + auto output = + std::make_unique>(std::move(result)); + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + + return absl::OkStatus(); + } + + private: + HandAssociationCalculatorOptions options_; + + // Return a list of non-overlapping elements from all input streams, with + // decreasing order of priority based on input stream index and indices + // within an input stream. + absl::StatusOr> GetNonOverlappingElements( + CalculatorContext* cc) { + std::vector result; + + for (const auto& input_stream : cc->Inputs()) { + if (input_stream.IsEmpty()) { + continue; + } + + for (auto rect : input_stream.Get>()) { + ASSIGN_OR_RETURN( + bool is_overlapping, + mediapipe::DoesRectOverlap(rect, result, + options_.min_similarity_threshold())); + if (!is_overlapping) { + if (!rect.has_rect_id()) { + rect.set_rect_id(GetNextRectId()); + } + result.push_back(rect); + } + } + } + + return result; + } + + private: + // Each NormalizedRect processed by the calculator will be assigned + // an unique id, if it does not already have an ID. The starting ID will be 1. + // Note: This rect_id_ is local to an instance of this calculator. And it is + // expected that the hand tracking graph to have only one instance of + // this association calculator. + int64 rect_id_ = 1; + + inline int GetNextRectId() { return rect_id_++; } +}; + +MEDIAPIPE_REGISTER_NODE(HandAssociationCalculator); + +} // namespace mediapipe::api2 diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.proto b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.proto new file mode 100644 index 000000000..e7229b4a2 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.proto @@ -0,0 +1,28 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message HandAssociationCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional HandAssociationCalculatorOptions ext = 408244367; + } + + optional float min_similarity_threshold = 1 [default = 1.0]; +} diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc new file mode 100644 index 000000000..cb3130854 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc @@ -0,0 +1,302 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" + +namespace mediapipe { +namespace { + +class HandAssociationCalculatorTest : public testing::Test { + protected: + HandAssociationCalculatorTest() { + // 0.4 ================ + // | | | | + // 0.3 ===================== | NR2 | | + // | | | NR1 | | | NR4 | + // 0.2 | NR0 | =========== ================ + // | | | | | | + // 0.1 =====|=============== | + // | NR3 | | | + // 0.0 ================ | + // | NR5 | + // -0.1 =========== + // 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2 + + // NormalizedRect nr_0. + nr_0_.set_x_center(0.2); + nr_0_.set_y_center(0.2); + nr_0_.set_width(0.2); + nr_0_.set_height(0.2); + + // NormalizedRect nr_1. + nr_1_.set_x_center(0.4); + nr_1_.set_y_center(0.2); + nr_1_.set_width(0.2); + nr_1_.set_height(0.2); + + // NormalizedRect nr_2. + nr_2_.set_x_center(1.0); + nr_2_.set_y_center(0.3); + nr_2_.set_width(0.2); + nr_2_.set_height(0.2); + + // NormalizedRect nr_3. + nr_3_.set_x_center(0.35); + nr_3_.set_y_center(0.15); + nr_3_.set_width(0.3); + nr_3_.set_height(0.3); + + // NormalizedRect nr_4. + nr_4_.set_x_center(1.1); + nr_4_.set_y_center(0.3); + nr_4_.set_width(0.2); + nr_4_.set_height(0.2); + + // NormalizedRect nr_5. + nr_5_.set_x_center(0.5); + nr_5_.set_y_center(0.05); + nr_5_.set_width(0.3); + nr_5_.set_height(0.3); + } + + NormalizedRect nr_0_, nr_1_, nr_2_, nr_3_, nr_4_, nr_5_; +}; + +TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "HandAssociationCalculator" + input_stream: "input_vec_0" + input_stream: "input_vec_1" + input_stream: "input_vec_2" + output_stream: "output_vec" + options { + [mediapipe.HandAssociationCalculatorOptions.ext] { + min_similarity_threshold: 0.1 + } + } + )pb")); + + // Input Stream 0: nr_0, nr_1, nr_2. + auto input_vec_0 = std::make_unique>(); + input_vec_0->push_back(nr_0_); + input_vec_0->push_back(nr_1_); + input_vec_0->push_back(nr_2_); + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(input_vec_0.release()).At(Timestamp(1))); + + // Input Stream 1: nr_3, nr_4. + auto input_vec_1 = std::make_unique>(); + input_vec_1->push_back(nr_3_); + input_vec_1->push_back(nr_4_); + runner.MutableInputs()->Index(1).packets.push_back( + Adopt(input_vec_1.release()).At(Timestamp(1))); + + // Input Stream 2: nr_5. + auto input_vec_2 = std::make_unique>(); + input_vec_2->push_back(nr_5_); + runner.MutableInputs()->Index(2).packets.push_back( + Adopt(input_vec_2.release()).At(Timestamp(1))); + + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, output.size()); + auto assoc_rects = output[0].Get>(); + + // Rectangles are added in the following sequence: + // nr_0 is added 1st. + // nr_1 is added because it does not overlap with nr_0. + // nr_2 is added because it does not overlap with nr_0 or nr_1. + // nr_3 is NOT added because it overlaps with nr_0. + // nr_4 is NOT added because it overlaps with nr_2. + // nr_5 is NOT added because it overlaps with nr_1. + EXPECT_EQ(3, assoc_rects.size()); + + // Check that IDs are filled in and contents match. + EXPECT_EQ(assoc_rects[0].rect_id(), 1); + assoc_rects[0].clear_rect_id(); + EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_)); + + EXPECT_EQ(assoc_rects[1].rect_id(), 2); + assoc_rects[1].clear_rect_id(); + EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_)); + + EXPECT_EQ(assoc_rects[2].rect_id(), 3); + assoc_rects[2].clear_rect_id(); + EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_)); +} + +TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "HandAssociationCalculator" + input_stream: "input_vec_0" + input_stream: "input_vec_1" + input_stream: "input_vec_2" + output_stream: "output_vec" + options { + [mediapipe.HandAssociationCalculatorOptions.ext] { + min_similarity_threshold: 0.1 + } + } + )pb")); + + // Input Stream 0: nr_0, nr_1. Tracked hands. + auto input_vec_0 = std::make_unique>(); + // Setting ID to a negative number for test only, since newly generated + // ID by HandAssociationCalculator are positive numbers. + nr_0_.set_rect_id(-2); + input_vec_0->push_back(nr_0_); + nr_1_.set_rect_id(-1); + input_vec_0->push_back(nr_1_); + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(input_vec_0.release()).At(Timestamp(1))); + + // Input Stream 1: nr_2, nr_3. Newly detected palms. + auto input_vec_1 = std::make_unique>(); + input_vec_1->push_back(nr_2_); + input_vec_1->push_back(nr_3_); + runner.MutableInputs()->Index(1).packets.push_back( + Adopt(input_vec_1.release()).At(Timestamp(1))); + + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, output.size()); + auto assoc_rects = output[0].Get>(); + + // Rectangles are added in the following sequence: + // nr_0 is added 1st. + // nr_1 is added because it does not overlap with nr_0. + // nr_2 is added because it does not overlap with nr_0 or nr_1. + // nr_3 is NOT added because it overlaps with nr_0. + EXPECT_EQ(3, assoc_rects.size()); + + // Check that IDs are filled in and contents match. + EXPECT_EQ(assoc_rects[0].rect_id(), -2); + EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_)); + + EXPECT_EQ(assoc_rects[1].rect_id(), -1); + EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_)); + + EXPECT_EQ(assoc_rects[2].rect_id(), 1); + assoc_rects[2].clear_rect_id(); + EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_)); +} + +TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "HandAssociationCalculator" + input_stream: "input_vec_0" + input_stream: "input_vec_1" + input_stream: "input_vec_2" + output_stream: "output_vec" + options { + [mediapipe.HandAssociationCalculatorOptions.ext] { + min_similarity_threshold: 0.1 + } + } + )pb")); + + // Input Stream 0: nr_5. + auto input_vec_0 = std::make_unique>(); + input_vec_0->push_back(nr_5_); + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(input_vec_0.release()).At(Timestamp(1))); + + // Input Stream 1: nr_4, nr_3 + auto input_vec_1 = std::make_unique>(); + input_vec_1->push_back(nr_4_); + input_vec_1->push_back(nr_3_); + runner.MutableInputs()->Index(1).packets.push_back( + Adopt(input_vec_1.release()).At(Timestamp(1))); + + // Input Stream 2: nr_2, nr_1, nr_0. + auto input_vec_2 = std::make_unique>(); + input_vec_2->push_back(nr_2_); + input_vec_2->push_back(nr_1_); + input_vec_2->push_back(nr_0_); + runner.MutableInputs()->Index(2).packets.push_back( + Adopt(input_vec_2.release()).At(Timestamp(1))); + + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, output.size()); + auto assoc_rects = output[0].Get>(); + + // Rectangles are added in the following sequence: + // nr_5 is added 1st. + // nr_4 is added because it does not overlap with nr_5. + // nr_3 is NOT added because it overlaps with nr_5. + // nr_2 is NOT added because it overlaps with nr_4. + // nr_1 is NOT added because it overlaps with nr_5. + // nr_0 is added because it does not overlap with nr_5 or nr_4. + EXPECT_EQ(3, assoc_rects.size()); + + // Outputs are in same order as inputs, and IDs are filled in. + EXPECT_EQ(assoc_rects[0].rect_id(), 1); + assoc_rects[0].clear_rect_id(); + EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_5_)); + + EXPECT_EQ(assoc_rects[1].rect_id(), 2); + assoc_rects[1].clear_rect_id(); + EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_4_)); + + EXPECT_EQ(assoc_rects[2].rect_id(), 3); + assoc_rects[2].clear_rect_id(); + EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_0_)); +} + +TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "HandAssociationCalculator" + input_stream: "input_vec" + output_stream: "output_vec" + options { + [mediapipe.HandAssociationCalculatorOptions.ext] { + min_similarity_threshold: 0.1 + } + } + )pb")); + + // Just one input stream : nr_3, nr_5. + auto input_vec = std::make_unique>(); + input_vec->push_back(nr_3_); + input_vec->push_back(nr_5_); + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(input_vec.release()).At(Timestamp(1))); + + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, output.size()); + auto assoc_rects = output[0].Get>(); + + // Rectangles are added in the following sequence: + // nr_3 is added 1st. + // nr_5 is NOT added because it overlaps with nr_3. + EXPECT_EQ(1, assoc_rects.size()); + + EXPECT_EQ(assoc_rects[0].rect_id(), 1); + assoc_rects[0].clear_rect_id(); + EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_3_)); +} + +} // namespace +} // namespace mediapipe From 3225372c28d9a7d541dc13f2dca8d99b78a8a13c Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 30 Sep 2022 08:10:33 +0000 Subject: [PATCH 008/132] Internal changes PiperOrigin-RevId: 477924417 --- mediapipe/python/image_test.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/mediapipe/python/image_test.py b/mediapipe/python/image_test.py index a21dbdadd..117d20974 100644 --- a/mediapipe/python/image_test.py +++ b/mediapipe/python/image_test.py @@ -187,16 +187,5 @@ class ImageTest(absltest.TestCase): gc.collect() self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count) - def test_image_create_from_file(self): - image_path = os.path.join( - resources.GetRunfilesDir(), - 'mediapipe/tasks/testdata/vision/cat.jpg') - loaded_image = Image.create_from_file(image_path) - self.assertEqual(loaded_image.width, 600) - self.assertEqual(loaded_image.height, 400) - self.assertEqual(loaded_image.channels, 3) - self.assertEqual(loaded_image.image_format, ImageFormat.SRGB) - - if __name__ == '__main__': absltest.main() From 3816951b8cbfcd70a05f2368d72979d2dbac6f78 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 30 Sep 2022 08:17:54 +0000 Subject: [PATCH 009/132] Fix the comment. PiperOrigin-RevId: 477925532 --- .../tasks/vision/objectdetector/ObjectDetector.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java index c78d1baad..463ab4c43 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java @@ -292,7 +292,10 @@ public final class ObjectDetector extends BaseVisionTaskApi { */ public abstract Builder setRunningMode(RunningMode value); - /** Sets the maximum number of top-scored classification results to return. */ + /** + * Sets the locale to use for display names specified through the TFLite Model Metadata, if + * any. Defaults to English. + */ public abstract Builder setDisplayNamesLocale(String value); /** From af2ad1abbe38c07961410a9fe7d000fd15cfa78e Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 30 Sep 2022 08:32:42 +0000 Subject: [PATCH 010/132] Switch MediaPipe Tasks Python and Java base layer to use MediaPipeBuiltinOpResolver by default. PiperOrigin-RevId: 477927852 --- .../tasks/java/com/google/mediapipe/tasks/core/jni/BUILD | 2 +- .../java/com/google/mediapipe/tasks/core/jni/BUILD.bazel | 2 +- .../mediapipe/tasks/core/jni/model_resources_cache_jni.cc | 4 +++- mediapipe/tasks/python/core/pybind/BUILD | 1 + mediapipe/tasks/python/core/pybind/task_runner.cc | 3 ++- 5 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD index a1ec67517..cb3ef9656 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD @@ -28,10 +28,10 @@ cc_library_with_tflite( ], tflite_deps = [ "//mediapipe/tasks/cc/core:model_resources_cache", - "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", ], deps = [ "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + "//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver", ] + select({ "//conditions:default": ["//third_party/java/jdk:jni"], "//mediapipe:android": [], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD.bazel b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD.bazel index fba314b28..0eb74e7ff 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD.bazel +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD.bazel @@ -34,10 +34,10 @@ cc_library_with_tflite( }), tflite_deps = [ "//mediapipe/tasks/cc/core:model_resources_cache", - "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", ], deps = [ "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + "//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver", ] + select({ "//conditions:default": [], "//mediapipe:android": [], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.cc b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.cc index 74ff4a689..aab022dec 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.cc +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.cc @@ -17,11 +17,13 @@ #include #include "mediapipe/java/com/google/mediapipe/framework/jni/graph_service_jni.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "tensorflow/lite/core/shims/cc/kernels/register.h" namespace { using ::mediapipe::tasks::core::kModelResourcesCacheService; +using ::mediapipe::tasks::core::MediaPipeBuiltinOpResolver; using ::mediapipe::tasks::core::ModelResourcesCache; using HandleType = std::shared_ptr*; } // namespace @@ -29,7 +31,7 @@ using HandleType = std::shared_ptr*; JNIEXPORT jlong JNICALL MODEL_RESOURCES_CACHE_METHOD( nativeCreateModelResourcesCache)(JNIEnv* env, jobject thiz) { auto ptr = std::make_shared( - absl::make_unique()); + absl::make_unique()); HandleType handle = new std::shared_ptr(std::move(ptr)); return reinterpret_cast(handle); } diff --git a/mediapipe/tasks/python/core/pybind/BUILD b/mediapipe/tasks/python/core/pybind/BUILD index fab878135..b59635dc3 100644 --- a/mediapipe/tasks/python/core/pybind/BUILD +++ b/mediapipe/tasks/python/core/pybind/BUILD @@ -27,6 +27,7 @@ pybind_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/python/pybind:util", + "//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver", "//mediapipe/tasks/cc/core:task_runner", "@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", diff --git a/mediapipe/tasks/python/core/pybind/task_runner.cc b/mediapipe/tasks/python/core/pybind/task_runner.cc index 52834bab2..cb13787c3 100644 --- a/mediapipe/tasks/python/core/pybind/task_runner.cc +++ b/mediapipe/tasks/python/core/pybind/task_runner.cc @@ -16,6 +16,7 @@ #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/python/pybind/util.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/task_runner.h" #include "pybind11/stl.h" #include "pybind11_protobuf/native_proto_caster.h" @@ -75,7 +76,7 @@ mode) or not (synchronous mode).)doc"); } auto task_runner = TaskRunner::Create( std::move(graph_config), - absl::make_unique(), + absl::make_unique(), std::move(callback)); RaisePyErrorIfNotOk(task_runner.status()); return std::move(*task_runner); From 46a5117c6d2fa9710e584532c8187253ff92fedf Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 30 Sep 2022 17:33:44 +0000 Subject: [PATCH 011/132] Add gate utility functions. PiperOrigin-RevId: 478026407 --- mediapipe/tasks/cc/components/utils/BUILD | 13 + mediapipe/tasks/cc/components/utils/gate.h | 158 ++++++++++++ .../tasks/cc/components/utils/gate_test.cc | 229 ++++++++++++++++++ 3 files changed, 400 insertions(+) create mode 100644 mediapipe/tasks/cc/components/utils/gate.h create mode 100644 mediapipe/tasks/cc/components/utils/gate_test.cc diff --git a/mediapipe/tasks/cc/components/utils/BUILD b/mediapipe/tasks/cc/components/utils/BUILD index 0ec7ac945..d16e2fbc4 100644 --- a/mediapipe/tasks/cc/components/utils/BUILD +++ b/mediapipe/tasks/cc/components/utils/BUILD @@ -42,3 +42,16 @@ cc_test( "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", ], ) + +cc_library( + name = "gate", + hdrs = ["gate.h"], + deps = [ + "//mediapipe/calculators/core:gate_calculator", + "//mediapipe/calculators/core:gate_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + ], +) + +# TODO: Enable this test diff --git a/mediapipe/tasks/cc/components/utils/gate.h b/mediapipe/tasks/cc/components/utils/gate.h new file mode 100644 index 000000000..68a9e781b --- /dev/null +++ b/mediapipe/tasks/cc/components/utils/gate.h @@ -0,0 +1,158 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_GATE_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_GATE_H_ + +#include + +#include "mediapipe/calculators/core/gate_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace utils { + +using ::mediapipe::api2::builder::SideSource; +using ::mediapipe::api2::builder::Source; + +// Utility class that simplifies allowing (gating) multiple streams. +class AllowGate { + public: + AllowGate(Source allow, mediapipe::api2::builder::Graph& graph) + : node_(AddSourceGate(allow, graph)) {} + AllowGate(SideSource allow, mediapipe::api2::builder::Graph& graph) + : node_(AddSideSourceGate(allow, graph)) {} + + // Move-only + AllowGate(AllowGate&& allow_gate) = default; + AllowGate& operator=(AllowGate&& allow_gate) = default; + + template + Source Allow(Source source) { + source >> node_.In(index_); + return node_.Out(index_++).Cast(); + } + + private: + template + static mediapipe::api2::builder::GenericNode& AddSourceGate( + T allow, mediapipe::api2::builder::Graph& graph) { + auto& gate_node = graph.AddNode("GateCalculator"); + allow >> gate_node.In("ALLOW"); + return gate_node; + } + + template + static mediapipe::api2::builder::GenericNode& AddSideSourceGate( + T allow, mediapipe::api2::builder::Graph& graph) { + auto& gate_node = graph.AddNode("GateCalculator"); + allow >> gate_node.SideIn("ALLOW"); + return gate_node; + } + + mediapipe::api2::builder::GenericNode& node_; + int index_ = 0; +}; + +// Utility class that simplifies disallowing (gating) multiple streams. +class DisallowGate { + public: + DisallowGate(Source disallow, mediapipe::api2::builder::Graph& graph) + : node_(AddSourceGate(disallow, graph)) {} + DisallowGate(SideSource disallow, + mediapipe::api2::builder::Graph& graph) + : node_(AddSideSourceGate(disallow, graph)) {} + + // Move-only + DisallowGate(DisallowGate&& disallow_gate) = default; + DisallowGate& operator=(DisallowGate&& disallow_gate) = default; + + template + Source Disallow(Source source) { + source >> node_.In(index_); + return node_.Out(index_++).Cast(); + } + + private: + template + static mediapipe::api2::builder::GenericNode& AddSourceGate( + T disallow, mediapipe::api2::builder::Graph& graph) { + auto& gate_node = graph.AddNode("GateCalculator"); + auto& gate_node_opts = + gate_node.GetOptions(); + // Supposedly, the most popular configuration for MediaPipe Tasks team + // graphs. Hence, intentionally hard coded to catch and verify any other use + // case (should help to workout a common approach and have a recommended way + // of blocking streams). + gate_node_opts.set_empty_packets_as_allow(true); + disallow >> gate_node.In("DISALLOW"); + return gate_node; + } + + template + static mediapipe::api2::builder::GenericNode& AddSideSourceGate( + T disallow, mediapipe::api2::builder::Graph& graph) { + auto& gate_node = graph.AddNode("GateCalculator"); + auto& gate_node_opts = + gate_node.GetOptions(); + gate_node_opts.set_empty_packets_as_allow(true); + disallow >> gate_node.SideIn("DISALLOW"); + return gate_node; + } + + mediapipe::api2::builder::GenericNode& node_; + int index_ = 0; +}; + +// Updates graph to drop @value stream packet if corresponding @condition stream +// packet holds true. +template +Source DisallowIf(Source value, Source condition, + mediapipe::api2::builder::Graph& graph) { + return DisallowGate(condition, graph).Disallow(value); +} + +// Updates graph to drop @value stream packet if corresponding @condition stream +// packet holds true. +template +Source DisallowIf(Source value, SideSource condition, + mediapipe::api2::builder::Graph& graph) { + return DisallowGate(condition, graph).Disallow(value); +} + +// Updates graph to pass through @value stream packet if corresponding +// @condition stream packet holds true. +template +Source AllowIf(Source value, Source allow, + mediapipe::api2::builder::Graph& graph) { + return AllowGate(allow, graph).Allow(value); +} + +// Updates graph to pass through @value stream packet if corresponding +// @condition side stream packet holds true. +template +Source AllowIf(Source value, SideSource allow, + mediapipe::api2::builder::Graph& graph) { + return AllowGate(allow, graph).Allow(value); +} + +} // namespace utils +} // namespace components +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_GATE_H_ diff --git a/mediapipe/tasks/cc/components/utils/gate_test.cc b/mediapipe/tasks/cc/components/utils/gate_test.cc new file mode 100644 index 000000000..7fdca48e7 --- /dev/null +++ b/mediapipe/tasks/cc/components/utils/gate_test.cc @@ -0,0 +1,229 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/utils/gate.h" + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_graph.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace utils { +namespace { + +using ::mediapipe::api2::builder::SideSource; +using ::mediapipe::api2::builder::Source; + +TEST(DisallowGate, VerifyConfig) { + mediapipe::api2::builder::Graph graph; + + Source condition = graph.In("CONDITION").Cast(); + Source value1 = graph.In("VALUE_1").Cast(); + Source value2 = graph.In("VALUE_2").Cast(); + Source value3 = graph.In("VALUE_3").Cast(); + + DisallowGate gate(condition, graph); + gate.Disallow(value1).SetName("gated_stream1"); + gate.Disallow(value2).SetName("gated_stream2"); + gate.Disallow(value3).SetName("gated_stream3"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_1" + input_stream: "__stream_2" + input_stream: "__stream_3" + input_stream: "DISALLOW:__stream_0" + output_stream: "gated_stream1" + output_stream: "gated_stream2" + output_stream: "gated_stream3" + options { + [mediapipe.GateCalculatorOptions.ext] { + empty_packets_as_allow: true + } + } + } + input_stream: "CONDITION:__stream_0" + input_stream: "VALUE_1:__stream_1" + input_stream: "VALUE_2:__stream_2" + input_stream: "VALUE_3:__stream_3" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(DisallowIf, VerifyConfig) { + mediapipe::api2::builder::Graph graph; + + Source value = graph.In("VALUE").Cast(); + Source condition = graph.In("CONDITION").Cast(); + + auto gated_stream = DisallowIf(value, condition, graph); + gated_stream.SetName("gated_stream"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_1" + input_stream: "DISALLOW:__stream_0" + output_stream: "gated_stream" + options { + [mediapipe.GateCalculatorOptions.ext] { + empty_packets_as_allow: true + } + } + } + input_stream: "CONDITION:__stream_0" + input_stream: "VALUE:__stream_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(DisallowIf, VerifyConfigWithSideCondition) { + mediapipe::api2::builder::Graph graph; + + Source value = graph.In("VALUE").Cast(); + SideSource condition = graph.SideIn("CONDITION").Cast(); + + auto gated_stream = DisallowIf(value, condition, graph); + gated_stream.SetName("gated_stream"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_0" + output_stream: "gated_stream" + input_side_packet: "DISALLOW:__side_packet_1" + options { + [mediapipe.GateCalculatorOptions.ext] { + empty_packets_as_allow: true + } + } + } + input_stream: "VALUE:__stream_0" + input_side_packet: "CONDITION:__side_packet_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(AllowGate, VerifyConfig) { + mediapipe::api2::builder::Graph graph; + + Source condition = graph.In("CONDITION").Cast(); + Source value1 = graph.In("VALUE_1").Cast(); + Source value2 = graph.In("VALUE_2").Cast(); + Source value3 = graph.In("VALUE_3").Cast(); + + AllowGate gate(condition, graph); + gate.Allow(value1).SetName("gated_stream1"); + gate.Allow(value2).SetName("gated_stream2"); + gate.Allow(value3).SetName("gated_stream3"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_1" + input_stream: "__stream_2" + input_stream: "__stream_3" + input_stream: "ALLOW:__stream_0" + output_stream: "gated_stream1" + output_stream: "gated_stream2" + output_stream: "gated_stream3" + } + input_stream: "CONDITION:__stream_0" + input_stream: "VALUE_1:__stream_1" + input_stream: "VALUE_2:__stream_2" + input_stream: "VALUE_3:__stream_3" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(AllowIf, VerifyConfig) { + mediapipe::api2::builder::Graph graph; + + Source value = graph.In("VALUE").Cast(); + Source condition = graph.In("CONDITION").Cast(); + + auto gated_stream = AllowIf(value, condition, graph); + gated_stream.SetName("gated_stream"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_1" + input_stream: "ALLOW:__stream_0" + output_stream: "gated_stream" + } + input_stream: "CONDITION:__stream_0" + input_stream: "VALUE:__stream_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(AllowIf, VerifyConfigWithSideConition) { + mediapipe::api2::builder::Graph graph; + + Source value = graph.In("VALUE").Cast(); + SideSource condition = graph.SideIn("CONDITION").Cast(); + + auto gated_stream = AllowIf(value, condition, graph); + gated_stream.SetName("gated_stream"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_0" + output_stream: "gated_stream" + input_side_packet: "ALLOW:__side_packet_1" + } + input_stream: "VALUE:__stream_0" + input_side_packet: "CONDITION:__side_packet_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +} // namespace +} // namespace utils +} // namespace components +} // namespace tasks +} // namespace mediapipe From a3dc91fafeac261ae9079925791d07c124c4df7b Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 30 Sep 2022 15:25:57 -0700 Subject: [PATCH 012/132] Internal change PiperOrigin-RevId: 478093259 --- mediapipe/tasks/cc/components/utils/gate.h | 62 +++++++++++----------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/mediapipe/tasks/cc/components/utils/gate.h b/mediapipe/tasks/cc/components/utils/gate.h index 68a9e781b..139205fc5 100644 --- a/mediapipe/tasks/cc/components/utils/gate.h +++ b/mediapipe/tasks/cc/components/utils/gate.h @@ -26,15 +26,12 @@ namespace tasks { namespace components { namespace utils { -using ::mediapipe::api2::builder::SideSource; -using ::mediapipe::api2::builder::Source; - // Utility class that simplifies allowing (gating) multiple streams. class AllowGate { public: - AllowGate(Source allow, mediapipe::api2::builder::Graph& graph) + AllowGate(api2::builder::Source allow, api2::builder::Graph& graph) : node_(AddSourceGate(allow, graph)) {} - AllowGate(SideSource allow, mediapipe::api2::builder::Graph& graph) + AllowGate(api2::builder::SideSource allow, api2::builder::Graph& graph) : node_(AddSideSourceGate(allow, graph)) {} // Move-only @@ -42,39 +39,40 @@ class AllowGate { AllowGate& operator=(AllowGate&& allow_gate) = default; template - Source Allow(Source source) { + api2::builder::Source Allow(api2::builder::Source source) { source >> node_.In(index_); return node_.Out(index_++).Cast(); } private: template - static mediapipe::api2::builder::GenericNode& AddSourceGate( - T allow, mediapipe::api2::builder::Graph& graph) { + static api2::builder::GenericNode& AddSourceGate( + T allow, api2::builder::Graph& graph) { auto& gate_node = graph.AddNode("GateCalculator"); allow >> gate_node.In("ALLOW"); return gate_node; } template - static mediapipe::api2::builder::GenericNode& AddSideSourceGate( - T allow, mediapipe::api2::builder::Graph& graph) { + static api2::builder::GenericNode& AddSideSourceGate( + T allow, api2::builder::Graph& graph) { auto& gate_node = graph.AddNode("GateCalculator"); allow >> gate_node.SideIn("ALLOW"); return gate_node; } - mediapipe::api2::builder::GenericNode& node_; + api2::builder::GenericNode& node_; int index_ = 0; }; // Utility class that simplifies disallowing (gating) multiple streams. class DisallowGate { public: - DisallowGate(Source disallow, mediapipe::api2::builder::Graph& graph) + DisallowGate(api2::builder::Source disallow, + api2::builder::Graph& graph) : node_(AddSourceGate(disallow, graph)) {} - DisallowGate(SideSource disallow, - mediapipe::api2::builder::Graph& graph) + DisallowGate(api2::builder::SideSource disallow, + api2::builder::Graph& graph) : node_(AddSideSourceGate(disallow, graph)) {} // Move-only @@ -82,15 +80,15 @@ class DisallowGate { DisallowGate& operator=(DisallowGate&& disallow_gate) = default; template - Source Disallow(Source source) { + api2::builder::Source Disallow(api2::builder::Source source) { source >> node_.In(index_); return node_.Out(index_++).Cast(); } private: template - static mediapipe::api2::builder::GenericNode& AddSourceGate( - T disallow, mediapipe::api2::builder::Graph& graph) { + static api2::builder::GenericNode& AddSourceGate( + T disallow, api2::builder::Graph& graph) { auto& gate_node = graph.AddNode("GateCalculator"); auto& gate_node_opts = gate_node.GetOptions(); @@ -104,8 +102,8 @@ class DisallowGate { } template - static mediapipe::api2::builder::GenericNode& AddSideSourceGate( - T disallow, mediapipe::api2::builder::Graph& graph) { + static api2::builder::GenericNode& AddSideSourceGate( + T disallow, api2::builder::Graph& graph) { auto& gate_node = graph.AddNode("GateCalculator"); auto& gate_node_opts = gate_node.GetOptions(); @@ -114,39 +112,43 @@ class DisallowGate { return gate_node; } - mediapipe::api2::builder::GenericNode& node_; + api2::builder::GenericNode& node_; int index_ = 0; }; // Updates graph to drop @value stream packet if corresponding @condition stream // packet holds true. template -Source DisallowIf(Source value, Source condition, - mediapipe::api2::builder::Graph& graph) { +api2::builder::Source DisallowIf(api2::builder::Source value, + api2::builder::Source condition, + api2::builder::Graph& graph) { return DisallowGate(condition, graph).Disallow(value); } // Updates graph to drop @value stream packet if corresponding @condition stream // packet holds true. template -Source DisallowIf(Source value, SideSource condition, - mediapipe::api2::builder::Graph& graph) { +api2::builder::Source DisallowIf(api2::builder::Source value, + api2::builder::SideSource condition, + api2::builder::Graph& graph) { return DisallowGate(condition, graph).Disallow(value); } // Updates graph to pass through @value stream packet if corresponding -// @condition stream packet holds true. +// @allow stream packet holds true. template -Source AllowIf(Source value, Source allow, - mediapipe::api2::builder::Graph& graph) { +api2::builder::Source AllowIf(api2::builder::Source value, + api2::builder::Source allow, + api2::builder::Graph& graph) { return AllowGate(allow, graph).Allow(value); } // Updates graph to pass through @value stream packet if corresponding -// @condition side stream packet holds true. +// @allow side stream packet holds true. template -Source AllowIf(Source value, SideSource allow, - mediapipe::api2::builder::Graph& graph) { +api2::builder::Source AllowIf(api2::builder::Source value, + api2::builder::SideSource allow, + api2::builder::Graph& graph) { return AllowGate(allow, graph).Allow(value); } From 9568de0570c74f796a09b427f0fb71b060ad1354 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Sat, 1 Oct 2022 05:26:59 -0700 Subject: [PATCH 013/132] Remove "-Wno-unused-function" and "$(STACK_FRAME_UNLIMITED)" to resolve the "invalid numeric argument '/Wno-unused-function'" error on Windows. PiperOrigin-RevId: 478192521 --- third_party/stblib.BUILD | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/third_party/stblib.BUILD b/third_party/stblib.BUILD index 5169906cc..8a419d1f2 100644 --- a/third_party/stblib.BUILD +++ b/third_party/stblib.BUILD @@ -7,16 +7,19 @@ package( licenses(["notice"]) # MIT license -exports_files(["LICENSE"]) +COPTS = select({ + "@platforms//os:windows": [], + "//conditions:default": [ + "-Wno-unused-function", + "$(STACK_FRAME_UNLIMITED)", + ], +}) cc_library( name = "stb_image", srcs = ["stb_image.c"], hdrs = ["stb_image.h"], - copts = [ - "-Wno-unused-function", - "$(STACK_FRAME_UNLIMITED)", - ], + copts = COPTS, includes = ["."], ) @@ -24,5 +27,6 @@ cc_library( name = "stb_image_write", srcs = ["stb_image_write.c"], hdrs = ["stb_image_write.h"], + copts = COPTS, includes = ["."], ) From 13f6e0c79713bb1d769e9625d0a1b199c5116ea2 Mon Sep 17 00:00:00 2001 From: Yuqi Li Date: Sat, 1 Oct 2022 21:50:22 -0700 Subject: [PATCH 014/132] Migrate base metadata functionality like MetadataPopulator and MetadataDisplayer class into MediaPipe. PiperOrigin-RevId: 478279747 --- mediapipe/tasks/metadata/BUILD | 12 +- mediapipe/tasks/python/metadata/BUILD | 40 + mediapipe/tasks/python/metadata/__init__.py | 13 + .../python/metadata/flatbuffers_lib/BUILD | 20 + .../flatbuffers_lib/flatbuffers_lib.cc | 59 ++ mediapipe/tasks/python/metadata/metadata.py | 865 ++++++++++++++++++ .../python/metadata/metadata_displayer_cli.py | 34 + .../metadata/metadata_parser.py.template | 26 + mediapipe/tasks/python/test/BUILD | 4 +- mediapipe/tasks/python/test/metadata/BUILD | 31 + .../test/metadata/metadata_parser_test.py | 37 + .../python/test/metadata/metadata_test.py | 857 +++++++++++++++++ mediapipe/tasks/python/test/test_utils.py | 45 + mediapipe/tasks/python/test/vision/BUILD | 2 +- .../test/vision/object_detector_test.py | 6 +- mediapipe/tasks/testdata/metadata/BUILD | 12 +- .../tasks/testdata/metadata/golden_json.json | 28 + third_party/external_files.bzl | 12 + 18 files changed, 2094 insertions(+), 9 deletions(-) create mode 100644 mediapipe/tasks/python/metadata/BUILD create mode 100644 mediapipe/tasks/python/metadata/__init__.py create mode 100644 mediapipe/tasks/python/metadata/flatbuffers_lib/BUILD create mode 100644 mediapipe/tasks/python/metadata/flatbuffers_lib/flatbuffers_lib.cc create mode 100644 mediapipe/tasks/python/metadata/metadata.py create mode 100644 mediapipe/tasks/python/metadata/metadata_displayer_cli.py create mode 100644 mediapipe/tasks/python/metadata/metadata_parser.py.template create mode 100644 mediapipe/tasks/python/test/metadata/BUILD create mode 100644 mediapipe/tasks/python/test/metadata/metadata_parser_test.py create mode 100644 mediapipe/tasks/python/test/metadata/metadata_test.py create mode 100644 mediapipe/tasks/python/test/test_utils.py create mode 100644 mediapipe/tasks/testdata/metadata/golden_json.json diff --git a/mediapipe/tasks/metadata/BUILD b/mediapipe/tasks/metadata/BUILD index 957bf6b74..abd948809 100644 --- a/mediapipe/tasks/metadata/BUILD +++ b/mediapipe/tasks/metadata/BUILD @@ -1,4 +1,4 @@ -load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") +load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library", "flatbuffer_py_library") package( default_visibility = [ @@ -14,3 +14,13 @@ flatbuffer_cc_library( name = "metadata_schema_cc", srcs = ["metadata_schema.fbs"], ) + +flatbuffer_py_library( + name = "schema_py", + srcs = ["@org_tensorflow//tensorflow/lite/schema:schema.fbs"], +) + +flatbuffer_py_library( + name = "metadata_schema_py", + srcs = ["metadata_schema.fbs"], +) diff --git a/mediapipe/tasks/python/metadata/BUILD b/mediapipe/tasks/python/metadata/BUILD new file mode 100644 index 000000000..34ee63f5e --- /dev/null +++ b/mediapipe/tasks/python/metadata/BUILD @@ -0,0 +1,40 @@ +load("//mediapipe/tasks/metadata:build_defs.bzl", "stamp_metadata_parser_version") + +package( + licenses = ["notice"], # Apache 2.0 +) + +stamp_metadata_parser_version( + name = "metadata_parser_py", + srcs = ["metadata_parser.py.template"], + outs = ["metadata_parser.py"], +) + +py_library( + name = "metadata", + srcs = [ + "metadata.py", + ":metadata_parser_py", + ], + data = ["//mediapipe/tasks/metadata:metadata_schema.fbs"], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/cc/metadata/python:_pywrap_metadata_version", + "//mediapipe/tasks/metadata:metadata_schema_py", + "//mediapipe/tasks/metadata:schema_py", + "//mediapipe/tasks/python/metadata/flatbuffers_lib:_pywrap_flatbuffers", + "@flatbuffers//:runtime_py", + ], +) + +py_binary( + name = "metadata_displayer_cli", + srcs = ["metadata_displayer_cli.py"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":metadata", + ], +) diff --git a/mediapipe/tasks/python/metadata/__init__.py b/mediapipe/tasks/python/metadata/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/tasks/python/metadata/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/tasks/python/metadata/flatbuffers_lib/BUILD b/mediapipe/tasks/python/metadata/flatbuffers_lib/BUILD new file mode 100644 index 000000000..303ff3224 --- /dev/null +++ b/mediapipe/tasks/python/metadata/flatbuffers_lib/BUILD @@ -0,0 +1,20 @@ +load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") + +package( + default_visibility = ["//mediapipe/tasks:internal"], + licenses = ["notice"], # Apache 2.0 +) + +pybind_extension( + name = "_pywrap_flatbuffers", + srcs = [ + "flatbuffers_lib.cc", + ], + features = ["-use_header_modules"], + module_name = "_pywrap_flatbuffers", + deps = [ + "@flatbuffers", + "@local_config_python//:python_headers", + "@pybind11", + ], +) diff --git a/mediapipe/tasks/python/metadata/flatbuffers_lib/flatbuffers_lib.cc b/mediapipe/tasks/python/metadata/flatbuffers_lib/flatbuffers_lib.cc new file mode 100644 index 000000000..34407620c --- /dev/null +++ b/mediapipe/tasks/python/metadata/flatbuffers_lib/flatbuffers_lib.cc @@ -0,0 +1,59 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "flatbuffers/flatbuffers.h" +#include "flatbuffers/idl.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" + +namespace tflite { +namespace support { + +PYBIND11_MODULE(_pywrap_flatbuffers, m) { + pybind11::class_(m, "IDLOptions") + .def(pybind11::init<>()) + .def_readwrite("strict_json", &flatbuffers::IDLOptions::strict_json); + pybind11::class_(m, "Parser") + .def(pybind11::init()) + .def("parse", + [](flatbuffers::Parser* self, const std::string& source) { + return self->Parse(source.c_str()); + }) + .def_readonly("builder", &flatbuffers::Parser::builder_) + .def_readonly("error", &flatbuffers::Parser::error_); + pybind11::class_(m, "FlatBufferBuilder") + .def("clear", &flatbuffers::FlatBufferBuilder::Clear) + .def("push_flat_buffer", [](flatbuffers::FlatBufferBuilder* self, + const std::string& contents) { + self->PushFlatBuffer(reinterpret_cast(contents.c_str()), + contents.length()); + }); + m.def("generate_text_file", &flatbuffers::GenerateTextFile); + m.def( + "generate_text", + [](const flatbuffers::Parser& parser, + const std::string& buffer) -> std::string { + std::string text; + if (!flatbuffers::GenerateText( + parser, reinterpret_cast(buffer.c_str()), &text)) { + return ""; + } + return text; + }); +} + +} // namespace support +} // namespace tflite diff --git a/mediapipe/tasks/python/metadata/metadata.py b/mediapipe/tasks/python/metadata/metadata.py new file mode 100644 index 000000000..10a0b9b66 --- /dev/null +++ b/mediapipe/tasks/python/metadata/metadata.py @@ -0,0 +1,865 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow Lite metadata tools.""" + +import copy +import inspect +import io +import os +import shutil +import sys +import tempfile +import warnings +import zipfile + +import flatbuffers +from mediapipe.tasks.cc.metadata.python import _pywrap_metadata_version +from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb +from mediapipe.tasks.metadata import schema_py_generated as _schema_fb +from mediapipe.tasks.python.metadata.flatbuffers_lib import _pywrap_flatbuffers + +try: + # If exists, optionally use TensorFlow to open and check files. Used to + # support more than local file systems. + # In pip requirements, we doesn't necessarily need tensorflow as a dep. + import tensorflow as tf + _open_file = tf.io.gfile.GFile + _exists_file = tf.io.gfile.exists +except ImportError as e: + # If TensorFlow package doesn't exist, fall back to original open and exists. + _open_file = open + _exists_file = os.path.exists + + +def _maybe_open_as_binary(filename, mode): + """Maybe open the binary file, and returns a file-like.""" + if hasattr(filename, "read"): # A file-like has read(). + return filename + openmode = mode if "b" in mode else mode + "b" # Add binary explicitly. + return _open_file(filename, openmode) + + +def _open_as_zipfile(filename, mode="r"): + """Open file as a zipfile. + + Args: + filename: str or file-like or path-like, to the zipfile. + mode: str, common file mode for zip. + (See: https://docs.python.org/3/library/zipfile.html) + + Returns: + A ZipFile object. + """ + file_like = _maybe_open_as_binary(filename, mode) + return zipfile.ZipFile(file_like, mode) + + +def _is_zipfile(filename): + """Checks whether it is a zipfile.""" + with _maybe_open_as_binary(filename, "r") as f: + return zipfile.is_zipfile(f) + + +def get_path_to_datafile(path): + """Gets the path to the specified file in the data dependencies. + + The path is relative to the file calling the function. + + It's a simple replacement of + "tensorflow.python.platform.resource_loader.get_path_to_datafile". + + Args: + path: a string resource path relative to the calling file. + + Returns: + The path to the specified file present in the data attribute of py_test + or py_binary. + """ + data_files_path = os.path.dirname(inspect.getfile(sys._getframe(1))) # pylint: disable=protected-access + return os.path.join(data_files_path, path) + + +_FLATC_TFLITE_METADATA_SCHEMA_FILE = get_path_to_datafile( + "../../metadata/metadata_schema.fbs") + + +# TODO: add delete method for associated files. +class MetadataPopulator(object): + """Packs metadata and associated files into TensorFlow Lite model file. + + MetadataPopulator can be used to populate metadata and model associated files + into a model file or a model buffer (in bytearray). It can also help to + inspect list of files that have been packed into the model or are supposed to + be packed into the model. + + The metadata file (or buffer) should be generated based on the metadata + schema: + third_party/tensorflow/lite/schema/metadata_schema.fbs + + Example usage: + Populate matadata and label file into an image classifier model. + + First, based on metadata_schema.fbs, generate the metadata for this image + classifer model using Flatbuffers API. Attach the label file onto the ouput + tensor (the tensor of probabilities) in the metadata. + + Then, pack the metadata and label file into the model as follows. + + ```python + # Populating a metadata file (or a metadta buffer) and associated files to + a model file: + populator = MetadataPopulator.with_model_file(model_file) + # For metadata buffer (bytearray read from the metadata file), use: + # populator.load_metadata_buffer(metadata_buf) + populator.load_metadata_file(metadata_file) + populator.load_associated_files([label.txt]) + # For associated file buffer (bytearray read from the file), use: + # populator.load_associated_file_buffers({"label.txt": b"file content"}) + populator.populate() + + # Populating a metadata file (or a metadata buffer) and associated files to + a model buffer: + populator = MetadataPopulator.with_model_buffer(model_buf) + populator.load_metadata_file(metadata_file) + populator.load_associated_files([label.txt]) + populator.populate() + # Writing the updated model buffer into a file. + updated_model_buf = populator.get_model_buffer() + with open("updated_model.tflite", "wb") as f: + f.write(updated_model_buf) + + # Transferring metadata and associated files from another TFLite model: + populator = MetadataPopulator.with_model_buffer(model_buf) + populator_dst.load_metadata_and_associated_files(src_model_buf) + populator_dst.populate() + updated_model_buf = populator.get_model_buffer() + with open("updated_model.tflite", "wb") as f: + f.write(updated_model_buf) + ``` + + Note that existing metadata buffer (if applied) will be overridden by the new + metadata buffer. + """ + # As Zip API is used to concatenate associated files after tflite model file, + # the populating operation is developed based on a model file. For in-memory + # model buffer, we create a tempfile to serve the populating operation. + # Creating the deleting such a tempfile is handled by the class, + # _MetadataPopulatorWithBuffer. + + METADATA_FIELD_NAME = "TFLITE_METADATA" + TFLITE_FILE_IDENTIFIER = b"TFL3" + METADATA_FILE_IDENTIFIER = b"M001" + + def __init__(self, model_file): + """Constructor for MetadataPopulator. + + Args: + model_file: valid path to a TensorFlow Lite model file. + + Raises: + IOError: File not found. + ValueError: the model does not have the expected flatbuffer identifer. + """ + _assert_model_file_identifier(model_file) + self._model_file = model_file + self._metadata_buf = None + # _associated_files is a dict of file name and file buffer. + self._associated_files = {} + + @classmethod + def with_model_file(cls, model_file): + """Creates a MetadataPopulator object that populates data to a model file. + + Args: + model_file: valid path to a TensorFlow Lite model file. + + Returns: + MetadataPopulator object. + + Raises: + IOError: File not found. + ValueError: the model does not have the expected flatbuffer identifer. + """ + return cls(model_file) + + # TODO: investigate if type check can be applied to model_buf for + # FB. + @classmethod + def with_model_buffer(cls, model_buf): + """Creates a MetadataPopulator object that populates data to a model buffer. + + Args: + model_buf: TensorFlow Lite model buffer in bytearray. + + Returns: + A MetadataPopulator(_MetadataPopulatorWithBuffer) object. + + Raises: + ValueError: the model does not have the expected flatbuffer identifer. + """ + return _MetadataPopulatorWithBuffer(model_buf) + + def get_model_buffer(self): + """Gets the buffer of the model with packed metadata and associated files. + + Returns: + Model buffer (in bytearray). + """ + with _open_file(self._model_file, "rb") as f: + return f.read() + + def get_packed_associated_file_list(self): + """Gets a list of associated files packed to the model file. + + Returns: + List of packed associated files. + """ + if not _is_zipfile(self._model_file): + return [] + + with _open_as_zipfile(self._model_file, "r") as zf: + return zf.namelist() + + def get_recorded_associated_file_list(self): + """Gets a list of associated files recorded in metadata of the model file. + + Associated files may be attached to a model, a subgraph, or an input/output + tensor. + + Returns: + List of recorded associated files. + """ + if not self._metadata_buf: + return [] + + metadata = _metadata_fb.ModelMetadataT.InitFromObj( + _metadata_fb.ModelMetadata.GetRootAsModelMetadata( + self._metadata_buf, 0)) + + return [ + file.name.decode("utf-8") + for file in self._get_recorded_associated_file_object_list(metadata) + ] + + def load_associated_file_buffers(self, associated_files): + """Loads the associated file buffers (in bytearray) to be populated. + + Args: + associated_files: a dictionary of associated file names and corresponding + file buffers, such as {"file.txt": b"file content"}. If pass in file + paths for the file name, only the basename will be populated. + """ + + self._associated_files.update({ + os.path.basename(name): buffers + for name, buffers in associated_files.items() + }) + + def load_associated_files(self, associated_files): + """Loads associated files that to be concatenated after the model file. + + Args: + associated_files: list of file paths. + + Raises: + IOError: + File not found. + """ + for af_name in associated_files: + _assert_file_exist(af_name) + with _open_file(af_name, "rb") as af: + self.load_associated_file_buffers({af_name: af.read()}) + + def load_metadata_buffer(self, metadata_buf): + """Loads the metadata buffer (in bytearray) to be populated. + + Args: + metadata_buf: metadata buffer (in bytearray) to be populated. + + Raises: + ValueError: The metadata to be populated is empty. + ValueError: The metadata does not have the expected flatbuffer identifer. + ValueError: Cannot get minimum metadata parser version. + ValueError: The number of SubgraphMetadata is not 1. + ValueError: The number of input/output tensors does not match the number + of input/output tensor metadata. + """ + if not metadata_buf: + raise ValueError("The metadata to be populated is empty.") + + self._validate_metadata(metadata_buf) + + # Gets the minimum metadata parser version of the metadata_buf. + min_version = _pywrap_metadata_version.GetMinimumMetadataParserVersion( + bytes(metadata_buf)) + + # Inserts in the minimum metadata parser version into the metadata_buf. + metadata = _metadata_fb.ModelMetadataT.InitFromObj( + _metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0)) + metadata.minParserVersion = min_version + + # Remove local file directory in the `name` field of `AssociatedFileT`, and + # make it consistent with the name of the actual file packed in the model. + self._use_basename_for_associated_files_in_metadata(metadata) + + b = flatbuffers.Builder(0) + b.Finish(metadata.Pack(b), self.METADATA_FILE_IDENTIFIER) + metadata_buf_with_version = b.Output() + + self._metadata_buf = metadata_buf_with_version + + def load_metadata_file(self, metadata_file): + """Loads the metadata file to be populated. + + Args: + metadata_file: path to the metadata file to be populated. + + Raises: + IOError: File not found. + ValueError: The metadata to be populated is empty. + ValueError: The metadata does not have the expected flatbuffer identifer. + ValueError: Cannot get minimum metadata parser version. + ValueError: The number of SubgraphMetadata is not 1. + ValueError: The number of input/output tensors does not match the number + of input/output tensor metadata. + """ + _assert_file_exist(metadata_file) + with _open_file(metadata_file, "rb") as f: + metadata_buf = f.read() + self.load_metadata_buffer(bytearray(metadata_buf)) + + def load_metadata_and_associated_files(self, src_model_buf): + """Loads the metadata and associated files from another model buffer. + + Args: + src_model_buf: source model buffer (in bytearray) with metadata and + associated files. + """ + # Load the model metadata from src_model_buf if exist. + metadata_buffer = get_metadata_buffer(src_model_buf) + if metadata_buffer: + self.load_metadata_buffer(metadata_buffer) + + # Load the associated files from src_model_buf if exist. + if _is_zipfile(io.BytesIO(src_model_buf)): + with _open_as_zipfile(io.BytesIO(src_model_buf)) as zf: + self.load_associated_file_buffers( + {f: zf.read(f) for f in zf.namelist()}) + + def populate(self): + """Populates loaded metadata and associated files into the model file.""" + self._assert_validate() + self._populate_metadata_buffer() + self._populate_associated_files() + + def _assert_validate(self): + """Validates the metadata and associated files to be populated. + + Raises: + ValueError: + File is recorded in the metadata, but is not going to be populated. + File has already been packed. + """ + # Gets files that are recorded in metadata. + recorded_files = self.get_recorded_associated_file_list() + + # Gets files that have been packed to self._model_file. + packed_files = self.get_packed_associated_file_list() + + # Gets the file name of those associated files to be populated. + to_be_populated_files = self._associated_files.keys() + + # Checks all files recorded in the metadata will be populated. + for rf in recorded_files: + if rf not in to_be_populated_files and rf not in packed_files: + raise ValueError("File, '{0}', is recorded in the metadata, but has " + "not been loaded into the populator.".format(rf)) + + for f in to_be_populated_files: + if f in packed_files: + raise ValueError("File, '{0}', has already been packed.".format(f)) + + if f not in recorded_files: + warnings.warn( + "File, '{0}', does not exist in the metadata. But packing it to " + "tflite model is still allowed.".format(f)) + + def _copy_archived_files(self, src_zip, file_list, dst_zip): + """Copy archieved files in file_list from src_zip ro dst_zip.""" + + if not _is_zipfile(src_zip): + raise ValueError("File, '{0}', is not a zipfile.".format(src_zip)) + + with _open_as_zipfile(src_zip, "r") as src_zf, \ + _open_as_zipfile(dst_zip, "a") as dst_zf: + src_list = src_zf.namelist() + for f in file_list: + if f not in src_list: + raise ValueError( + "File, '{0}', does not exist in the zipfile, {1}.".format( + f, src_zip)) + file_buffer = src_zf.read(f) + dst_zf.writestr(f, file_buffer) + + def _get_associated_files_from_process_units(self, table, field_name): + """Gets the files that are attached the process units field of a table. + + Args: + table: a Flatbuffers table object that contains fields of an array of + ProcessUnit, such as TensorMetadata and SubGraphMetadata. + field_name: the name of the field in the table that represents an array of + ProcessUnit. If the table is TensorMetadata, field_name can be + "ProcessUnits". If the table is SubGraphMetadata, field_name can be + either "InputProcessUnits" or "OutputProcessUnits". + + Returns: + A list of AssociatedFileT objects. + """ + + if table is None: + return [] + + file_list = [] + process_units = getattr(table, field_name) + # If the process_units field is not populated, it will be None. Use an + # empty list to skip the check. + for process_unit in process_units or []: + options = process_unit.options + if isinstance(options, (_metadata_fb.BertTokenizerOptionsT, + _metadata_fb.RegexTokenizerOptionsT)): + file_list += self._get_associated_files_from_table(options, "vocabFile") + elif isinstance(options, _metadata_fb.SentencePieceTokenizerOptionsT): + file_list += self._get_associated_files_from_table( + options, "sentencePieceModel") + file_list += self._get_associated_files_from_table(options, "vocabFile") + return file_list + + def _get_associated_files_from_table(self, table, field_name): + """Gets the associated files that are attached a table directly. + + Args: + table: a Flatbuffers table object that contains fields of an array of + AssociatedFile, such as TensorMetadata and BertTokenizerOptions. + field_name: the name of the field in the table that represents an array of + ProcessUnit. If the table is TensorMetadata, field_name can be + "AssociatedFiles". If the table is BertTokenizerOptions, field_name can + be "VocabFile". + + Returns: + A list of AssociatedFileT objects. + """ + + if table is None: + return [] + + # If the associated file field is not populated, + # `getattr(table, field_name)` will be None. Return an empty list. + return getattr(table, field_name) or [] + + def _get_recorded_associated_file_object_list(self, metadata): + """Gets a list of AssociatedFileT objects recorded in the metadata. + + Associated files may be attached to a model, a subgraph, or an input/output + tensor. + + Args: + metadata: the ModelMetadataT object. + + Returns: + List of recorded AssociatedFileT objects. + """ + recorded_files = [] + + # Add associated files attached to ModelMetadata. + recorded_files += self._get_associated_files_from_table( + metadata, "associatedFiles") + + # Add associated files attached to each SubgraphMetadata. + for subgraph in metadata.subgraphMetadata or []: + recorded_files += self._get_associated_files_from_table( + subgraph, "associatedFiles") + + # Add associated files attached to each input tensor. + for tensor_metadata in subgraph.inputTensorMetadata or []: + recorded_files += self._get_associated_files_from_table( + tensor_metadata, "associatedFiles") + recorded_files += self._get_associated_files_from_process_units( + tensor_metadata, "processUnits") + + # Add associated files attached to each output tensor. + for tensor_metadata in subgraph.outputTensorMetadata or []: + recorded_files += self._get_associated_files_from_table( + tensor_metadata, "associatedFiles") + recorded_files += self._get_associated_files_from_process_units( + tensor_metadata, "processUnits") + + # Add associated files attached to the input_process_units. + recorded_files += self._get_associated_files_from_process_units( + subgraph, "inputProcessUnits") + + # Add associated files attached to the output_process_units. + recorded_files += self._get_associated_files_from_process_units( + subgraph, "outputProcessUnits") + + return recorded_files + + def _populate_associated_files(self): + """Concatenates associated files after TensorFlow Lite model file. + + If the MetadataPopulator object is created using the method, + with_model_file(model_file), the model file will be updated. + """ + # Opens up the model file in "appending" mode. + # If self._model_file already has pack files, zipfile will concatenate + # addition files after self._model_file. For example, suppose we have + # self._model_file = old_tflite_file | label1.txt | label2.txt + # Then after trigger populate() to add label3.txt, self._model_file becomes + # self._model_file = old_tflite_file | label1.txt | label2.txt | label3.txt + with tempfile.SpooledTemporaryFile() as temp: + # (1) Copy content from model file of to temp file. + with _open_file(self._model_file, "rb") as f: + shutil.copyfileobj(f, temp) + + # (2) Append of to a temp file as a zip. + with _open_as_zipfile(temp, "a") as zf: + for file_name, file_buffer in self._associated_files.items(): + zf.writestr(file_name, file_buffer) + + # (3) Copy temp file to model file. + temp.seek(0) + with _open_file(self._model_file, "wb") as f: + shutil.copyfileobj(temp, f) + + def _populate_metadata_buffer(self): + """Populates the metadata buffer (in bytearray) into the model file. + + Inserts metadata_buf into the metadata field of schema.Model. If the + MetadataPopulator object is created using the method, + with_model_file(model_file), the model file will be updated. + + Existing metadata buffer (if applied) will be overridden by the new metadata + buffer. + """ + + with _open_file(self._model_file, "rb") as f: + model_buf = f.read() + + model = _schema_fb.ModelT.InitFromObj( + _schema_fb.Model.GetRootAsModel(model_buf, 0)) + buffer_field = _schema_fb.BufferT() + buffer_field.data = self._metadata_buf + + is_populated = False + if not model.metadata: + model.metadata = [] + else: + # Check if metadata has already been populated. + for meta in model.metadata: + if meta.name.decode("utf-8") == self.METADATA_FIELD_NAME: + is_populated = True + model.buffers[meta.buffer] = buffer_field + + if not is_populated: + if not model.buffers: + model.buffers = [] + model.buffers.append(buffer_field) + # Creates a new metadata field. + metadata_field = _schema_fb.MetadataT() + metadata_field.name = self.METADATA_FIELD_NAME + metadata_field.buffer = len(model.buffers) - 1 + model.metadata.append(metadata_field) + + # Packs model back to a flatbuffer binaray file. + b = flatbuffers.Builder(0) + b.Finish(model.Pack(b), self.TFLITE_FILE_IDENTIFIER) + model_buf = b.Output() + + # Saves the updated model buffer to model file. + # Gets files that have been packed to self._model_file. + packed_files = self.get_packed_associated_file_list() + if packed_files: + # Writes the updated model buffer and associated files into a new model + # file (in memory). Then overwrites the original model file. + with tempfile.SpooledTemporaryFile() as temp: + temp.write(model_buf) + self._copy_archived_files(self._model_file, packed_files, temp) + temp.seek(0) + with _open_file(self._model_file, "wb") as f: + shutil.copyfileobj(temp, f) + else: + with _open_file(self._model_file, "wb") as f: + f.write(model_buf) + + def _use_basename_for_associated_files_in_metadata(self, metadata): + """Removes any associated file local directory (if exists).""" + for file in self._get_recorded_associated_file_object_list(metadata): + file.name = os.path.basename(file.name) + + def _validate_metadata(self, metadata_buf): + """Validates the metadata to be populated.""" + _assert_metadata_buffer_identifier(metadata_buf) + + # Verify the number of SubgraphMetadata is exactly one. + # TFLite currently only support one subgraph. + model_meta = _metadata_fb.ModelMetadata.GetRootAsModelMetadata( + metadata_buf, 0) + if model_meta.SubgraphMetadataLength() != 1: + raise ValueError("The number of SubgraphMetadata should be exactly one, " + "but got {0}.".format( + model_meta.SubgraphMetadataLength())) + + # Verify if the number of tensor metadata matches the number of tensors. + with _open_file(self._model_file, "rb") as f: + model_buf = f.read() + model = _schema_fb.Model.GetRootAsModel(model_buf, 0) + + num_input_tensors = model.Subgraphs(0).InputsLength() + num_input_meta = model_meta.SubgraphMetadata(0).InputTensorMetadataLength() + if num_input_tensors != num_input_meta: + raise ValueError( + "The number of input tensors ({0}) should match the number of " + "input tensor metadata ({1})".format(num_input_tensors, + num_input_meta)) + num_output_tensors = model.Subgraphs(0).OutputsLength() + num_output_meta = model_meta.SubgraphMetadata( + 0).OutputTensorMetadataLength() + if num_output_tensors != num_output_meta: + raise ValueError( + "The number of output tensors ({0}) should match the number of " + "output tensor metadata ({1})".format(num_output_tensors, + num_output_meta)) + + +class _MetadataPopulatorWithBuffer(MetadataPopulator): + """Subclass of MetadtaPopulator that populates metadata to a model buffer. + + This class is used to populate metadata into a in-memory model buffer. As we + use Zip API to concatenate associated files after tflite model file, the + populating operation is developed based on a model file. For in-memory model + buffer, we create a tempfile to serve the populating operation. This class is + then used to generate this tempfile, and delete the file when the + MetadataPopulator object is deleted. + """ + + def __init__(self, model_buf): + """Constructor for _MetadataPopulatorWithBuffer. + + Args: + model_buf: TensorFlow Lite model buffer in bytearray. + + Raises: + ValueError: model_buf is empty. + ValueError: model_buf does not have the expected flatbuffer identifer. + """ + if not model_buf: + raise ValueError("model_buf cannot be empty.") + + with tempfile.NamedTemporaryFile() as temp: + model_file = temp.name + + with _open_file(model_file, "wb") as f: + f.write(model_buf) + + super().__init__(model_file) + + def __del__(self): + """Destructor of _MetadataPopulatorWithBuffer. + + Deletes the tempfile. + """ + if os.path.exists(self._model_file): + os.remove(self._model_file) + + +class MetadataDisplayer(object): + """Displays metadata and associated file info in human-readable format.""" + + def __init__(self, model_buffer, metadata_buffer, associated_file_list): + """Constructor for MetadataDisplayer. + + Args: + model_buffer: valid buffer of the model file. + metadata_buffer: valid buffer of the metadata file. + associated_file_list: list of associate files in the model file. + """ + _assert_model_buffer_identifier(model_buffer) + _assert_metadata_buffer_identifier(metadata_buffer) + self._model_buffer = model_buffer + self._metadata_buffer = metadata_buffer + self._associated_file_list = associated_file_list + + @classmethod + def with_model_file(cls, model_file): + """Creates a MetadataDisplayer object for the model file. + + Args: + model_file: valid path to a TensorFlow Lite model file. + + Returns: + MetadataDisplayer object. + + Raises: + IOError: File not found. + ValueError: The model does not have metadata. + """ + _assert_file_exist(model_file) + with _open_file(model_file, "rb") as f: + return cls.with_model_buffer(f.read()) + + @classmethod + def with_model_buffer(cls, model_buffer): + """Creates a MetadataDisplayer object for a file buffer. + + Args: + model_buffer: TensorFlow Lite model buffer in bytearray. + + Returns: + MetadataDisplayer object. + """ + if not model_buffer: + raise ValueError("model_buffer cannot be empty.") + metadata_buffer = get_metadata_buffer(model_buffer) + if not metadata_buffer: + raise ValueError("The model does not have metadata.") + associated_file_list = cls._parse_packed_associted_file_list(model_buffer) + return cls(model_buffer, metadata_buffer, associated_file_list) + + def get_associated_file_buffer(self, filename): + """Get the specified associated file content in bytearray. + + Args: + filename: name of the file to be extracted. + + Returns: + The file content in bytearray. + + Raises: + ValueError: if the file does not exist in the model. + """ + if filename not in self._associated_file_list: + raise ValueError( + "The file, {}, does not exist in the model.".format(filename)) + + with _open_as_zipfile(io.BytesIO(self._model_buffer)) as zf: + return zf.read(filename) + + def get_metadata_buffer(self): + """Get the metadata buffer in bytearray out from the model.""" + return copy.deepcopy(self._metadata_buffer) + + def get_metadata_json(self): + """Converts the metadata into a json string.""" + return convert_to_json(self._metadata_buffer) + + def get_packed_associated_file_list(self): + """Returns a list of associated files that are packed in the model. + + Returns: + A name list of associated files. + """ + return copy.deepcopy(self._associated_file_list) + + @staticmethod + def _parse_packed_associted_file_list(model_buf): + """Gets a list of associated files packed to the model file. + + Args: + model_buf: valid file buffer. + + Returns: + List of packed associated files. + """ + + try: + with _open_as_zipfile(io.BytesIO(model_buf)) as zf: + return zf.namelist() + except zipfile.BadZipFile: + return [] + + +# Create an individual method for getting the metadata json file, so that it can +# be used as a standalone util. +def convert_to_json(metadata_buffer): + """Converts the metadata into a json string. + + Args: + metadata_buffer: valid metadata buffer in bytes. + + Returns: + Metadata in JSON format. + + Raises: + ValueError: error occured when parsing the metadata schema file. + """ + + opt = _pywrap_flatbuffers.IDLOptions() + opt.strict_json = True + parser = _pywrap_flatbuffers.Parser(opt) + with _open_file(_FLATC_TFLITE_METADATA_SCHEMA_FILE) as f: + metadata_schema_content = f.read() + if not parser.parse(metadata_schema_content): + raise ValueError("Cannot parse metadata schema. Reason: " + parser.error) + return _pywrap_flatbuffers.generate_text(parser, metadata_buffer) + + +def _assert_file_exist(filename): + """Checks if a file exists.""" + if not _exists_file(filename): + raise IOError("File, '{0}', does not exist.".format(filename)) + + +def _assert_model_file_identifier(model_file): + """Checks if a model file has the expected TFLite schema identifier.""" + _assert_file_exist(model_file) + with _open_file(model_file, "rb") as f: + _assert_model_buffer_identifier(f.read()) + + +def _assert_model_buffer_identifier(model_buf): + if not _schema_fb.Model.ModelBufferHasIdentifier(model_buf, 0): + raise ValueError( + "The model provided does not have the expected identifier, and " + "may not be a valid TFLite model.") + + +def _assert_metadata_buffer_identifier(metadata_buf): + """Checks if a metadata buffer has the expected Metadata schema identifier.""" + if not _metadata_fb.ModelMetadata.ModelMetadataBufferHasIdentifier( + metadata_buf, 0): + raise ValueError( + "The metadata buffer does not have the expected identifier, and may not" + " be a valid TFLite Metadata.") + + +def get_metadata_buffer(model_buf): + """Returns the metadata in the model file as a buffer. + + Args: + model_buf: valid buffer of the model file. + + Returns: + Metadata buffer. Returns `None` if the model does not have metadata. + """ + tflite_model = _schema_fb.Model.GetRootAsModel(model_buf, 0) + + # Gets metadata from the model file. + for i in range(tflite_model.MetadataLength()): + meta = tflite_model.Metadata(i) + if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME: + buffer_index = meta.Buffer() + metadata = tflite_model.Buffers(buffer_index) + return metadata.DataAsNumpy().tobytes() + + return None diff --git a/mediapipe/tasks/python/metadata/metadata_displayer_cli.py b/mediapipe/tasks/python/metadata/metadata_displayer_cli.py new file mode 100644 index 000000000..745da1f25 --- /dev/null +++ b/mediapipe/tasks/python/metadata/metadata_displayer_cli.py @@ -0,0 +1,34 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""CLI tool for display metadata.""" + +from absl import app +from absl import flags + +from mediapipe.tasks.python.metadata import metadata + +FLAGS = flags.FLAGS +flags.DEFINE_string('model_path', None, 'Path to the TFLite model file.') +flags.DEFINE_string('export_json_path', None, 'Path to the output JSON file.') + + +def main(_): + displayer = metadata.MetadataDisplayer.with_model_file(FLAGS.model_path) + with open(FLAGS.export_json_path, 'w') as f: + f.write(displayer.get_metadata_json()) + + +if __name__ == '__main__': + app.run(main) diff --git a/mediapipe/tasks/python/metadata/metadata_parser.py.template b/mediapipe/tasks/python/metadata/metadata_parser.py.template new file mode 100644 index 000000000..b5a64dee6 --- /dev/null +++ b/mediapipe/tasks/python/metadata/metadata_parser.py.template @@ -0,0 +1,26 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Information about the metadata parser that this python library depends on.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class MetadataParser(object): + """Information about the metadata parser.""" + + # The version of the metadata parser. + VERSION = "{LATEST_METADATA_PARSER_VERSION}" diff --git a/mediapipe/tasks/python/test/BUILD b/mediapipe/tasks/python/test/BUILD index 7d5f2451b..e31c7af1c 100644 --- a/mediapipe/tasks/python/test/BUILD +++ b/mediapipe/tasks/python/test/BUILD @@ -19,9 +19,9 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) py_library( - name = "test_util", + name = "test_utils", testonly = 1, - srcs = ["test_util.py"], + srcs = ["test_utils.py"], srcs_version = "PY3", deps = [ "//mediapipe/python:_framework_bindings", diff --git a/mediapipe/tasks/python/test/metadata/BUILD b/mediapipe/tasks/python/test/metadata/BUILD new file mode 100644 index 000000000..c679d86cb --- /dev/null +++ b/mediapipe/tasks/python/test/metadata/BUILD @@ -0,0 +1,31 @@ +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +py_test( + name = "metadata_test", + srcs = ["metadata_test.py"], + data = ["//mediapipe/tasks/testdata/metadata:data_files"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + "//mediapipe/tasks/metadata:metadata_schema_py", + "//mediapipe/tasks/metadata:schema_py", + "//mediapipe/tasks/python/metadata", + "//mediapipe/tasks/python/test:test_utils", + "@flatbuffers//:runtime_py", + ], +) + +py_test( + name = "metadata_parser_test", + srcs = ["metadata_parser_test.py"], + python_version = "PY3", + srcs_version = "PY2AND3", + deps = [ + "//mediapipe/tasks/python/metadata", + ], +) diff --git a/mediapipe/tasks/python/test/metadata/metadata_parser_test.py b/mediapipe/tasks/python/test/metadata/metadata_parser_test.py new file mode 100644 index 000000000..93b851082 --- /dev/null +++ b/mediapipe/tasks/python/test/metadata/metadata_parser_test.py @@ -0,0 +1,37 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for mediapipe.tasks.python.metadata.metadata_parser.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +from absl.testing import absltest + +from mediapipe.tasks.python.metadata import metadata_parser + + +class MetadataParserTest(absltest.TestCase): + + def testVersionWellFormedSemanticVersion(self): + # Validates that the version is well-formed (x.y.z). + self.assertTrue( + re.match('[0-9]+\\.[0-9]+\\.[0-9]+', + metadata_parser.MetadataParser.VERSION)) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/test/metadata/metadata_test.py b/mediapipe/tasks/python/test/metadata/metadata_test.py new file mode 100644 index 000000000..00dbe526a --- /dev/null +++ b/mediapipe/tasks/python/test/metadata/metadata_test.py @@ -0,0 +1,857 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for mediapipe.tasks.python.metadata.metadata.""" + +import enum +import os + +from absl.testing import absltest +from absl.testing import parameterized +import six + +import flatbuffers +from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb +from mediapipe.tasks.metadata import schema_py_generated as _schema_fb +from mediapipe.tasks.python.metadata import metadata as _metadata +from mediapipe.tasks.python.test import test_utils + + +class Tokenizer(enum.Enum): + BERT_TOKENIZER = 0 + SENTENCE_PIECE = 1 + + +class TensorType(enum.Enum): + INPUT = 0 + OUTPUT = 1 + + +def _read_file(file_name, mode="rb"): + with open(file_name, mode) as f: + return f.read() + + +class MetadataTest(parameterized.TestCase): + + def setUp(self): + super(MetadataTest, self).setUp() + self._invalid_model_buf = None + self._invalid_file = "not_existed_file" + self._model_buf = self._create_model_buf() + self._model_file = self.create_tempfile().full_path + with open(self._model_file, "wb") as f: + f.write(self._model_buf) + self._metadata_file = self._create_metadata_file() + self._metadata_file_with_version = self._create_metadata_file_with_version( + self._metadata_file, "1.0.0") + self._file1 = self.create_tempfile("file1").full_path + self._file2 = self.create_tempfile("file2").full_path + self._file2_content = b"file2_content" + with open(self._file2, "wb") as f: + f.write(self._file2_content) + self._file3 = self.create_tempfile("file3").full_path + + def _create_model_buf(self): + # Create a model with two inputs and one output, which matches the metadata + # created by _create_metadata_file(). + metadata_field = _schema_fb.MetadataT() + subgraph = _schema_fb.SubGraphT() + subgraph.inputs = [0, 1] + subgraph.outputs = [2] + + metadata_field.name = "meta" + buffer_field = _schema_fb.BufferT() + model = _schema_fb.ModelT() + model.subgraphs = [subgraph] + # Creates the metadata and buffer fields for testing purposes. + model.metadata = [metadata_field, metadata_field] + model.buffers = [buffer_field, buffer_field, buffer_field] + model_builder = flatbuffers.Builder(0) + model_builder.Finish( + model.Pack(model_builder), + _metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER) + return model_builder.Output() + + def _create_metadata_file(self): + associated_file1 = _metadata_fb.AssociatedFileT() + associated_file1.name = b"file1" + associated_file2 = _metadata_fb.AssociatedFileT() + associated_file2.name = b"file2" + self.expected_recorded_files = [ + six.ensure_str(associated_file1.name), + six.ensure_str(associated_file2.name) + ] + + input_meta = _metadata_fb.TensorMetadataT() + output_meta = _metadata_fb.TensorMetadataT() + output_meta.associatedFiles = [associated_file2] + subgraph = _metadata_fb.SubGraphMetadataT() + # Create a model with two inputs and one output. + subgraph.inputTensorMetadata = [input_meta, input_meta] + subgraph.outputTensorMetadata = [output_meta] + + model_meta = _metadata_fb.ModelMetadataT() + model_meta.name = "Mobilenet_quantized" + model_meta.associatedFiles = [associated_file1] + model_meta.subgraphMetadata = [subgraph] + b = flatbuffers.Builder(0) + b.Finish( + model_meta.Pack(b), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + + metadata_file = self.create_tempfile().full_path + with open(metadata_file, "wb") as f: + f.write(b.Output()) + return metadata_file + + def _create_model_buffer_with_wrong_identifier(self): + wrong_identifier = b"widn" + model = _schema_fb.ModelT() + model_builder = flatbuffers.Builder(0) + model_builder.Finish(model.Pack(model_builder), wrong_identifier) + return model_builder.Output() + + def _create_metadata_buffer_with_wrong_identifier(self): + # Creates a metadata with wrong identifier + wrong_identifier = b"widn" + metadata = _metadata_fb.ModelMetadataT() + metadata_builder = flatbuffers.Builder(0) + metadata_builder.Finish(metadata.Pack(metadata_builder), wrong_identifier) + return metadata_builder.Output() + + def _populate_metadata_with_identifier(self, model_buf, metadata_buf, + identifier): + # For testing purposes only. MetadataPopulator cannot populate metadata with + # wrong identifiers. + model = _schema_fb.ModelT.InitFromObj( + _schema_fb.Model.GetRootAsModel(model_buf, 0)) + buffer_field = _schema_fb.BufferT() + buffer_field.data = metadata_buf + model.buffers = [buffer_field] + # Creates a new metadata field. + metadata_field = _schema_fb.MetadataT() + metadata_field.name = _metadata.MetadataPopulator.METADATA_FIELD_NAME + metadata_field.buffer = len(model.buffers) - 1 + model.metadata = [metadata_field] + b = flatbuffers.Builder(0) + b.Finish(model.Pack(b), identifier) + return b.Output() + + def _create_metadata_file_with_version(self, metadata_file, min_version): + # Creates a new metadata file with the specified min_version for testing + # purposes. + metadata_buf = bytearray(_read_file(metadata_file)) + + metadata = _metadata_fb.ModelMetadataT.InitFromObj( + _metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0)) + metadata.minParserVersion = min_version + + b = flatbuffers.Builder(0) + b.Finish( + metadata.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + + metadata_file_with_version = self.create_tempfile().full_path + with open(metadata_file_with_version, "wb") as f: + f.write(b.Output()) + return metadata_file_with_version + + +class MetadataPopulatorTest(MetadataTest): + + def _create_bert_tokenizer(self): + vocab_file_name = "bert_vocab" + vocab = _metadata_fb.AssociatedFileT() + vocab.name = vocab_file_name + vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY + tokenizer = _metadata_fb.ProcessUnitT() + tokenizer.optionsType = _metadata_fb.ProcessUnitOptions.BertTokenizerOptions + tokenizer.options = _metadata_fb.BertTokenizerOptionsT() + tokenizer.options.vocabFile = [vocab] + return tokenizer, [vocab_file_name] + + def _create_sentence_piece_tokenizer(self): + sp_model_name = "sp_model" + vocab_file_name = "sp_vocab" + sp_model = _metadata_fb.AssociatedFileT() + sp_model.name = sp_model_name + vocab = _metadata_fb.AssociatedFileT() + vocab.name = vocab_file_name + vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY + tokenizer = _metadata_fb.ProcessUnitT() + tokenizer.optionsType = ( + _metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions) + tokenizer.options = _metadata_fb.SentencePieceTokenizerOptionsT() + tokenizer.options.sentencePieceModel = [sp_model] + tokenizer.options.vocabFile = [vocab] + return tokenizer, [sp_model_name, vocab_file_name] + + def _create_tokenizer(self, tokenizer_type): + if tokenizer_type is Tokenizer.BERT_TOKENIZER: + return self._create_bert_tokenizer() + elif tokenizer_type is Tokenizer.SENTENCE_PIECE: + return self._create_sentence_piece_tokenizer() + else: + raise ValueError( + "The tokenizer type, {0}, is unsupported.".format(tokenizer_type)) + + def _create_tempfiles(self, file_names): + tempfiles = [] + for name in file_names: + tempfiles.append(self.create_tempfile(name).full_path) + return tempfiles + + def _create_model_meta_with_subgraph_meta(self, subgraph_meta): + model_meta = _metadata_fb.ModelMetadataT() + model_meta.subgraphMetadata = [subgraph_meta] + b = flatbuffers.Builder(0) + b.Finish( + model_meta.Pack(b), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + return b.Output() + + def testToValidModelFile(self): + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + self.assertIsInstance(populator, _metadata.MetadataPopulator) + + def testToInvalidModelFile(self): + with self.assertRaises(IOError) as error: + _metadata.MetadataPopulator.with_model_file(self._invalid_file) + self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), + str(error.exception)) + + def testToValidModelBuffer(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + self.assertIsInstance(populator, _metadata.MetadataPopulator) + + def testToInvalidModelBuffer(self): + with self.assertRaises(ValueError) as error: + _metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf) + self.assertEqual("model_buf cannot be empty.", str(error.exception)) + + def testToModelBufferWithWrongIdentifier(self): + model_buf = self._create_model_buffer_with_wrong_identifier() + with self.assertRaises(ValueError) as error: + _metadata.MetadataPopulator.with_model_buffer(model_buf) + self.assertEqual( + "The model provided does not have the expected identifier, and " + "may not be a valid TFLite model.", str(error.exception)) + + def testSinglePopulateAssociatedFile(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + populator.load_associated_files([self._file1]) + populator.populate() + + packed_files = populator.get_packed_associated_file_list() + expected_packed_files = [os.path.basename(self._file1)] + self.assertEqual(set(packed_files), set(expected_packed_files)) + + def testRepeatedPopulateAssociatedFile(self): + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_associated_files([self._file1, self._file2]) + # Loads file2 multiple times. + populator.load_associated_files([self._file2]) + populator.populate() + + packed_files = populator.get_packed_associated_file_list() + expected_packed_files = [ + os.path.basename(self._file1), + os.path.basename(self._file2) + ] + self.assertLen(packed_files, 2) + self.assertEqual(set(packed_files), set(expected_packed_files)) + + # Check if the model buffer read from file is the same as that read from + # get_model_buffer(). + model_buf_from_file = _read_file(self._model_file) + model_buf_from_getter = populator.get_model_buffer() + self.assertEqual(model_buf_from_file, model_buf_from_getter) + + def testPopulateInvalidAssociatedFile(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(IOError) as error: + populator.load_associated_files([self._invalid_file]) + self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), + str(error.exception)) + + def testPopulatePackedAssociatedFile(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + populator.load_associated_files([self._file1]) + populator.populate() + with self.assertRaises(ValueError) as error: + populator.load_associated_files([self._file1]) + populator.populate() + self.assertEqual( + "File, '{0}', has already been packed.".format( + os.path.basename(self._file1)), str(error.exception)) + + def testLoadAssociatedFileBuffers(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + file_buffer = _read_file(self._file1) + populator.load_associated_file_buffers({self._file1: file_buffer}) + populator.populate() + + packed_files = populator.get_packed_associated_file_list() + expected_packed_files = [os.path.basename(self._file1)] + self.assertEqual(set(packed_files), set(expected_packed_files)) + + def testRepeatedLoadAssociatedFileBuffers(self): + file_buffer1 = _read_file(self._file1) + file_buffer2 = _read_file(self._file2) + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + + populator.load_associated_file_buffers({ + self._file1: file_buffer1, + self._file2: file_buffer2 + }) + # Loads file2 multiple times. + populator.load_associated_file_buffers({self._file2: file_buffer2}) + populator.populate() + + packed_files = populator.get_packed_associated_file_list() + expected_packed_files = [ + os.path.basename(self._file1), + os.path.basename(self._file2) + ] + self.assertEqual(set(packed_files), set(expected_packed_files)) + + # Check if the model buffer read from file is the same as that read from + # get_model_buffer(). + model_buf_from_file = _read_file(self._model_file) + model_buf_from_getter = populator.get_model_buffer() + self.assertEqual(model_buf_from_file, model_buf_from_getter) + + def testLoadPackedAssociatedFileBuffersFails(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + file_buffer = _read_file(self._file1) + populator.load_associated_file_buffers({self._file1: file_buffer}) + populator.populate() + + # Load file1 again should fail. + with self.assertRaises(ValueError) as error: + populator.load_associated_file_buffers({self._file1: file_buffer}) + populator.populate() + self.assertEqual( + "File, '{0}', has already been packed.".format( + os.path.basename(self._file1)), str(error.exception)) + + def testGetPackedAssociatedFileList(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + packed_files = populator.get_packed_associated_file_list() + self.assertEqual(packed_files, []) + + def testPopulateMetadataFileToEmptyModelFile(self): + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_file(self._metadata_file) + populator.load_associated_files([self._file1, self._file2]) + populator.populate() + + model_buf_from_file = _read_file(self._model_file) + model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0) + # self._model_file already has two elements in the metadata field, so the + # populated TFLite metadata will be the third element. + metadata_field = model.Metadata(2) + self.assertEqual( + six.ensure_str(metadata_field.Name()), + six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME)) + + buffer_index = metadata_field.Buffer() + buffer_data = model.Buffers(buffer_index) + metadata_buf_np = buffer_data.DataAsNumpy() + metadata_buf = metadata_buf_np.tobytes() + expected_metadata_buf = bytearray( + _read_file(self._metadata_file_with_version)) + self.assertEqual(metadata_buf, expected_metadata_buf) + + recorded_files = populator.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set(self.expected_recorded_files)) + + # Up to now, we've proved the correctness of the model buffer that read from + # file. Then we'll test if get_model_buffer() gives the same model buffer. + model_buf_from_getter = populator.get_model_buffer() + self.assertEqual(model_buf_from_file, model_buf_from_getter) + + def testPopulateMetadataFileWithoutAssociatedFiles(self): + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_file(self._metadata_file) + populator.load_associated_files([self._file1]) + # Suppose to populate self._file2, because it is recorded in the metadta. + with self.assertRaises(ValueError) as error: + populator.populate() + self.assertEqual(("File, '{0}', is recorded in the metadata, but has " + "not been loaded into the populator.").format( + os.path.basename(self._file2)), str(error.exception)) + + def testPopulateMetadataBufferWithWrongIdentifier(self): + metadata_buf = self._create_metadata_buffer_with_wrong_identifier() + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer(metadata_buf) + self.assertEqual( + "The metadata buffer does not have the expected identifier, and may not" + " be a valid TFLite Metadata.", str(error.exception)) + + def _assert_golden_metadata(self, model_file): + model_buf_from_file = _read_file(model_file) + model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0) + # There are two elements in model.Metadata array before the population. + # Metadata should be packed to the third element in the array. + metadata_field = model.Metadata(2) + self.assertEqual( + six.ensure_str(metadata_field.Name()), + six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME)) + + buffer_index = metadata_field.Buffer() + buffer_data = model.Buffers(buffer_index) + metadata_buf_np = buffer_data.DataAsNumpy() + metadata_buf = metadata_buf_np.tobytes() + expected_metadata_buf = bytearray( + _read_file(self._metadata_file_with_version)) + self.assertEqual(metadata_buf, expected_metadata_buf) + + def testPopulateMetadataFileToModelWithMetadataAndAssociatedFiles(self): + # First, creates a dummy metadata different from self._metadata_file. It + # needs to have the same input/output tensor numbers as self._model_file. + # Populates it and the associated files into the model. + input_meta = _metadata_fb.TensorMetadataT() + output_meta = _metadata_fb.TensorMetadataT() + subgraph = _metadata_fb.SubGraphMetadataT() + # Create a model with two inputs and one output. + subgraph.inputTensorMetadata = [input_meta, input_meta] + subgraph.outputTensorMetadata = [output_meta] + model_meta = _metadata_fb.ModelMetadataT() + model_meta.subgraphMetadata = [subgraph] + b = flatbuffers.Builder(0) + b.Finish( + model_meta.Pack(b), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata_buf = b.Output() + + # Populate the metadata. + populator1 = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator1.load_metadata_buffer(metadata_buf) + populator1.load_associated_files([self._file1, self._file2]) + populator1.populate() + + # Then, populate the metadata again. + populator2 = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator2.load_metadata_file(self._metadata_file) + populator2.populate() + + # Test if the metadata is populated correctly. + self._assert_golden_metadata(self._model_file) + + def testPopulateMetadataFileToModelFileWithMetadataAndBufFields(self): + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_file(self._metadata_file) + populator.load_associated_files([self._file1, self._file2]) + populator.populate() + + # Tests if the metadata is populated correctly. + self._assert_golden_metadata(self._model_file) + + recorded_files = populator.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set(self.expected_recorded_files)) + + # Up to now, we've proved the correctness of the model buffer that read from + # file. Then we'll test if get_model_buffer() gives the same model buffer. + model_buf_from_file = _read_file(self._model_file) + model_buf_from_getter = populator.get_model_buffer() + self.assertEqual(model_buf_from_file, model_buf_from_getter) + + def testPopulateInvalidMetadataFile(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(IOError) as error: + populator.load_metadata_file(self._invalid_file) + self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), + str(error.exception)) + + def testPopulateInvalidMetadataBuffer(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer([]) + self.assertEqual("The metadata to be populated is empty.", + str(error.exception)) + + def testGetModelBufferBeforePopulatingData(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + model_buf = populator.get_model_buffer() + expected_model_buf = self._model_buf + self.assertEqual(model_buf, expected_model_buf) + + def testLoadMetadataBufferWithNoSubgraphMetadataThrowsException(self): + # Create a dummy metadata without Subgraph. + model_meta = _metadata_fb.ModelMetadataT() + builder = flatbuffers.Builder(0) + builder.Finish( + model_meta.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + meta_buf = builder.Output() + + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer(meta_buf) + self.assertEqual( + "The number of SubgraphMetadata should be exactly one, but got 0.", + str(error.exception)) + + def testLoadMetadataBufferWithWrongInputMetaNumberThrowsException(self): + # Create a dummy metadata with no input tensor metadata, while the expected + # number is 2. + output_meta = _metadata_fb.TensorMetadataT() + subgprah_meta = _metadata_fb.SubGraphMetadataT() + subgprah_meta.outputTensorMetadata = [output_meta] + model_meta = _metadata_fb.ModelMetadataT() + model_meta.subgraphMetadata = [subgprah_meta] + builder = flatbuffers.Builder(0) + builder.Finish( + model_meta.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + meta_buf = builder.Output() + + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer(meta_buf) + self.assertEqual( + ("The number of input tensors (2) should match the number of " + "input tensor metadata (0)"), str(error.exception)) + + def testLoadMetadataBufferWithWrongOutputMetaNumberThrowsException(self): + # Create a dummy metadata with no output tensor metadata, while the expected + # number is 1. + input_meta = _metadata_fb.TensorMetadataT() + subgprah_meta = _metadata_fb.SubGraphMetadataT() + subgprah_meta.inputTensorMetadata = [input_meta, input_meta] + model_meta = _metadata_fb.ModelMetadataT() + model_meta.subgraphMetadata = [subgprah_meta] + builder = flatbuffers.Builder(0) + builder.Finish( + model_meta.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + meta_buf = builder.Output() + + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer(meta_buf) + self.assertEqual( + ("The number of output tensors (1) should match the number of " + "output tensor metadata (0)"), str(error.exception)) + + def testLoadMetadataAndAssociatedFilesShouldSucceeds(self): + # Create a src model with metadata and two associated files. + src_model_buf = self._create_model_buf() + populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf) + populator_src.load_metadata_file(self._metadata_file) + populator_src.load_associated_files([self._file1, self._file2]) + populator_src.populate() + + # Create a model to be populated with the metadata and files from + # src_model_buf. + dst_model_buf = self._create_model_buf() + populator_dst = _metadata.MetadataPopulator.with_model_buffer(dst_model_buf) + populator_dst.load_metadata_and_associated_files( + populator_src.get_model_buffer()) + populator_dst.populate() + + # Tests if the metadata and associated files are populated correctly. + dst_model_file = self.create_tempfile().full_path + with open(dst_model_file, "wb") as f: + f.write(populator_dst.get_model_buffer()) + self._assert_golden_metadata(dst_model_file) + + recorded_files = populator_dst.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set(self.expected_recorded_files)) + + @parameterized.named_parameters( + { + "testcase_name": "InputTensorWithBert", + "tensor_type": TensorType.INPUT, + "tokenizer_type": Tokenizer.BERT_TOKENIZER + }, { + "testcase_name": "OutputTensorWithBert", + "tensor_type": TensorType.OUTPUT, + "tokenizer_type": Tokenizer.BERT_TOKENIZER + }, { + "testcase_name": "InputTensorWithSentencePiece", + "tensor_type": TensorType.INPUT, + "tokenizer_type": Tokenizer.SENTENCE_PIECE + }, { + "testcase_name": "OutputTensorWithSentencePiece", + "tensor_type": TensorType.OUTPUT, + "tokenizer_type": Tokenizer.SENTENCE_PIECE + }) + def testGetRecordedAssociatedFileListWithSubgraphTensor( + self, tensor_type, tokenizer_type): + # Creates a metadata with the tokenizer in the tensor process units. + tokenizer, expected_files = self._create_tokenizer(tokenizer_type) + + # Create the tensor with process units. + tensor = _metadata_fb.TensorMetadataT() + tensor.processUnits = [tokenizer] + + # Create the subgrah with the tensor. + subgraph = _metadata_fb.SubGraphMetadataT() + dummy_tensor_meta = _metadata_fb.TensorMetadataT() + subgraph.outputTensorMetadata = [dummy_tensor_meta] + if tensor_type is TensorType.INPUT: + subgraph.inputTensorMetadata = [tensor, dummy_tensor_meta] + subgraph.outputTensorMetadata = [dummy_tensor_meta] + elif tensor_type is TensorType.OUTPUT: + subgraph.inputTensorMetadata = [dummy_tensor_meta, dummy_tensor_meta] + subgraph.outputTensorMetadata = [tensor] + else: + raise ValueError( + "The tensor type, {0}, is unsupported.".format(tensor_type)) + + # Create a model metadata with the subgraph metadata + meta_buffer = self._create_model_meta_with_subgraph_meta(subgraph) + + # Creates the tempfiles. + tempfiles = self._create_tempfiles(expected_files) + + # Creates the MetadataPopulator object. + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_buffer(meta_buffer) + populator.load_associated_files(tempfiles) + populator.populate() + + recorded_files = populator.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set(expected_files)) + + @parameterized.named_parameters( + { + "testcase_name": "InputTensorWithBert", + "tensor_type": TensorType.INPUT, + "tokenizer_type": Tokenizer.BERT_TOKENIZER + }, { + "testcase_name": "OutputTensorWithBert", + "tensor_type": TensorType.OUTPUT, + "tokenizer_type": Tokenizer.BERT_TOKENIZER + }, { + "testcase_name": "InputTensorWithSentencePiece", + "tensor_type": TensorType.INPUT, + "tokenizer_type": Tokenizer.SENTENCE_PIECE + }, { + "testcase_name": "OutputTensorWithSentencePiece", + "tensor_type": TensorType.OUTPUT, + "tokenizer_type": Tokenizer.SENTENCE_PIECE + }) + def testGetRecordedAssociatedFileListWithSubgraphProcessUnits( + self, tensor_type, tokenizer_type): + # Creates a metadata with the tokenizer in the subgraph process units. + tokenizer, expected_files = self._create_tokenizer(tokenizer_type) + + # Create the subgraph with process units. + subgraph = _metadata_fb.SubGraphMetadataT() + if tensor_type is TensorType.INPUT: + subgraph.inputProcessUnits = [tokenizer] + elif tensor_type is TensorType.OUTPUT: + subgraph.outputProcessUnits = [tokenizer] + else: + raise ValueError( + "The tensor type, {0}, is unsupported.".format(tensor_type)) + + # Creates the input and output tensor meta to match self._model_file. + dummy_tensor_meta = _metadata_fb.TensorMetadataT() + subgraph.inputTensorMetadata = [dummy_tensor_meta, dummy_tensor_meta] + subgraph.outputTensorMetadata = [dummy_tensor_meta] + + # Create a model metadata with the subgraph metadata + meta_buffer = self._create_model_meta_with_subgraph_meta(subgraph) + + # Creates the tempfiles. + tempfiles = self._create_tempfiles(expected_files) + + # Creates the MetadataPopulator object. + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_buffer(meta_buffer) + populator.load_associated_files(tempfiles) + populator.populate() + + recorded_files = populator.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set(expected_files)) + + def testPopulatedFullPathAssociatedFileShouldSucceed(self): + # Create AssociatedFileT using the full path file name. + associated_file = _metadata_fb.AssociatedFileT() + associated_file.name = self._file1 + + # Create model metadata with the associated file. + subgraph = _metadata_fb.SubGraphMetadataT() + subgraph.associatedFiles = [associated_file] + # Creates the input and output tensor metadata to match self._model_file. + dummy_tensor = _metadata_fb.TensorMetadataT() + subgraph.inputTensorMetadata = [dummy_tensor, dummy_tensor] + subgraph.outputTensorMetadata = [dummy_tensor] + md_buffer = self._create_model_meta_with_subgraph_meta(subgraph) + + # Populate the metadata to a model. + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_buffer(md_buffer) + populator.load_associated_files([self._file1]) + populator.populate() + + # The recorded file name in metadata should only contain file basename; file + # directory should not be included. + recorded_files = populator.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set([os.path.basename(self._file1)])) + + +class MetadataDisplayerTest(MetadataTest): + + def setUp(self): + super(MetadataDisplayerTest, self).setUp() + self._model_with_meta_file = ( + self._create_model_with_metadata_and_associated_files()) + + def _create_model_with_metadata_and_associated_files(self): + model_buf = self._create_model_buf() + model_file = self.create_tempfile().full_path + with open(model_file, "wb") as f: + f.write(model_buf) + + populator = _metadata.MetadataPopulator.with_model_file(model_file) + populator.load_metadata_file(self._metadata_file) + populator.load_associated_files([self._file1, self._file2]) + populator.populate() + return model_file + + def testLoadModelBufferMetadataBufferWithWrongIdentifierThrowsException(self): + model_buf = self._create_model_buffer_with_wrong_identifier() + metadata_buf = self._create_metadata_buffer_with_wrong_identifier() + model_buf = self._populate_metadata_with_identifier( + model_buf, metadata_buf, + _metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER) + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_buffer(model_buf) + self.assertEqual( + "The metadata buffer does not have the expected identifier, and may not" + " be a valid TFLite Metadata.", str(error.exception)) + + def testLoadModelBufferModelBufferWithWrongIdentifierThrowsException(self): + model_buf = self._create_model_buffer_with_wrong_identifier() + metadata_file = self._create_metadata_file() + wrong_identifier = b"widn" + metadata_buf = bytearray(_read_file(metadata_file)) + model_buf = self._populate_metadata_with_identifier(model_buf, metadata_buf, + wrong_identifier) + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_buffer(model_buf) + self.assertEqual( + "The model provided does not have the expected identifier, and " + "may not be a valid TFLite model.", str(error.exception)) + + def testLoadModelFileInvalidModelFileThrowsException(self): + with self.assertRaises(IOError) as error: + _metadata.MetadataDisplayer.with_model_file(self._invalid_file) + self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), + str(error.exception)) + + def testLoadModelFileModelWithoutMetadataThrowsException(self): + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_file(self._model_file) + self.assertEqual("The model does not have metadata.", str(error.exception)) + + def testLoadModelFileModelWithMetadata(self): + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + self.assertIsInstance(displayer, _metadata.MetadataDisplayer) + + def testLoadModelBufferInvalidModelBufferThrowsException(self): + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_buffer(_read_file(self._file1)) + self.assertEqual("model_buffer cannot be empty.", str(error.exception)) + + def testLoadModelBufferModelWithOutMetadataThrowsException(self): + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_buffer(self._create_model_buf()) + self.assertEqual("The model does not have metadata.", str(error.exception)) + + def testLoadModelBufferModelWithMetadata(self): + displayer = _metadata.MetadataDisplayer.with_model_buffer( + _read_file(self._model_with_meta_file)) + self.assertIsInstance(displayer, _metadata.MetadataDisplayer) + + def testGetAssociatedFileBufferShouldSucceed(self): + # _model_with_meta_file contains file1 and file2. + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + + actual_content = displayer.get_associated_file_buffer("file2") + self.assertEqual(actual_content, self._file2_content) + + def testGetAssociatedFileBufferFailsWithNonExistentFile(self): + # _model_with_meta_file contains file1 and file2. + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + + non_existent_file = "non_existent_file" + with self.assertRaises(ValueError) as error: + displayer.get_associated_file_buffer(non_existent_file) + self.assertEqual( + "The file, {}, does not exist in the model.".format(non_existent_file), + str(error.exception)) + + def testGetMetadataBufferShouldSucceed(self): + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + actual_buffer = displayer.get_metadata_buffer() + actual_json = _metadata.convert_to_json(actual_buffer) + + # Verifies the generated json file. + golden_json_file_path = test_utils.get_test_data_path("golden_json.json") + with open(golden_json_file_path, "r") as f: + expected = f.read() + self.assertEqual(actual_json, expected) + + def testGetMetadataJsonModelWithMetadata(self): + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + actual = displayer.get_metadata_json() + + # Verifies the generated json file. + golden_json_file_path = test_utils.get_test_data_path("golden_json.json") + expected = _read_file(golden_json_file_path, "r") + self.assertEqual(actual, expected) + + def testGetPackedAssociatedFileListModelWithMetadata(self): + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + packed_files = displayer.get_packed_associated_file_list() + + expected_packed_files = [ + os.path.basename(self._file1), + os.path.basename(self._file2) + ] + self.assertLen( + packed_files, 2, + "The following two associated files packed to the model: {0}; {1}" + .format(expected_packed_files[0], expected_packed_files[1])) + self.assertEqual(set(packed_files), set(expected_packed_files)) + + +class MetadataUtilTest(MetadataTest): + + def test_convert_to_json_should_succeed(self): + metadata_buf = _read_file(self._metadata_file_with_version) + metadata_json = _metadata.convert_to_json(metadata_buf) + + # Verifies the generated json file. + golden_json_file_path = test_utils.get_test_data_path("golden_json.json") + expected = _read_file(golden_json_file_path, "r") + self.assertEqual(metadata_json, expected) + + +if __name__ == "__main__": + absltest.main() diff --git a/mediapipe/tasks/python/test/test_utils.py b/mediapipe/tasks/python/test/test_utils.py new file mode 100644 index 000000000..531a18f7a --- /dev/null +++ b/mediapipe/tasks/python/test/test_utils.py @@ -0,0 +1,45 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test util for MediaPipe Tasks.""" + +import os + +from absl import flags + +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.python._framework_bindings import image_frame as image_frame_module + +FLAGS = flags.FLAGS +_Image = image_module.Image +_ImageFormat = image_frame_module.ImageFormat +_RGB_CHANNELS = 3 + + +def test_srcdir(): + """Returns the path where to look for test data files.""" + if "test_srcdir" in flags.FLAGS: + return flags.FLAGS["test_srcdir"].value + elif "TEST_SRCDIR" in os.environ: + return os.environ["TEST_SRCDIR"] + else: + raise RuntimeError("Missing TEST_SRCDIR environment.") + + +def get_test_data_path(file_or_dirname: str) -> str: + """Returns full test data path.""" + for (directory, subdirs, files) in os.walk(test_srcdir()): + for f in subdirs + files: + if f.endswith(file_or_dirname): + return os.path.join(directory, f) + raise ValueError("No %s in test directory" % file_or_dirname) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 6980a12a0..290b665e7 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -31,7 +31,7 @@ py_test( "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:detections", "//mediapipe/tasks/python/core:base_options", - "//mediapipe/tasks/python/test:test_util", + "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:object_detector", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], diff --git a/mediapipe/tasks/python/test/vision/object_detector_test.py b/mediapipe/tasks/python/test/vision/object_detector_test.py index 64184d5fe..95b6bf867 100644 --- a/mediapipe/tasks/python/test/vision/object_detector_test.py +++ b/mediapipe/tasks/python/test/vision/object_detector_test.py @@ -25,7 +25,7 @@ from mediapipe.tasks.python.components.containers import bounding_box as boundin from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import detections as detections_module from mediapipe.tasks.python.core import base_options as base_options_module -from mediapipe.tasks.python.test import test_util +from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import object_detector from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module @@ -99,8 +99,8 @@ class ObjectDetectorTest(parameterized.TestCase): def setUp(self): super().setUp() self.test_image = _Image.create_from_file( - test_util.get_test_data_path(_IMAGE_FILE)) - self.model_path = test_util.get_test_data_path(_MODEL_FILE) + test_utils.get_test_data_path(_IMAGE_FILE)) + self.model_path = test_utils.get_test_data_path(_MODEL_FILE) def test_create_from_file_succeeds_with_valid_model_path(self): # Creates with default option and valid model file successfully. diff --git a/mediapipe/tasks/testdata/metadata/BUILD b/mediapipe/tasks/testdata/metadata/BUILD index 1cf94a38f..8bda87ae2 100644 --- a/mediapipe/tasks/testdata/metadata/BUILD +++ b/mediapipe/tasks/testdata/metadata/BUILD @@ -28,9 +28,13 @@ mediapipe_files(srcs = [ "mobile_ica_8bit-without-model-metadata.tflite", "mobile_object_classifier_v0_2_3-metadata-no-name.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite", + "mobilenet_v2_1.0_224_quant.tflite", ]) -exports_files(["external_file"]) +exports_files([ + "external_file", + "golden_json.json", +]) filegroup( name = "model_files", @@ -40,10 +44,14 @@ filegroup( "mobile_ica_8bit-without-model-metadata.tflite", "mobile_object_classifier_v0_2_3-metadata-no-name.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite", + "mobilenet_v2_1.0_224_quant.tflite", ], ) filegroup( name = "data_files", - srcs = ["external_file"], + srcs = [ + "external_file", + "golden_json.json", + ], ) diff --git a/mediapipe/tasks/testdata/metadata/golden_json.json b/mediapipe/tasks/testdata/metadata/golden_json.json new file mode 100644 index 000000000..601a5976c --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/golden_json.json @@ -0,0 +1,28 @@ +{ + "name": "Mobilenet_quantized", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + }, + { + } + ], + "output_tensor_metadata": [ + { + "associated_files": [ + { + "name": "file2" + } + ] + } + ] + } + ], + "associated_files": [ + { + "name": "file1" + } + ], + "min_parser_version": "1.0.0" +} diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index e246bbd8d..cd291fc1e 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -166,6 +166,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmark_with_attention.tflite?generation=1661875751615925"], ) + http_file( + name = "com_google_mediapipe_golden_json_json", + sha256 = "55c0c88748d099aa379930504df62c6c8f1d8874ea52d2f8a925f352c4c7f09c", + urls = ["https://storage.googleapis.com/mediapipe-assets/golden_json.json?generation=1664340169675228"], + ) + http_file( name = "com_google_mediapipe_hair_segmentation_tflite", sha256 = "d2c940c4fd80edeaf38f5d7387d1b4235ee320ed120080df67c663e749e77633", @@ -316,6 +322,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite?generation=1661875836078124"], ) + http_file( + name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_tflite", + sha256 = "f08d447cde49b4e0446428aa921aff0a14ea589fa9c5817b31f83128e9a43c1d", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.tflite?generation=1664340173966530"], + ) + http_file( name = "com_google_mediapipe_mobilenet_v2_1_0_224_tflite", sha256 = "ff5cb7f9e62c92ebdad971f8a98aa6b3106d82a64587a7787c6a385c9e791339", From 1e5cccdc73b60a241ff4585412e2c35ff7168614 Mon Sep 17 00:00:00 2001 From: Yuqi Li Date: Sat, 1 Oct 2022 23:47:54 -0700 Subject: [PATCH 015/132] internal change. PiperOrigin-RevId: 478288749 --- mediapipe/tasks/python/metadata/BUILD | 4 +--- mediapipe/tasks/python/test/BUILD | 4 +--- mediapipe/tasks/python/test/metadata/BUILD | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/mediapipe/tasks/python/metadata/BUILD b/mediapipe/tasks/python/metadata/BUILD index 34ee63f5e..07805ec61 100644 --- a/mediapipe/tasks/python/metadata/BUILD +++ b/mediapipe/tasks/python/metadata/BUILD @@ -34,7 +34,5 @@ py_binary( visibility = [ "//visibility:public", ], - deps = [ - ":metadata", - ], + deps = [":metadata"], ) diff --git a/mediapipe/tasks/python/test/BUILD b/mediapipe/tasks/python/test/BUILD index e31c7af1c..d4ef3a35b 100644 --- a/mediapipe/tasks/python/test/BUILD +++ b/mediapipe/tasks/python/test/BUILD @@ -23,7 +23,5 @@ py_library( testonly = 1, srcs = ["test_utils.py"], srcs_version = "PY3", - deps = [ - "//mediapipe/python:_framework_bindings", - ], + deps = ["//mediapipe/python:_framework_bindings"], ) diff --git a/mediapipe/tasks/python/test/metadata/BUILD b/mediapipe/tasks/python/test/metadata/BUILD index c679d86cb..2cdc7e63a 100644 --- a/mediapipe/tasks/python/test/metadata/BUILD +++ b/mediapipe/tasks/python/test/metadata/BUILD @@ -25,7 +25,5 @@ py_test( srcs = ["metadata_parser_test.py"], python_version = "PY3", srcs_version = "PY2AND3", - deps = [ - "//mediapipe/tasks/python/metadata", - ], + deps = ["//mediapipe/tasks/python/metadata"], ) From 03c8ac3641a84a2dd03167ee23f99942d09ea40e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 3 Oct 2022 01:58:41 -0700 Subject: [PATCH 016/132] Refactor ClassificationResult and ClassificationPostprocessing. PiperOrigin-RevId: 478444264 --- .../tasks/cc/audio/audio_classifier/BUILD | 12 +- .../audio_classifier/audio_classifier.cc | 11 +- .../audio/audio_classifier/audio_classifier.h | 15 ++- .../audio_classifier_graph.cc | 20 +-- .../audio_classifier/audio_classifier_test.cc | 5 +- .../cc/audio/audio_classifier/proto/BUILD | 2 +- .../audio_classifier_graph_options.proto | 4 +- mediapipe/tasks/cc/components/BUILD | 59 --------- .../tasks/cc/components/calculators/BUILD | 6 +- .../classification_aggregation_calculator.cc | 8 +- .../calculators/end_loop_calculator.cc | 5 +- .../cc/components/containers/proto/BUILD | 23 +++- .../containers/{ => proto}/category.proto | 2 +- .../{ => proto}/classifications.proto | 4 +- .../tasks/cc/components/processors/BUILD | 64 ++++++++++ .../classification_postprocessing_graph.cc} | 54 ++++---- .../classification_postprocessing_graph.h} | 28 ++-- ...assification_postprocessing_graph_test.cc} | 120 +++++++++--------- .../{ => processors}/classifier_options.cc | 10 +- .../{ => processors}/classifier_options.h | 12 +- .../{containers => processors/proto}/BUILD | 14 +- ...cation_postprocessing_graph_options.proto} | 6 +- .../proto/classifier_options.proto | 2 +- mediapipe/tasks/cc/components/proto/BUILD | 5 - .../cc/vision/hand_gesture_recognizer/BUILD | 6 +- .../hand_gesture_recognizer_subgraph.cc | 20 +-- .../hand_gesture_recognizer/proto/BUILD | 4 +- ..._gesture_recognizer_subgraph_options.proto | 4 +- .../tasks/cc/vision/image_classifier/BUILD | 12 +- .../image_classifier/image_classifier.cc | 11 +- .../image_classifier/image_classifier.h | 18 +-- .../image_classifier_graph.cc | 20 +-- .../image_classifier/image_classifier_test.cc | 7 +- .../cc/vision/image_classifier/proto/BUILD | 2 +- .../image_classifier_graph_options.proto | 4 +- .../tasks/python/components/containers/BUILD | 2 +- .../python/components/containers/category.py | 2 +- 37 files changed, 329 insertions(+), 274 deletions(-) rename mediapipe/tasks/cc/components/containers/{ => proto}/category.proto (96%) rename mediapipe/tasks/cc/components/containers/{ => proto}/classifications.proto (93%) create mode 100644 mediapipe/tasks/cc/components/processors/BUILD rename mediapipe/tasks/cc/components/{classification_postprocessing.cc => processors/classification_postprocessing_graph.cc} (92%) rename mediapipe/tasks/cc/components/{classification_postprocessing.h => processors/classification_postprocessing_graph.h} (59%) rename mediapipe/tasks/cc/components/{classification_postprocessing_test.cc => processors/classification_postprocessing_graph_test.cc} (88%) rename mediapipe/tasks/cc/components/{ => processors}/classifier_options.cc (81%) rename mediapipe/tasks/cc/components/{ => processors}/classifier_options.h (83%) rename mediapipe/tasks/cc/components/{containers => processors/proto}/BUILD (58%) rename mediapipe/tasks/cc/components/{classification_postprocessing_options.proto => processors/proto/classification_postprocessing_graph_options.proto} (91%) rename mediapipe/tasks/cc/components/{ => processors}/proto/classifier_options.proto (97%) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index 20ccf68f0..ac238bfda 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -35,9 +35,10 @@ cc_library( "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", "//mediapipe/tasks/cc/audio/utils:audio_tensor_specs", - "//mediapipe/tasks/cc/components:classification_postprocessing", - "//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", @@ -64,8 +65,9 @@ cc_library( "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", "//mediapipe/tasks/cc/audio/core:base_audio_task_api", "//mediapipe/tasks/cc/audio/core:running_mode", - "//mediapipe/tasks/cc/components:classifier_options", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc index 9a8075f77..702d802c5 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc @@ -24,8 +24,9 @@ limitations under the License. #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h" -#include "mediapipe/tasks/cc/components/classifier_options.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -37,6 +38,8 @@ namespace audio_classifier { namespace { +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; + constexpr char kAudioStreamName[] = "audio_in"; constexpr char kAudioTag[] = "AUDIO"; constexpr char kClassificationResultStreamName[] = "classification_result_out"; @@ -77,8 +80,8 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) { options_proto->mutable_base_options()->set_use_stream_mode( options->running_mode == core::RunningMode::AUDIO_STREAM); auto classifier_options_proto = - std::make_unique( - components::ConvertClassifierOptionsToProto( + std::make_unique( + components::processors::ConvertClassifierOptionsToProto( &(options->classifier_options))); options_proto->mutable_classifier_options()->Swap( classifier_options_proto.get()); diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h index bd8bd5e0c..200cffb8c 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h @@ -23,8 +23,8 @@ limitations under the License. #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h" #include "mediapipe/tasks/cc/audio/core/running_mode.h" -#include "mediapipe/tasks/cc/components/classifier_options.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/core/base_options.h" namespace mediapipe { @@ -40,7 +40,7 @@ struct AudioClassifierOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - components::ClassifierOptions classifier_options; + components::processors::ClassifierOptions classifier_options; // The running mode of the audio classifier. Default to the audio clips mode. // Audio classifier has two running modes: @@ -59,8 +59,9 @@ struct AudioClassifierOptions { // The user-defined result callback for processing audio stream data. // The result callback should only be specified when the running mode is set // to RunningMode::AUDIO_STREAM. - std::function)> result_callback = - nullptr; + std::function)> + result_callback = nullptr; }; // Performs audio classification on audio clips or audio stream. @@ -132,8 +133,8 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi { // framed audio clip. // TODO: Use `sample_rate` in AudioClassifierOptions by default // and makes `audio_sample_rate` optional. - absl::StatusOr Classify(mediapipe::Matrix audio_clip, - double audio_sample_rate); + absl::StatusOr Classify( + mediapipe::Matrix audio_clip, double audio_sample_rate); // Sends audio data (a block in a continuous audio stream) to perform audio // classification. Only use this method when the AudioClassifier is created diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc index 810fb2da5..12f8ce31a 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc @@ -31,9 +31,9 @@ limitations under the License. #include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -53,6 +53,7 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr char kAtPrestreamTag[] = "AT_PRESTREAM"; constexpr char kAudioTag[] = "AUDIO"; @@ -238,11 +239,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph { // Adds postprocessing calculators and connects them to the graph output. auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( - model_resources, task_options.classifier_options(), - &postprocessing.GetOptions< - tasks::components::ClassificationPostprocessingOptions>())); + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR( + components::processors::ConfigureClassificationPostprocessingGraph( + model_resources, task_options.classifier_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Time aggregation is only needed for performing audio classification on diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc index 4e874b520..4b64d2231 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc @@ -37,8 +37,8 @@ limitations under the License. #include "mediapipe/tasks/cc/audio/core/running_mode.h" #include "mediapipe/tasks/cc/audio/utils/test_utils.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/containers/category.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" namespace mediapipe { @@ -49,6 +49,7 @@ namespace { using ::absl::StatusOr; using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::testing::HasSubstr; using ::testing::Optional; diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD index 033bb51ac..bfe37ec01 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD @@ -24,7 +24,7 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto index 63b4b3293..16aa86aeb 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto @@ -18,7 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.audio.audio_classifier.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; +import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; message AudioClassifierGraphOptions { @@ -31,7 +31,7 @@ message AudioClassifierGraphOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - optional components.proto.ClassifierOptions classifier_options = 2; + optional components.processors.proto.ClassifierOptions classifier_options = 2; // The default sample rate of the input audio. Must be set when the // AudioClassifier is configured to process audio stream data. diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD index 4de32ce9b..7939e4e39 100644 --- a/mediapipe/tasks/cc/components/BUILD +++ b/mediapipe/tasks/cc/components/BUILD @@ -58,65 +58,6 @@ cc_library( # TODO: Enable this test -cc_library( - name = "classifier_options", - srcs = ["classifier_options.cc"], - hdrs = ["classifier_options.h"], - deps = ["//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto"], -) - -mediapipe_proto_library( - name = "classification_postprocessing_options_proto", - srcs = ["classification_postprocessing_options.proto"], - deps = [ - "//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto", - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto", - ], -) - -cc_library( - name = "classification_postprocessing", - srcs = ["classification_postprocessing.cc"], - hdrs = ["classification_postprocessing.h"], - deps = [ - ":classification_postprocessing_options_cc_proto", - "//mediapipe/calculators/core:split_vector_calculator", - "//mediapipe/calculators/core:split_vector_calculator_cc_proto", - "//mediapipe/calculators/tensor:tensors_dequantization_calculator", - "//mediapipe/calculators/tensor:tensors_to_classification_calculator", - "//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:packet", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:tensor", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator", - "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", - "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/components/utils:source_or_node_output", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/metadata:metadata_extractor", - "//mediapipe/tasks/metadata:metadata_schema_cc", - "//mediapipe/util:label_map_cc_proto", - "//mediapipe/util:label_map_util", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@org_tensorflow//tensorflow/lite/schema:schema_fbs", - ], - alwayslink = 1, -) - cc_library( name = "embedder_options", srcs = ["embedder_options.cc"], diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index 13ca6b496..7d01e4dfe 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -37,8 +37,8 @@ cc_library( "//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/tasks/cc/components/containers:category_cc_proto", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:category_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "@com_google_absl//absl/status", ], alwayslink = 1, @@ -128,7 +128,7 @@ cc_library( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", ], alwayslink = 1, ) diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc index b2848bc3f..e1f69e607 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc @@ -25,15 +25,15 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/containers/category.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" namespace mediapipe { namespace api2 { using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions; -using ::mediapipe::tasks::ClassificationResult; -using ::mediapipe::tasks::Classifications; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::components::containers::proto::Classifications; // Aggregates ClassificationLists into a single ClassificationResult that has // 3 dimensions: (classification head, classification timestamp, classification diff --git a/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc b/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc index b688cda91..10eb962dd 100644 --- a/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc @@ -17,12 +17,13 @@ limitations under the License. #include -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" // Specialized EndLoopCalculator for Tasks specific types. namespace mediapipe::tasks { -typedef EndLoopCalculator> +typedef EndLoopCalculator< + std::vector> EndLoopClassificationResultCalculator; REGISTER_CALCULATOR(::mediapipe::tasks::EndLoopClassificationResultCalculator); diff --git a/mediapipe/tasks/cc/components/containers/proto/BUILD b/mediapipe/tasks/cc/components/containers/proto/BUILD index 9c6402e64..633b5b369 100644 --- a/mediapipe/tasks/cc/components/containers/proto/BUILD +++ b/mediapipe/tasks/cc/components/containers/proto/BUILD @@ -18,6 +18,24 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +mediapipe_proto_library( + name = "category_proto", + srcs = ["category.proto"], +) + +mediapipe_proto_library( + name = "classifications_proto", + srcs = ["classifications.proto"], + deps = [ + ":category_proto", + ], +) + +mediapipe_proto_library( + name = "embeddings_proto", + srcs = ["embeddings.proto"], +) + mediapipe_proto_library( name = "landmarks_detection_result_proto", srcs = [ @@ -29,8 +47,3 @@ mediapipe_proto_library( "//mediapipe/framework/formats:rect_proto", ], ) - -mediapipe_proto_library( - name = "embeddings_proto", - srcs = ["embeddings.proto"], -) diff --git a/mediapipe/tasks/cc/components/containers/category.proto b/mediapipe/tasks/cc/components/containers/proto/category.proto similarity index 96% rename from mediapipe/tasks/cc/components/containers/category.proto rename to mediapipe/tasks/cc/components/containers/proto/category.proto index 47f38b75a..a44fb5b15 100644 --- a/mediapipe/tasks/cc/components/containers/category.proto +++ b/mediapipe/tasks/cc/components/containers/proto/category.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks; +package mediapipe.tasks.components.containers.proto; // A single classification result. message Category { diff --git a/mediapipe/tasks/cc/components/containers/classifications.proto b/mediapipe/tasks/cc/components/containers/proto/classifications.proto similarity index 93% rename from mediapipe/tasks/cc/components/containers/classifications.proto rename to mediapipe/tasks/cc/components/containers/proto/classifications.proto index 469c67fc9..e0ccad7a1 100644 --- a/mediapipe/tasks/cc/components/containers/classifications.proto +++ b/mediapipe/tasks/cc/components/containers/proto/classifications.proto @@ -15,9 +15,9 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks; +package mediapipe.tasks.components.containers.proto; -import "mediapipe/tasks/cc/components/containers/category.proto"; +import "mediapipe/tasks/cc/components/containers/proto/category.proto"; // List of predicted categories with an optional timestamp. message ClassificationEntry { diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD new file mode 100644 index 000000000..62f04dcb7 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -0,0 +1,64 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "classifier_options", + srcs = ["classifier_options.cc"], + hdrs = ["classifier_options.h"], + deps = ["//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto"], +) + +cc_library( + name = "classification_postprocessing_graph", + srcs = ["classification_postprocessing_graph.cc"], + hdrs = ["classification_postprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/core:split_vector_calculator", + "//mediapipe/calculators/core:split_vector_calculator_cc_proto", + "//mediapipe/calculators/tensor:tensors_dequantization_calculator", + "//mediapipe/calculators/tensor:tensors_to_classification_calculator", + "//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator", + "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/components/utils:source_or_node_output", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "//mediapipe/util:label_map_cc_proto", + "//mediapipe/util:label_map_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/components/classification_postprocessing.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc similarity index 92% rename from mediapipe/tasks/cc/components/classification_postprocessing.cc rename to mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index 871476e8f..35adab687 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include @@ -37,9 +37,9 @@ limitations under the License. #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" @@ -51,6 +51,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { @@ -61,7 +62,7 @@ using ::mediapipe::api2::Timestamp; using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::proto::ClassifierOptions; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::tflite::ProcessUnit; @@ -79,7 +80,8 @@ constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTimestampsTag[] = "TIMESTAMPS"; // Performs sanity checks on provided ClassifierOptions. -absl::Status SanityCheckClassifierOptions(const ClassifierOptions& options) { +absl::Status SanityCheckClassifierOptions( + const proto::ClassifierOptions& options) { if (options.max_results() == 0) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -203,7 +205,7 @@ absl::StatusOr GetScoreThreshold( // Gets the category allowlist or denylist (if any) as a set of indices. absl::StatusOr> GetAllowOrDenyCategoryIndicesIfAny( - const ClassifierOptions& options, const LabelItems& label_items) { + const proto::ClassifierOptions& options, const LabelItems& label_items) { absl::flat_hash_set category_indices; // Exit early if no denylist/allowlist. if (options.category_denylist_size() == 0 && @@ -239,7 +241,7 @@ absl::StatusOr> GetAllowOrDenyCategoryIndicesIfAny( absl::Status ConfigureScoreCalibrationIfAny( const ModelMetadataExtractor& metadata_extractor, int tensor_index, - ClassificationPostprocessingOptions* options) { + proto::ClassificationPostprocessingGraphOptions* options) { const auto* tensor_metadata = metadata_extractor.GetOutputTensorMetadata(tensor_index); if (tensor_metadata == nullptr) { @@ -283,7 +285,7 @@ absl::Status ConfigureScoreCalibrationIfAny( // Fills in the TensorsToClassificationCalculatorOptions based on the // classifier options and the (optional) output tensor metadata. absl::Status ConfigureTensorsToClassificationCalculator( - const ClassifierOptions& options, + const proto::ClassifierOptions& options, const ModelMetadataExtractor& metadata_extractor, int tensor_index, TensorsToClassificationCalculatorOptions* calculator_options) { const auto* tensor_metadata = @@ -345,10 +347,10 @@ void ConfigureClassificationAggregationCalculator( } // namespace -absl::Status ConfigureClassificationPostprocessing( +absl::Status ConfigureClassificationPostprocessingGraph( const ModelResources& model_resources, - const ClassifierOptions& classifier_options, - ClassificationPostprocessingOptions* options) { + const proto::ClassifierOptions& classifier_options, + proto::ClassificationPostprocessingGraphOptions* options) { MP_RETURN_IF_ERROR(SanityCheckClassifierOptions(classifier_options)); ASSIGN_OR_RETURN(const auto heads_properties, GetClassificationHeadsProperties(model_resources)); @@ -366,8 +368,8 @@ absl::Status ConfigureClassificationPostprocessing( return absl::OkStatus(); } -// A "mediapipe.tasks.components.ClassificationPostprocessingSubgraph" converts -// raw tensors into ClassificationResult objects. +// A "ClassificationPostprocessingGraph" converts raw tensors into +// ClassificationResult objects. // - Accepts CPU input tensors. // // Inputs: @@ -381,10 +383,10 @@ absl::Status ConfigureClassificationPostprocessing( // CLASSIFICATION_RESULT - ClassificationResult // The output aggregated classification results. // -// The recommended way of using this subgraph is through the GraphBuilder API -// using the 'ConfigureClassificationPostprocessing()' function. See header file -// for more details. -class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { +// The recommended way of using this graph is through the GraphBuilder API +// using the 'ConfigureClassificationPostprocessingGraph()' function. See header +// file for more details. +class ClassificationPostprocessingGraph : public mediapipe::Subgraph { public: absl::StatusOr GetConfig( mediapipe::SubgraphContext* sc) override { @@ -392,7 +394,7 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { ASSIGN_OR_RETURN( auto classification_result_out, BuildClassificationPostprocessing( - sc->Options(), + sc->Options(), graph[Input>(kTensorsTag)], graph[Input>(kTimestampsTag)], graph)); classification_result_out >> @@ -401,19 +403,19 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { } private: - // Adds an on-device classification postprocessing subgraph into the provided - // builder::Graph instance. The classification postprocessing subgraph takes + // Adds an on-device classification postprocessing graph into the provided + // builder::Graph instance. The classification postprocessing graph takes // tensors (std::vector) as input and returns one output // stream containing the output classification results (ClassificationResult). // - // options: the on-device ClassificationPostprocessingOptions. + // options: the on-device ClassificationPostprocessingGraphOptions. // tensors_in: (std::vector>) tensors to postprocess. // timestamps_in: (std::vector) optional collection of // timestamps that a single ClassificationResult should aggregate. // graph: the mediapipe builder::Graph instance to be updated. absl::StatusOr> BuildClassificationPostprocessing( - const ClassificationPostprocessingOptions& options, + const proto::ClassificationPostprocessingGraphOptions& options, Source> tensors_in, Source> timestamps_in, Graph& graph) { const int num_heads = options.tensors_to_classifications_options_size(); @@ -504,9 +506,11 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { kClassificationResultTag)]; } }; -REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::components::ClassificationPostprocessingSubgraph); +REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::components::processors:: + ClassificationPostprocessingGraph); // NOLINT + +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/classification_postprocessing.h b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h similarity index 59% rename from mediapipe/tasks/cc/components/classification_postprocessing.h rename to mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h index eb638bd60..8aedad46d 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing.h +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h @@ -13,32 +13,33 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_ #include "absl/status/status.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { -// Configures a ClassificationPostprocessing subgraph using the provided model +// Configures a ClassificationPostprocessingGraph using the provided model // resources and ClassifierOptions. // - Accepts CPU input tensors. // // Example usage: // // auto& postprocessing = -// graph.AddNode("mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); -// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( +// graph.AddNode("mediapipe.tasks.components.processors.ClassificationPostprocessingGraph"); +// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph( // model_resources, // classifier_options, -// &preprocessing.GetOptions())); +// &preprocessing.GetOptions())); // -// The resulting ClassificationPostprocessing subgraph has the following I/O: +// The resulting ClassificationPostprocessingGraph has the following I/O: // Inputs: // TENSORS - std::vector // The output tensors of an InferenceCalculator. @@ -49,13 +50,14 @@ namespace components { // Outputs: // CLASSIFICATION_RESULT - ClassificationResult // The output aggregated classification results. -absl::Status ConfigureClassificationPostprocessing( +absl::Status ConfigureClassificationPostprocessingGraph( const tasks::core::ModelResources& model_resources, - const tasks::components::proto::ClassifierOptions& classifier_options, - ClassificationPostprocessingOptions* options); + const proto::ClassifierOptions& classifier_options, + proto::ClassificationPostprocessingGraphOptions* options); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/classification_postprocessing_test.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc similarity index 88% rename from mediapipe/tasks/cc/components/classification_postprocessing_test.cc rename to mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc index 67223050f..bb03e2530 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing_test.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include #include @@ -42,9 +42,9 @@ limitations under the License. #include "mediapipe/framework/timestamp.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/util/label_map.pb.h" @@ -53,6 +53,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { using ::mediapipe::api2::Input; @@ -60,7 +61,7 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::proto::ClassifierOptions; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::ModelResources; using ::testing::HasSubstr; using ::testing::proto::Approximately; @@ -101,12 +102,12 @@ TEST_F(ConfigureTest, FailsWithInvalidMaxResults) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.set_max_results(0); - ClassificationPostprocessingOptions options_out; - auto status = ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out); + proto::ClassificationPostprocessingGraphOptions options_out; + auto status = ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option")); @@ -116,13 +117,13 @@ TEST_F(ConfigureTest, FailsWithBothAllowlistAndDenylist) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.add_category_allowlist("foo"); options_in.add_category_denylist("bar"); - ClassificationPostprocessingOptions options_out; - auto status = ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out); + proto::ClassificationPostprocessingGraphOptions options_out; + auto status = ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options")); @@ -132,12 +133,12 @@ TEST_F(ConfigureTest, FailsWithAllowlistAndNoMetadata) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.add_category_allowlist("foo"); - ClassificationPostprocessingOptions options_out; - auto status = ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out); + proto::ClassificationPostprocessingGraphOptions options_out; + auto status = ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT( @@ -149,11 +150,11 @@ TEST_F(ConfigureTest, SucceedsWithoutMetadata) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); EXPECT_THAT(options_out, Approximately(EqualsProto( R"pb(score_calibration_options: [] @@ -171,12 +172,12 @@ TEST_F(ConfigureTest, SucceedsWithMaxResults) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.set_max_results(3); - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); EXPECT_THAT(options_out, Approximately(EqualsProto( R"pb(score_calibration_options: [] @@ -194,12 +195,12 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.set_score_threshold(0.5); - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); EXPECT_THAT(options_out, Approximately(EqualsProto( R"pb(score_calibration_options: [] @@ -217,11 +218,11 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Check label map size and two first elements. EXPECT_EQ( @@ -254,12 +255,12 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.add_category_allowlist("tench"); - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Clear label map and compare the rest of the options. options_out.mutable_tensors_to_classifications_options(0) @@ -283,12 +284,12 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.add_category_denylist("background"); - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Clear label map and compare the rest of the options. options_out.mutable_tensors_to_classifications_options(0) @@ -313,11 +314,11 @@ TEST_F(ConfigureTest, SucceedsWithScoreCalibration) { auto model_resources, CreateModelResourcesForModel( kQuantizedImageClassifierWithDummyScoreCalibration)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Check label map size and two first elements. EXPECT_EQ( @@ -362,11 +363,11 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kFloatTwoHeadsAudioClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Check label maps sizes and first two elements. EXPECT_EQ( options_out.tensors_to_classifications_options(0).label_items_size(), @@ -414,17 +415,19 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { class PostprocessingTest : public tflite_shims::testing::Test { protected: absl::StatusOr BuildGraph( - absl::string_view model_name, const ClassifierOptions& options, + absl::string_view model_name, const proto::ClassifierOptions& options, bool connect_timestamps = false) { ASSIGN_OR_RETURN(auto model_resources, CreateModelResourcesForModel(model_name)); Graph graph; auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph( *model_resources, options, - &postprocessing.GetOptions())); + &postprocessing + .GetOptions())); graph[Input>(kTensorsTag)].SetName(kTensorsName) >> postprocessing.In(kTensorsTag); if (connect_timestamps) { @@ -495,7 +498,7 @@ class PostprocessingTest : public tflite_shims::testing::Test { TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(3); options.set_score_threshold(0.5); MP_ASSERT_OK_AND_ASSIGN( @@ -524,7 +527,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { TEST_F(PostprocessingTest, SucceedsWithMetadata) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(3); MP_ASSERT_OK_AND_ASSIGN( auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options)); @@ -567,7 +570,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) { TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(3); MP_ASSERT_OK_AND_ASSIGN( auto poller, @@ -613,7 +616,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(2); MP_ASSERT_OK_AND_ASSIGN( auto poller, @@ -673,7 +676,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { TEST_F(PostprocessingTest, SucceedsWithTimestamps) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(2); MP_ASSERT_OK_AND_ASSIGN( auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options, @@ -729,6 +732,7 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) { } } // namespace +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/classifier_options.cc b/mediapipe/tasks/cc/components/processors/classifier_options.cc similarity index 81% rename from mediapipe/tasks/cc/components/classifier_options.cc rename to mediapipe/tasks/cc/components/processors/classifier_options.cc index c54db5f88..349bb569d 100644 --- a/mediapipe/tasks/cc/components/classifier_options.cc +++ b/mediapipe/tasks/cc/components/processors/classifier_options.cc @@ -13,17 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/classifier_options.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { -tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( +proto::ClassifierOptions ConvertClassifierOptionsToProto( ClassifierOptions* options) { - tasks::components::proto::ClassifierOptions options_proto; + proto::ClassifierOptions options_proto; options_proto.set_display_names_locale(options->display_names_locale); options_proto.set_max_results(options->max_results); options_proto.set_score_threshold(options->score_threshold); @@ -36,6 +37,7 @@ tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( return options_proto; } +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/classifier_options.h b/mediapipe/tasks/cc/components/processors/classifier_options.h similarity index 83% rename from mediapipe/tasks/cc/components/classifier_options.h rename to mediapipe/tasks/cc/components/processors/classifier_options.h index e15bf5e69..189b42e60 100644 --- a/mediapipe/tasks/cc/components/classifier_options.h +++ b/mediapipe/tasks/cc/components/processors/classifier_options.h @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { // Classifier options for MediaPipe C++ classification Tasks. struct ClassifierOptions { @@ -49,11 +50,12 @@ struct ClassifierOptions { }; // Converts a ClassifierOptions to a ClassifierOptionsProto. -tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( +proto::ClassifierOptions ConvertClassifierOptionsToProto( ClassifierOptions* classifier_options); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD similarity index 58% rename from mediapipe/tasks/cc/components/containers/BUILD rename to mediapipe/tasks/cc/components/processors/proto/BUILD index 701f84824..d7cbe47ff 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -19,14 +19,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) mediapipe_proto_library( - name = "category_proto", - srcs = ["category.proto"], + name = "classifier_options_proto", + srcs = ["classifier_options.proto"], ) mediapipe_proto_library( - name = "classifications_proto", - srcs = ["classifications.proto"], + name = "classification_postprocessing_graph_options_proto", + srcs = ["classification_postprocessing_graph_options.proto"], deps = [ - ":category_proto", + "//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto", ], ) diff --git a/mediapipe/tasks/cc/components/classification_postprocessing_options.proto b/mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto similarity index 91% rename from mediapipe/tasks/cc/components/classification_postprocessing_options.proto rename to mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto index 9b67e2f75..1de788eab 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto @@ -15,16 +15,16 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components; +package mediapipe.tasks.components.processors.proto; import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto"; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto"; import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto"; -message ClassificationPostprocessingOptions { +message ClassificationPostprocessingGraphOptions { extend mediapipe.CalculatorOptions { - optional ClassificationPostprocessingOptions ext = 460416950; + optional ClassificationPostprocessingGraphOptions ext = 460416950; } // Optional mapping between output tensor index and corresponding score diff --git a/mediapipe/tasks/cc/components/proto/classifier_options.proto b/mediapipe/tasks/cc/components/processors/proto/classifier_options.proto similarity index 97% rename from mediapipe/tasks/cc/components/proto/classifier_options.proto rename to mediapipe/tasks/cc/components/processors/proto/classifier_options.proto index ea1491bb8..7afbfc14e 100644 --- a/mediapipe/tasks/cc/components/proto/classifier_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/classifier_options.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components.proto; +package mediapipe.tasks.components.processors.proto; // Shared options used by all classification tasks. message ClassifierOptions { diff --git a/mediapipe/tasks/cc/components/proto/BUILD b/mediapipe/tasks/cc/components/proto/BUILD index 8c4dcdad9..c11d6f95a 100644 --- a/mediapipe/tasks/cc/components/proto/BUILD +++ b/mediapipe/tasks/cc/components/proto/BUILD @@ -23,11 +23,6 @@ mediapipe_proto_library( srcs = ["segmenter_options.proto"], ) -mediapipe_proto_library( - name = "classifier_options_proto", - srcs = ["classifier_options.proto"], -) - mediapipe_proto_library( name = "embedder_options_proto", srcs = ["embedder_options.proto"], diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD index bb5b86212..9e2d9bd17 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD @@ -54,10 +54,10 @@ cc_library( "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:classification_postprocessing", - "//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto", "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc index e124d3410..247d8453d 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc @@ -27,9 +27,9 @@ limitations under the License. #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -49,6 +49,7 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto:: HandGestureRecognizerSubgraphOptions; using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions; @@ -218,11 +219,14 @@ class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph { auto inference_output_tensors = inference.Out(kTensorsTag); auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( - model_resources, graph_options.classifier_options(), - &postprocessing.GetOptions< - tasks::components::ClassificationPostprocessingOptions>())); + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR( + components::processors::ConfigureClassificationPostprocessingGraph( + model_resources, graph_options.classifier_options(), + &postprocessing + .GetOptions())); inference_output_tensors >> postprocessing.In(kTensorsTag); auto classification_result = postprocessing[Output("CLASSIFICATION_RESULT")]; diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD index f3927727e..44ec611b2 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD @@ -26,7 +26,7 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) @@ -37,7 +37,5 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_proto", - "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto index f73443eaf..d8ee95037 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto @@ -18,7 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; +import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; message HandGestureRecognizerSubgraphOptions { @@ -31,7 +31,7 @@ message HandGestureRecognizerSubgraphOptions { // Options for configuring the gesture classifier behavior, such as score // threshold, number of results, etc. - optional components.proto.ClassifierOptions classifier_options = 2; + optional components.processors.proto.ClassifierOptions classifier_options = 2; // Minimum confidence value ([0.0, 1.0]) for the hand landmarks to be // considered tracked successfully diff --git a/mediapipe/tasks/cc/vision/image_classifier/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD index e7c8a6586..dfa77cb96 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -26,11 +26,11 @@ cc_library( "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:classification_postprocessing", - "//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto", "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", @@ -50,9 +50,9 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:classifier_options", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:utils", diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc index 1e092e85a..0338b2ee2 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc @@ -26,9 +26,9 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/timestamp.h" -#include "mediapipe/tasks/cc/components/classifier_options.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -56,6 +56,7 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::PacketMap; // Builds a NormalizedRect covering the entire image. @@ -107,8 +108,8 @@ ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) { options_proto->mutable_base_options()->set_use_stream_mode( options->running_mode != core::RunningMode::IMAGE); auto classifier_options_proto = - std::make_unique( - components::ConvertClassifierOptionsToProto( + std::make_unique( + components::processors::ConvertClassifierOptionsToProto( &(options->classifier_options))); options_proto->mutable_classifier_options()->Swap( classifier_options_proto.get()); diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h index 8ff11413e..24f36017a 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h @@ -23,8 +23,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" -#include "mediapipe/tasks/cc/components/classifier_options.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" @@ -51,12 +51,14 @@ struct ImageClassifierOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - components::ClassifierOptions classifier_options; + components::processors::ClassifierOptions classifier_options; // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. - std::function, const Image&, int64)> + std::function, + const Image&, int64)> result_callback = nullptr; }; @@ -113,7 +115,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // The image can be of any size with format RGB or RGBA. // TODO: describe exact preprocessing steps once // YUVToImageCalculator is integrated. - absl::StatusOr Classify( + absl::StatusOr Classify( mediapipe::Image image, std::optional roi = std::nullopt); @@ -127,9 +129,9 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // The image can be of any size with format RGB or RGBA. It's required to // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. - absl::StatusOr ClassifyForVideo( - mediapipe::Image image, int64 timestamp_ms, - std::optional roi = std::nullopt); + absl::StatusOr + ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms, + std::optional roi = std::nullopt); // Sends live image data to image classification, and the results will be // available via the "result_callback" provided in the ImageClassifierOptions. diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 0d7b60c99..9a0078c5c 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -22,11 +22,11 @@ limitations under the License. #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" @@ -43,6 +43,7 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); @@ -152,11 +153,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph { // Adds postprocessing calculators and connects them to the graph output. auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( - model_resources, task_options.classifier_options(), - &postprocessing.GetOptions< - tasks::components::ClassificationPostprocessingOptions>())); + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR( + components::processors::ConfigureClassificationPostprocessingGraph( + model_resources, task_options.classifier_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Outputs the aggregated classification result as the subgraph output diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index edbb851c0..070a5a034 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -32,8 +32,8 @@ limitations under the License. #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/containers/category.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -48,6 +48,9 @@ namespace image_classifier { namespace { using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::proto::ClassificationEntry; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::components::containers::proto::Classifications; using ::testing::HasSubstr; using ::testing::Optional; diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD b/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD index a6f5791e3..29638bebd 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD @@ -24,7 +24,7 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto index 3da047110..b307a66b6 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto @@ -18,7 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_classifier.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; +import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; message ImageClassifierGraphOptions { @@ -31,5 +31,5 @@ message ImageClassifierGraphOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - optional components.proto.ClassifierOptions classifier_options = 2; + optional components.processors.proto.ClassifierOptions classifier_options = 2; } diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 2bc951220..8dd9fcd60 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -31,7 +31,7 @@ py_library( name = "category", srcs = ["category.py"], deps = [ - "//mediapipe/tasks/cc/components/containers:category_py_pb2", + "//mediapipe/tasks/cc/components/containers/proto:category_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", ], ) diff --git a/mediapipe/tasks/python/components/containers/category.py b/mediapipe/tasks/python/components/containers/category.py index 00f68e532..0b347fc10 100644 --- a/mediapipe/tasks/python/components/containers/category.py +++ b/mediapipe/tasks/python/components/containers/category.py @@ -16,7 +16,7 @@ import dataclasses from typing import Any -from mediapipe.tasks.cc.components.containers import category_pb2 +from mediapipe.tasks.cc.components.containers.proto import category_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls _CategoryProto = category_pb2.Category From 65c7fb9004573e4b421d952eb50b6519f94733c0 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 3 Oct 2022 04:48:10 -0700 Subject: [PATCH 017/132] Internal change PiperOrigin-RevId: 478470582 --- mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD index cb3ef9656..97f8dfd15 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD @@ -28,6 +28,7 @@ cc_library_with_tflite( ], tflite_deps = [ "//mediapipe/tasks/cc/core:model_resources_cache", + "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", ], deps = [ "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", From cfd0f3e79fa631692ac4e809f4619d6ff53d4421 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 3 Oct 2022 13:48:47 -0700 Subject: [PATCH 018/132] Add HandLandmarkerGraph which connect HandDetectorGraph and HandLandmarkerSubgraph with landmarks tracking. PiperOrigin-RevId: 478596004 --- mediapipe/tasks/cc/vision/hand_detector/BUILD | 2 +- .../hand_detector/hand_detector_graph.cc | 63 ++-- .../hand_detector/hand_detector_graph_test.cc | 17 +- .../tasks/cc/vision/hand_detector/proto/BUILD | 4 +- ...roto => hand_detector_graph_options.proto} | 14 +- .../tasks/cc/vision/hand_landmarker/BUILD | 39 +++ .../hand_landmarker/hand_landmarker_graph.cc | 284 ++++++++++++++++++ .../hand_landmarker_graph_test.cc | 167 ++++++++++ .../hand_landmarker_subgraph.cc | 68 ++--- .../hand_landmarker_subgraph_test.cc | 10 +- .../cc/vision/hand_landmarker/proto/BUILD | 6 +- ...to => hand_landmarker_graph_options.proto} | 19 +- .../hand_landmarker_subgraph_options.proto | 6 +- 13 files changed, 600 insertions(+), 99 deletions(-) rename mediapipe/tasks/cc/vision/hand_detector/proto/{hand_detector_options.proto => hand_detector_graph_options.proto} (76%) create mode 100644 mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc create mode 100644 mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc rename mediapipe/tasks/cc/vision/hand_landmarker/proto/{hand_landmarker_options.proto => hand_landmarker_graph_options.proto} (74%) diff --git a/mediapipe/tasks/cc/vision/hand_detector/BUILD b/mediapipe/tasks/cc/vision/hand_detector/BUILD index c87cc50a6..433a30471 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/BUILD +++ b/mediapipe/tasks/cc/vision/hand_detector/BUILD @@ -51,7 +51,7 @@ cc_library( "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index 7ead21bad..7ef8d62f5 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -40,7 +40,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/utils.h" -#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" namespace mediapipe { @@ -53,18 +53,23 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorOptions; +using ::mediapipe::tasks::vision::hand_detector::proto:: + HandDetectorGraphOptions; constexpr char kImageTag[] = "IMAGE"; -constexpr char kDetectionsTag[] = "DETECTIONS"; -constexpr char kNormRectsTag[] = "NORM_RECTS"; +constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS"; +constexpr char kHandRectsTag[] = "HAND_RECTS"; +constexpr char kPalmRectsTag[] = "PALM_RECTS"; struct HandDetectionOuts { Source> palm_detections; Source> hand_rects; + Source> palm_rects; + Source image; }; void ConfigureTensorsToDetectionsCalculator( + const HandDetectorGraphOptions& tasks_options, mediapipe::TensorsToDetectionsCalculatorOptions* options) { // TODO use metadata to configure these fields. options->set_num_classes(1); @@ -77,7 +82,7 @@ void ConfigureTensorsToDetectionsCalculator( options->set_sigmoid_score(true); options->set_score_clipping_thresh(100.0); options->set_reverse_output_order(true); - options->set_min_score_thresh(0.5); + options->set_min_score_thresh(tasks_options.min_detection_confidence()); options->set_x_scale(192.0); options->set_y_scale(192.0); options->set_w_scale(192.0); @@ -144,19 +149,26 @@ void ConfigureRectTransformationCalculator( // Image to perform detection on. // // Outputs: -// DETECTIONS - std::vector +// PALM_DETECTIONS - std::vector // Detected palms with maximum `num_hands` specified in options. -// NORM_RECTS - std::vector +// HAND_RECTS - std::vector // Detected hand bounding boxes in normalized coordinates. +// PLAM_RECTS - std::vector +// Detected palm bounding boxes in normalized coordinates. +// IMAGE - Image +// The input image that the hand detector runs on and has the pixel data +// stored on the target storage (CPU vs GPU). // // Example: // node { // calculator: "mediapipe.tasks.vision.HandDetectorGraph" // input_stream: "IMAGE:image" -// output_stream: "DETECTIONS:palm_detections" -// output_stream: "NORM_RECTS:hand_rects_from_palm_detections" +// output_stream: "PALM_DETECTIONS:palm_detections" +// output_stream: "HAND_RECTS:hand_rects_from_palm_detections" +// output_stream: "PALM_RECTS:palm_rects" +// output_stream: "IMAGE:image_out" // options { -// [mediapipe.tasks.hand_detector.proto.HandDetectorOptions.ext] { +// [mediapipe.tasks.hand_detector.proto.HandDetectorGraphOptions.ext] { // base_options { // model_asset { // file_name: "palm_detection.tflite" @@ -173,16 +185,20 @@ class HandDetectorGraph : public core::ModelTaskGraph { absl::StatusOr GetConfig( SubgraphContext* sc) override { ASSIGN_OR_RETURN(const auto* model_resources, - CreateModelResources(sc)); + CreateModelResources(sc)); Graph graph; - ASSIGN_OR_RETURN(auto hand_detection_outs, - BuildHandDetectionSubgraph( - sc->Options(), *model_resources, - graph[Input(kImageTag)], graph)); + ASSIGN_OR_RETURN( + auto hand_detection_outs, + BuildHandDetectionSubgraph(sc->Options(), + *model_resources, + graph[Input(kImageTag)], graph)); hand_detection_outs.palm_detections >> - graph[Output>(kDetectionsTag)]; + graph[Output>(kPalmDetectionsTag)]; hand_detection_outs.hand_rects >> - graph[Output>(kNormRectsTag)]; + graph[Output>(kHandRectsTag)]; + hand_detection_outs.palm_rects >> + graph[Output>(kPalmRectsTag)]; + hand_detection_outs.image >> graph[Output(kImageTag)]; return graph.GetConfig(); } @@ -196,7 +212,7 @@ class HandDetectorGraph : public core::ModelTaskGraph { // image_in: image stream to run hand detection on. // graph: the mediapipe builder::Graph instance to be updated. absl::StatusOr BuildHandDetectionSubgraph( - const HandDetectorOptions& subgraph_options, + const HandDetectorGraphOptions& subgraph_options, const core::ModelResources& model_resources, Source image_in, Graph& graph) { // Add image preprocessing subgraph. The model expects aspect ratio @@ -235,6 +251,7 @@ class HandDetectorGraph : public core::ModelTaskGraph { auto& tensors_to_detections = graph.AddNode("TensorsToDetectionsCalculator"); ConfigureTensorsToDetectionsCalculator( + subgraph_options, &tensors_to_detections .GetOptions()); model_output_tensors >> tensors_to_detections.In("TENSORS"); @@ -281,7 +298,8 @@ class HandDetectorGraph : public core::ModelTaskGraph { .GetOptions()); palm_detections >> detections_to_rects.In("DETECTIONS"); image_size >> detections_to_rects.In("IMAGE_SIZE"); - auto palm_rects = detections_to_rects.Out("NORM_RECTS"); + auto palm_rects = + detections_to_rects[Output>("NORM_RECTS")]; // Expands and shifts the rectangle that contains the palm so that it's // likely to cover the entire hand. @@ -308,8 +326,11 @@ class HandDetectorGraph : public core::ModelTaskGraph { clip_normalized_rect_vector_size[Output>( "")]; - return HandDetectionOuts{.palm_detections = palm_detections, - .hand_rects = clipped_hand_rects}; + return HandDetectionOuts{ + /* palm_detections= */ palm_detections, + /* hand_rects= */ clipped_hand_rects, + /* palm_rects= */ palm_rects, + /* image= */ preprocessing[Output(kImageTag)]}; } }; diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc index 3fa97664e..850ff2732 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc @@ -40,7 +40,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" -#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" @@ -60,7 +60,8 @@ using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::TaskRunner; using ::mediapipe::tasks::core::proto::ExternalFile; using ::mediapipe::tasks::vision::DecodeImageFromFile; -using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorOptions; +using ::mediapipe::tasks::vision::hand_detector::proto:: + HandDetectorGraphOptions; using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorResult; using ::testing::EqualsProto; using ::testing::TestParamInfo; @@ -80,9 +81,9 @@ constexpr char kTwoHandsResultFile[] = "hand_detector_result_two_hands.pbtxt"; constexpr char kImageTag[] = "IMAGE"; constexpr char kImageName[] = "image"; -constexpr char kPalmDetectionsTag[] = "DETECTIONS"; +constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS"; constexpr char kPalmDetectionsName[] = "palm_detections"; -constexpr char kHandNormRectsTag[] = "NORM_RECTS"; +constexpr char kHandRectsTag[] = "HAND_RECTS"; constexpr char kHandNormRectsName[] = "hand_norm_rects"; constexpr float kPalmDetectionBboxMaxDiff = 0.01; @@ -106,20 +107,20 @@ absl::StatusOr> CreateTaskRunner( auto& hand_detection = graph.AddNode("mediapipe.tasks.vision.HandDetectorGraph"); - auto options = std::make_unique(); + auto options = std::make_unique(); options->mutable_base_options()->mutable_model_asset()->set_file_name( JoinPath("./", kTestDataDirectory, model_name)); options->set_min_detection_confidence(0.5); options->set_num_hands(num_hands); - hand_detection.GetOptions().Swap(options.get()); + hand_detection.GetOptions().Swap(options.get()); graph[Input(kImageTag)].SetName(kImageName) >> hand_detection.In(kImageTag); hand_detection.Out(kPalmDetectionsTag).SetName(kPalmDetectionsName) >> graph[Output>(kPalmDetectionsTag)]; - hand_detection.Out(kHandNormRectsTag).SetName(kHandNormRectsName) >> - graph[Output>(kHandNormRectsTag)]; + hand_detection.Out(kHandRectsTag).SetName(kHandNormRectsName) >> + graph[Output>(kHandRectsTag)]; return TaskRunner::Create( graph.GetConfig(), std::make_unique()); diff --git a/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD b/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD index 2d22aab10..77f3b2649 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD +++ b/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD @@ -21,8 +21,8 @@ package(default_visibility = [ licenses(["notice"]) mediapipe_proto_library( - name = "hand_detector_options_proto", - srcs = ["hand_detector_options.proto"], + name = "hand_detector_graph_options_proto", + srcs = ["hand_detector_graph_options.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", diff --git a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto similarity index 76% rename from mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto rename to mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto index ae22c7991..be20583d0 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto +++ b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto @@ -21,24 +21,20 @@ import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.handdetector"; -option java_outer_classname = "HandDetectorOptionsProto"; +option java_outer_classname = "HandDetectorGraphOptionsProto"; -message HandDetectorOptions { +message HandDetectorGraphOptions { extend mediapipe.CalculatorOptions { - optional HandDetectorOptions ext = 464864288; + optional HandDetectorGraphOptions ext = 464864288; } // Base options for configuring Task library, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; - // The locale to use for display names specified through the TFLite Model - // Metadata, if any. Defaults to English. - optional string display_names_locale = 2 [default = "en"]; - // Minimum confidence value ([0.0, 1.0]) for confidence score to be considered // successfully detecting a hand in the image. - optional float min_detection_confidence = 3 [default = 0.5]; + optional float min_detection_confidence = 2 [default = 0.5]; // The maximum number of hands output by the detector. - optional int32 num_hands = 4; + optional int32 num_hands = 3; } diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 653976b96..c968c17fa 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -51,6 +51,7 @@ cc_library( # TODO: move calculators in modules/hand_landmark/calculators to tasks dir. "//mediapipe/modules/hand_landmark/calculators:hand_landmarks_to_rect_calculator", "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/utils:gate", "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", @@ -66,3 +67,41 @@ cc_library( ) # TODO: Enable this test + +cc_library( + name = "hand_landmarker_graph", + srcs = ["hand_landmarker_graph.cc"], + deps = [ + ":hand_landmarker_subgraph", + "//mediapipe/calculators/core:begin_loop_calculator", + "//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto", + "//mediapipe/calculators/core:end_loop_calculator", + "//mediapipe/calculators/core:gate_calculator", + "//mediapipe/calculators/core:gate_calculator_cc_proto", + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/calculators/core:previous_loopback_calculator", + "//mediapipe/calculators/util:collection_has_min_size_calculator", + "//mediapipe/calculators/util:collection_has_min_size_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/utils:gate", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator", + "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_subgraph_options_cc_proto", + ], + alwayslink = 1, +) + +# TODO: Enable this test diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc new file mode 100644 index 000000000..6041d528f --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -0,0 +1,284 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h" +#include "mediapipe/calculators/core/gate_calculator.pb.h" +#include "mediapipe/calculators/util/collection_has_min_size_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/utils/gate.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::utils::DisallowIf; +using ::mediapipe::tasks::vision::hand_detector::proto:: + HandDetectorGraphOptions; +using ::mediapipe::tasks::vision::hand_landmarker::proto:: + HandLandmarkerGraphOptions; +using ::mediapipe::tasks::vision::hand_landmarker::proto:: + HandLandmarkerSubgraphOptions; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kLandmarksTag[] = "LANDMARKS"; +constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; +constexpr char kHandRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME"; +constexpr char kHandednessTag[] = "HANDEDNESS"; +constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS"; +constexpr char kPalmRectsTag[] = "PALM_RECTS"; +constexpr char kPreviousLoopbackCalculatorName[] = "PreviousLoopbackCalculator"; + +struct HandLandmarkerOutputs { + Source> landmark_lists; + Source> world_landmark_lists; + Source> hand_rects_next_frame; + Source> handednesses; + Source> palm_rects; + Source> palm_detections; + Source image; +}; + +} // namespace + +// A "mediapipe.tasks.vision.HandLandmarkerGraph" performs hand +// landmarks detection. The HandLandmarkerGraph consists of two subgraphs: +// HandDetectorGraph and HandLandmarkerSubgraph. HandLandmarkerSubgraph detects +// landmarks from bounding boxes produced by HandDetectorGraph. +// HandLandmarkerGraph tracks the landmarks over time, and skips the +// HandDetectorGraph. If the tracking is lost or the detectd hands are +// less than configured max number hands, HandDetectorGraph would be triggered +// to detect hands. +// +// Accepts CPU input images and outputs Landmarks on CPU. +// +// Inputs: +// IMAGE - Image +// Image to perform hand landmarks detection on. +// +// Outputs: +// LANDMARKS: - std::vector +// Vector of detected hand landmarks. +// WORLD_LANDMARKS - std::vector +// Vector of detected hand landmarks in world coordinates. +// HAND_RECT_NEXT_FRAME - std::vector +// Vector of the predicted rects enclosing the same hand RoI for landmark +// detection on the next frame. +// HANDEDNESS - std::vector +// Vector of classification of handedness. +// PALM_RECTS - std::vector +// Detected palm bounding boxes in normalized coordinates. +// PALM_DETECTIONS - std::vector +// Detected palms with maximum `num_hands` specified in options. +// IMAGE - Image +// The input image that the hand landmarker runs on and has the pixel data +// stored on the target storage (CPU vs GPU). +// +// Example: +// node { +// calculator: "mediapipe.tasks.vision.HandLandmarkerGraph" +// input_stream: "IMAGE:image_in" +// output_stream: "LANDMARKS:hand_landmarks" +// output_stream: "WORLD_LANDMARKS:world_hand_landmarks" +// output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame" +// output_stream: "HANDEDNESS:handedness" +// output_stream: "PALM_RECTS:palm_rects" +// output_stream: "PALM_DETECTIONS:palm_detections" +// output_stream: "IMAGE:image_out" +// options { +// [mediapipe.tasks.hand_landmarker.proto.HandLandmarkerGraphOptions.ext] { +// base_options { +// model_asset { +// file_name: "hand_landmarker.task" +// } +// } +// hand_detector_graph_options { +// base_options { +// model_asset { +// file_name: "palm_detection.tflite" +// } +// } +// min_detection_confidence: 0.5 +// num_hands: 2 +// } +// hand_landmarker_subgraph_options { +// base_options { +// model_asset { +// file_name: "hand_landmark_lite.tflite" +// } +// } +// min_detection_confidence: 0.5 +// } +// } +// } +// } +class HandLandmarkerGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + ASSIGN_OR_RETURN( + auto hand_landmarker_outputs, + BuildHandLandmarkerGraph(sc->Options(), + graph[Input(kImageTag)], graph)); + hand_landmarker_outputs.landmark_lists >> + graph[Output>(kLandmarksTag)]; + hand_landmarker_outputs.world_landmark_lists >> + graph[Output>(kWorldLandmarksTag)]; + hand_landmarker_outputs.hand_rects_next_frame >> + graph[Output>(kHandRectNextFrameTag)]; + hand_landmarker_outputs.handednesses >> + graph[Output>(kHandednessTag)]; + hand_landmarker_outputs.palm_rects >> + graph[Output>(kPalmRectsTag)]; + hand_landmarker_outputs.palm_detections >> + graph[Output>(kPalmDetectionsTag)]; + hand_landmarker_outputs.image >> graph[Output(kImageTag)]; + + // TODO remove when support is fixed. + // As mediapipe GraphBuilder currently doesn't support configuring + // InputStreamInfo, modifying the CalculatorGraphConfig proto directly. + CalculatorGraphConfig config = graph.GetConfig(); + for (int i = 0; i < config.node_size(); ++i) { + if (config.node(i).calculator() == kPreviousLoopbackCalculatorName) { + auto* info = config.mutable_node(i)->add_input_stream_info(); + info->set_tag_index("LOOP"); + info->set_back_edge(true); + break; + } + } + return config; + } + + private: + // Adds a mediapipe hand landmark detection graph into the provided + // builder::Graph instance. + // + // tasks_options: the mediapipe tasks module HandLandmarkerGraphOptions. + // image_in: (mediapipe::Image) stream to run hand landmark detection on. + // graph: the mediapipe graph instance to be updated. + absl::StatusOr BuildHandLandmarkerGraph( + const HandLandmarkerGraphOptions& tasks_options, Source image_in, + Graph& graph) { + const int max_num_hands = + tasks_options.hand_detector_graph_options().num_hands(); + + auto& previous_loopback = graph.AddNode(kPreviousLoopbackCalculatorName); + image_in >> previous_loopback.In("MAIN"); + auto prev_hand_rects_from_landmarks = + previous_loopback[Output>("PREV_LOOP")]; + + auto& min_size_node = + graph.AddNode("NormalizedRectVectorHasMinSizeCalculator"); + prev_hand_rects_from_landmarks >> min_size_node.In("ITERABLE"); + min_size_node.GetOptions() + .set_min_size(max_num_hands); + auto has_enough_hands = min_size_node.Out("").Cast(); + + auto image_for_hand_detector = + DisallowIf(image_in, has_enough_hands, graph); + + auto& hand_detector = + graph.AddNode("mediapipe.tasks.vision.HandDetectorGraph"); + hand_detector.GetOptions().CopyFrom( + tasks_options.hand_detector_graph_options()); + image_for_hand_detector >> hand_detector.In("IMAGE"); + auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS"); + + auto& hand_association = graph.AddNode("HandAssociationCalculator"); + hand_association.GetOptions() + .set_min_similarity_threshold(tasks_options.min_tracking_confidence()); + prev_hand_rects_from_landmarks >> + hand_association[Input>::Multiple("")][0]; + hand_rects_from_hand_detector >> + hand_association[Input>::Multiple("")][1]; + auto hand_rects = hand_association.Out(""); + + auto& clip_hand_rects = + graph.AddNode("ClipNormalizedRectVectorSizeCalculator"); + clip_hand_rects.GetOptions() + .set_max_vec_size(max_num_hands); + hand_rects >> clip_hand_rects.In(""); + auto clipped_hand_rects = clip_hand_rects.Out(""); + + auto& hand_landmarker_subgraph = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerSubgraph"); + hand_landmarker_subgraph.GetOptions() + .CopyFrom(tasks_options.hand_landmarker_subgraph_options()); + image_in >> hand_landmarker_subgraph.In("IMAGE"); + clipped_hand_rects >> hand_landmarker_subgraph.In("HAND_RECT"); + + auto hand_rects_for_next_frame = + hand_landmarker_subgraph[Output>( + kHandRectNextFrameTag)]; + // Back edge. + hand_rects_for_next_frame >> previous_loopback.In("LOOP"); + + // TODO: Replace PassThroughCalculator with a calculator that + // converts the pixel data to be stored on the target storage (CPU vs GPU). + auto& pass_through = graph.AddNode("PassThroughCalculator"); + image_in >> pass_through.In(""); + + return {{ + /* landmark_lists= */ hand_landmarker_subgraph + [Output>(kLandmarksTag)], + /* world_landmark_lists= */ + hand_landmarker_subgraph[Output>( + kWorldLandmarksTag)], + /* hand_rects_next_frame= */ hand_rects_for_next_frame, + hand_landmarker_subgraph[Output>( + kHandednessTag)], + /* palm_rects= */ + hand_detector[Output>(kPalmRectsTag)], + /* palm_detections */ + hand_detector[Output>(kPalmDetectionsTag)], + /* image */ + pass_through[Output("")], + }}; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::hand_landmarker::HandLandmarkerGraph); + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc new file mode 100644 index 000000000..413af68ff --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc @@ -0,0 +1,167 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +namespace { + +using ::file::Defaults; +using ::file::GetTextProto; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::vision::hand_landmarker::proto:: + HandLandmarkerGraphOptions; +using ::testing::EqualsProto; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kPalmDetectionModel[] = "palm_detection_full.tflite"; +constexpr char kHandLandmarkerFullModel[] = "hand_landmark_full.tflite"; +constexpr char kLeftHandsImage[] = "left_hands.jpg"; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageName[] = "image_in"; +constexpr char kLandmarksTag[] = "LANDMARKS"; +constexpr char kLandmarksName[] = "landmarks"; +constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; +constexpr char kWorldLandmarksName[] = "world_landmarks"; +constexpr char kHandRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME"; +constexpr char kHandRectNextFrameName[] = "hand_rect_next_frame"; +constexpr char kHandednessTag[] = "HANDEDNESS"; +constexpr char kHandednessName[] = "handedness"; + +// Expected hand landmarks positions, in text proto format. +constexpr char kExpectedLeftUpHandLandmarksFilename[] = + "expected_left_up_hand_landmarks.prototxt"; +constexpr char kExpectedLeftDownHandLandmarksFilename[] = + "expected_left_down_hand_landmarks.prototxt"; + +constexpr float kFullModelFractionDiff = 0.03; // percentage +constexpr float kAbsMargin = 0.03; +constexpr int kMaxNumHands = 2; +constexpr float kMinTrackingConfidence = 0.5; + +NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) { + NormalizedLandmarkList expected_landmark_list; + MP_EXPECT_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, filename), + &expected_landmark_list, Defaults())); + return expected_landmark_list; +} + +// Helper function to create a Hand Landmarker TaskRunner. +absl::StatusOr> CreateTaskRunner() { + Graph graph; + auto& hand_landmarker_graph = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"); + auto& options = + hand_landmarker_graph.GetOptions(); + options.mutable_hand_detector_graph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name(JoinPath("./", kTestDataDirectory, kPalmDetectionModel)); + options.mutable_hand_detector_graph_options()->mutable_base_options(); + options.mutable_hand_detector_graph_options()->set_num_hands(kMaxNumHands); + options.mutable_hand_landmarker_subgraph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name( + JoinPath("./", kTestDataDirectory, kHandLandmarkerFullModel)); + options.set_min_tracking_confidence(kMinTrackingConfidence); + + graph[Input(kImageTag)].SetName(kImageName) >> + hand_landmarker_graph.In(kImageTag); + hand_landmarker_graph.Out(kLandmarksTag).SetName(kLandmarksName) >> + graph[Output>(kLandmarksTag)]; + hand_landmarker_graph.Out(kWorldLandmarksTag).SetName(kWorldLandmarksName) >> + graph[Output>(kWorldLandmarksTag)]; + hand_landmarker_graph.Out(kHandednessTag).SetName(kHandednessName) >> + graph[Output>(kHandednessTag)]; + hand_landmarker_graph.Out(kHandRectNextFrameTag) + .SetName(kHandRectNextFrameName) >> + graph[Output>(kHandRectNextFrameTag)]; + return TaskRunner::Create( + graph.GetConfig(), absl::make_unique()); +} + +class HandLandmarkerTest : public tflite_shims::testing::Test {}; + +TEST_F(HandLandmarkerTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kLeftHandsImage))); + MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner()); + auto output_packets = + task_runner->Process({{kImageName, MakePacket(std::move(image))}}); + const auto& landmarks = (*output_packets)[kLandmarksName] + .Get>(); + ASSERT_EQ(landmarks.size(), kMaxNumHands); + std::vector expected_landmarks = { + GetExpectedLandmarkList(kExpectedLeftUpHandLandmarksFilename), + GetExpectedLandmarkList(kExpectedLeftDownHandLandmarksFilename)}; + + EXPECT_THAT(landmarks[0], + Approximately(Partially(EqualsProto(expected_landmarks[0])), + /*margin=*/kAbsMargin, + /*fraction=*/kFullModelFractionDiff)); + EXPECT_THAT(landmarks[1], + Approximately(Partially(EqualsProto(expected_landmarks[1])), + /*margin=*/kAbsMargin, + /*fraction=*/kFullModelFractionDiff)); +} + +} // namespace + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph.cc index fff4ae0d4..0ac4686b7 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/utils/gate.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -48,6 +49,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace hand_landmarker { namespace { @@ -55,6 +57,7 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::utils::AllowIf; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::vision::hand_landmarker::proto:: HandLandmarkerSubgraphOptions; @@ -82,7 +85,6 @@ struct SingleHandLandmarkerOutputs { Source hand_presence; Source hand_presence_score; Source handedness; - Source> image_size; }; struct HandLandmarkerOutputs { @@ -92,7 +94,6 @@ struct HandLandmarkerOutputs { Source> presences; Source> presence_scores; Source> handednesses; - Source> image_size; }; absl::Status SanityCheckOptions(const HandLandmarkerSubgraphOptions& options) { @@ -208,8 +209,6 @@ void ConfigureHandRectTransformationCalculator( // Float value indicates the probability that the hand is present. // HANDEDNESS - ClassificationList // Classification of handedness. -// IMAGE_SIZE - std::vector -// The size of input image. // // Example: // node { @@ -221,8 +220,6 @@ void ConfigureHandRectTransformationCalculator( // output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame" // output_stream: "PRESENCE:hand_presence" // output_stream: "PRESENCE_SCORE:hand_presence_score" -// output_stream: "HANDEDNESS:handedness" -// output_stream: "IMAGE_SIZE:image_size" // options { // [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarkerSubgraphOptions.ext] // { @@ -259,8 +256,6 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { graph[Output(kPresenceScoreTag)]; hand_landmark_detection_outs.handedness >> graph[Output(kHandednessTag)]; - hand_landmark_detection_outs.image_size >> - graph[Output>(kImageSizeTag)]; return graph.GetConfig(); } @@ -332,18 +327,7 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { // score of hand presence. auto& tensors_to_hand_presence = graph.AddNode("TensorsToFloatsCalculator"); hand_flag_tensors >> tensors_to_hand_presence.In("TENSORS"); - - // Converts the handedness tensor into a float that represents the - // classification score of handedness. - auto& tensors_to_handedness = - graph.AddNode("TensorsToClassificationCalculator"); - ConfigureTensorsToHandednessCalculator( - &tensors_to_handedness.GetOptions< - mediapipe::TensorsToClassificationCalculatorOptions>()); - handedness_tensors >> tensors_to_handedness.In("TENSORS"); auto hand_presence_score = tensors_to_hand_presence[Output("FLOAT")]; - auto handedness = - tensors_to_handedness[Output("CLASSIFICATIONS")]; // Applies a threshold to the confidence score to determine whether a // hand is present. @@ -354,6 +338,18 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { hand_presence_score >> hand_presence_thresholding.In("FLOAT"); auto hand_presence = hand_presence_thresholding[Output("FLAG")]; + // Converts the handedness tensor into a float that represents the + // classification score of handedness. + auto& tensors_to_handedness = + graph.AddNode("TensorsToClassificationCalculator"); + ConfigureTensorsToHandednessCalculator( + &tensors_to_handedness.GetOptions< + mediapipe::TensorsToClassificationCalculatorOptions>()); + handedness_tensors >> tensors_to_handedness.In("TENSORS"); + auto handedness = AllowIf( + tensors_to_handedness[Output("CLASSIFICATIONS")], + hand_presence, graph); + // Adjusts landmarks (already normalized to [0.f, 1.f]) on the letterboxed // hand image (after image transformation with the FIT scale mode) to the // corresponding locations on the same image with the letterbox removed @@ -371,8 +367,9 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { landmark_letterbox_removal.Out("LANDMARKS") >> landmark_projection.In("NORM_LANDMARKS"); hand_rect >> landmark_projection.In("NORM_RECT"); - auto projected_landmarks = - landmark_projection[Output("NORM_LANDMARKS")]; + auto projected_landmarks = AllowIf( + landmark_projection[Output("NORM_LANDMARKS")], + hand_presence, graph); // Projects the world landmarks from the cropped hand image to the // corresponding locations on the full image before cropping (input to the @@ -383,7 +380,8 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { world_landmark_projection.In("LANDMARKS"); hand_rect >> world_landmark_projection.In("NORM_RECT"); auto projected_world_landmarks = - world_landmark_projection[Output("LANDMARKS")]; + AllowIf(world_landmark_projection[Output("LANDMARKS")], + hand_presence, graph); // Converts the hand landmarks into a rectangle (normalized by image size) // that encloses the hand. @@ -403,7 +401,8 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { hand_landmarks_to_rect.Out("NORM_RECT") >> hand_rect_transformation.In("NORM_RECT"); auto hand_rect_next_frame = - hand_rect_transformation[Output("")]; + AllowIf(hand_rect_transformation[Output("")], + hand_presence, graph); return {{ /* hand_landmarks= */ projected_landmarks, @@ -412,16 +411,15 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { /* hand_presence= */ hand_presence, /* hand_presence_score= */ hand_presence_score, /* handedness= */ handedness, - /* image_size= */ image_size, }}; } }; REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::vision::SingleHandLandmarkerSubgraph); + ::mediapipe::tasks::vision::hand_landmarker::SingleHandLandmarkerSubgraph); -// A "mediapipe.tasks.vision.HandLandmarkerSubgraph" performs multi -// hand landmark detection. +// A "mediapipe.tasks.vision.HandLandmarkerSubgraph" performs multi hand +// landmark detection. // - Accepts CPU input image and a vector of hand rect RoIs to detect the // multiple hands landmarks enclosed by the RoIs. Output vectors of // hand landmarks related results, where each element in the vectors @@ -449,8 +447,6 @@ REGISTER_MEDIAPIPE_GRAPH( // Vector of float value indicates the probability that the hand is present. // HANDEDNESS - std::vector // Vector of classification of handedness. -// IMAGE_SIZE - std::vector -// The size of input image. // // Example: // node { @@ -463,7 +459,6 @@ REGISTER_MEDIAPIPE_GRAPH( // output_stream: "PRESENCE:hand_presence" // output_stream: "PRESENCE_SCORE:hand_presence_score" // output_stream: "HANDEDNESS:handedness" -// output_stream: "IMAGE_SIZE:image_size" // options { // [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarkerSubgraphOptions.ext] // { @@ -499,8 +494,6 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph { graph[Output>(kPresenceScoreTag)]; hand_landmark_detection_outputs.handednesses >> graph[Output>(kHandednessTag)]; - hand_landmark_detection_outputs.image_size >> - graph[Output>(kImageSizeTag)]; return graph.GetConfig(); } @@ -510,8 +503,8 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph { const HandLandmarkerSubgraphOptions& subgraph_options, Source image_in, Source> multi_hand_rects, Graph& graph) { - auto& hand_landmark_subgraph = - graph.AddNode("mediapipe.tasks.vision.SingleHandLandmarkerSubgraph"); + auto& hand_landmark_subgraph = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker.SingleHandLandmarkerSubgraph"); hand_landmark_subgraph.GetOptions().CopyFrom( subgraph_options); @@ -533,8 +526,6 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph { hand_landmark_subgraph.Out("HAND_RECT_NEXT_FRAME"); auto landmarks = hand_landmark_subgraph.Out("LANDMARKS"); auto world_landmarks = hand_landmark_subgraph.Out("WORLD_LANDMARKS"); - auto image_size = - hand_landmark_subgraph[Output>("IMAGE_SIZE")]; auto& end_loop_handedness = graph.AddNode("EndLoopClassificationListCalculator"); @@ -585,13 +576,14 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph { /* presences= */ presences, /* presence_scores= */ presence_scores, /* handednesses= */ handednesses, - /* image_size= */ image_size, }}; } }; -REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::HandLandmarkerSubgraph); +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::hand_landmarker::HandLandmarkerSubgraph); +} // namespace hand_landmarker } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph_test.cc index 1c2bc6da7..7d91dc3c7 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph_test.cc @@ -45,6 +45,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace hand_landmarker { namespace { using ::file::Defaults; @@ -112,8 +113,8 @@ absl::StatusOr> CreateSingleHandTaskRunner( absl::string_view model_name) { Graph graph; - auto& hand_landmark_detection = - graph.AddNode("mediapipe.tasks.vision.SingleHandLandmarkerSubgraph"); + auto& hand_landmark_detection = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker.SingleHandLandmarkerSubgraph"); auto options = std::make_unique(); options->mutable_base_options()->mutable_model_asset()->set_file_name( @@ -151,8 +152,8 @@ absl::StatusOr> CreateMultiHandTaskRunner( absl::string_view model_name) { Graph graph; - auto& multi_hand_landmark_detection = - graph.AddNode("mediapipe.tasks.vision.HandLandmarkerSubgraph"); + auto& multi_hand_landmark_detection = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerSubgraph"); auto options = std::make_unique(); options->mutable_base_options()->mutable_model_asset()->set_file_name( @@ -462,6 +463,7 @@ INSTANTIATE_TEST_SUITE_P( }); } // namespace +} // namespace hand_landmarker } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD index 8cc984c47..9d1ba6f90 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD @@ -31,13 +31,13 @@ mediapipe_proto_library( ) mediapipe_proto_library( - name = "hand_landmarker_options_proto", - srcs = ["hand_landmarker_options.proto"], + name = "hand_landmarker_graph_options_proto", + srcs = ["hand_landmarker_graph_options.proto"], deps = [ ":hand_landmarker_subgraph_options_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_options_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto similarity index 74% rename from mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_options.proto rename to mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto index b3d82eda4..13849ec5e 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto @@ -19,22 +19,25 @@ package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; -import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto"; +import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto"; import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto"; -message HandLandmarkerOptions { +message HandLandmarkerGraphOptions { extend mediapipe.CalculatorOptions { - optional HandLandmarkerOptions ext = 462713202; + optional HandLandmarkerGraphOptions ext = 462713202; } // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; - // The locale to use for display names specified through the TFLite Model - // Metadata, if any. Defaults to English. - optional string display_names_locale = 2 [default = "en"]; + // Options for hand detector graph. + optional hand_detector.proto.HandDetectorGraphOptions + hand_detector_graph_options = 2; - optional hand_detector.proto.HandDetectorOptions hand_detector_options = 3; + // Options for hand landmarker subgraph. + optional HandLandmarkerSubgraphOptions hand_landmarker_subgraph_options = 3; - optional HandLandmarkerSubgraphOptions hand_landmarker_subgraph_options = 4; + // Minimum confidence for hand landmarks tracking to be considered + // successfully. + optional float min_tracking_confidence = 4 [default = 0.5]; } diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto index 9e93384d6..02d18e8ab 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto @@ -28,11 +28,7 @@ message HandLandmarkerSubgraphOptions { // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; - // The locale to use for display names specified through the TFLite Model - // Metadata, if any. Defaults to English. - optional string display_names_locale = 2 [default = "en"]; - // Minimum confidence value ([0.0, 1.0]) for hand presence score to be // considered successfully detecting a hand in the image. - optional float min_detection_confidence = 3 [default = 0.5]; + optional float min_detection_confidence = 2 [default = 0.5]; } From f7fa3dc9bea39a621a139fc312ba1d6695958b17 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Mon, 3 Oct 2022 22:04:29 -0700 Subject: [PATCH 019/132] Explaining "Graph Options" in the MediaPipe user guide. PiperOrigin-RevId: 478688026 --- docs/framework_concepts/graphs.md | 92 +++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/docs/framework_concepts/graphs.md b/docs/framework_concepts/graphs.md index f951b506d..b20a87467 100644 --- a/docs/framework_concepts/graphs.md +++ b/docs/framework_concepts/graphs.md @@ -143,6 +143,98 @@ Below is an example of how to create a subgraph named `TwoPassThroughSubgraph`. } ``` +## Graph Options + +It is possible to specify a "graph options" protobuf for a MediaPipe graph +similar to the [`Calculator Options`](calculators.md#calculator-options) +protobuf specified for a MediaPipe calculator. These "graph options" can be +specified where a graph is invoked, and used to populate calculator options and +subgraph options within the graph. + +In a CalculatorGraphConfig, graph options can be specified for a subgraph +exactly like calculator options, as shown below: + +``` +node { + calculator: "FlowLimiterCalculator" + input_stream: "image" + output_stream: "throttled_image" + node_options: { + [type.googleapis.com/mediapipe.FlowLimiterCalculatorOptions] { + max_in_flight: 1 + } + } +} + +node { + calculator: "FaceDetectionSubgraph" + input_stream: "IMAGE:throttled_image" + node_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] { + tensor_width: 192 + tensor_height: 192 + } + } +} +``` + +In a CalculatorGraphConfig, graph options can be accepted and used to populate +calculator options, as shown below: + +``` +graph_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] {} +} + +node: { + calculator: "ImageToTensorCalculator" + input_stream: "IMAGE:multi_backend_image" + node_options: { + [type.googleapis.com/mediapipe.ImageToTensorCalculatorOptions] { + keep_aspect_ratio: true + border_mode: BORDER_ZERO + } + } + option_value: "output_tensor_width:options/tensor_width" + option_value: "output_tensor_height:options/tensor_height" +} + +node { + calculator: "InferenceCalculator" + node_options: { + [type.googleapis.com/mediapipe.InferenceCalculatorOptions] {} + } + option_value: "delegate:options/delegate" + option_value: "model_path:options/model_path" +} +``` + +In this example, the `FaceDetectionSubgraph` accepts graph option protobuf +`FaceDetectionOptions`. The `FaceDetectionOptions` is used to define some field +values in the calculator options `ImageToTensorCalculatorOptions` and some field +values in the subgraph options `InferenceCalculatorOptions`. The field values +are defined using the `option_value:` syntax. + +In the `CalculatorGraphConfig::Node` protobuf, the fields `node_options:` and +`option_value:` together define the option values for a calculator such as +`ImageToTensorCalculator`. The `node_options:` field defines a set of literal +constant values using the text protobuf syntax. Each `option_value:` field +defines the value for one protobuf field using information from the enclosing +graph, specifically from field values of the graph options of the enclosing +graph. In the example above, the `option_value:` +`"output_tensor_width:options/tensor_width"` defines the field +`ImageToTensorCalculatorOptions.output_tensor_width` using the value of +`FaceDetectionOptions.tensor_width`. + +The syntax of `option_value:` is similar to the syntax of `input_stream:`. The +syntax is `option_value: "LHS:RHS"`. The LHS identifies a calculator option +field and the RHS identifies a graph option field. More specifically, the LHS +and RHS each consists of a series of protobuf field names identifying nested +protobuf messages and fields separated by '/'. This is known as the "ProtoPath" +syntax. Nested messages that are referenced in the LHS or RHS must already be +defined in the enclosing protobuf in order to be traversed using +`option_value:`. + ## Cycles From 2cb9ebb5e341c2419c9580ba01aff8d28a508f01 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 3 Oct 2022 23:43:13 -0700 Subject: [PATCH 020/132] Rename HandGestureRecognizer to GestureRecognizer and update namespace for Tasks C++ conventions. PiperOrigin-RevId: 478700907 --- .../BUILD | 12 +-- .../calculators/BUILD | 16 ++- .../handedness_to_matrix_calculator.cc | 13 ++- .../handedness_to_matrix_calculator_test.cc | 4 - .../landmarks_to_matrix_calculator.cc | 13 +-- .../landmarks_to_matrix_calculator.proto | 2 +- .../landmarks_to_matrix_calculator_test.cc | 10 +- .../hand_gesture_recognizer_graph.cc} | 100 +++++++++--------- .../handedness_util.cc | 4 +- .../handedness_util.h | 8 +- .../handedness_util_test.cc | 4 +- .../proto/BUILD | 13 +-- ...nd_gesture_recognizer_graph_options.proto} | 6 +- .../hand_detector/hand_detector_graph.cc | 16 +-- .../hand_detector/hand_detector_graph_test.cc | 4 +- .../hand_landmarker/hand_landmarker_graph.cc | 2 +- 16 files changed, 114 insertions(+), 113 deletions(-) rename mediapipe/tasks/cc/vision/{hand_gesture_recognizer => gesture_recognizer}/BUILD (83%) rename mediapipe/tasks/cc/vision/{hand_gesture_recognizer => gesture_recognizer}/calculators/BUILD (84%) rename mediapipe/tasks/cc/vision/{hand_gesture_recognizer => gesture_recognizer}/calculators/handedness_to_matrix_calculator.cc (90%) rename mediapipe/tasks/cc/vision/{hand_gesture_recognizer => gesture_recognizer}/calculators/handedness_to_matrix_calculator_test.cc (97%) rename mediapipe/tasks/cc/vision/{hand_gesture_recognizer => gesture_recognizer}/calculators/landmarks_to_matrix_calculator.cc (96%) rename mediapipe/tasks/cc/vision/{hand_gesture_recognizer/proto => gesture_recognizer/calculators}/landmarks_to_matrix_calculator.proto (97%) rename mediapipe/tasks/cc/vision/{hand_gesture_recognizer => gesture_recognizer}/calculators/landmarks_to_matrix_calculator_test.cc (96%) rename mediapipe/tasks/cc/vision/{hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc => gesture_recognizer/hand_gesture_recognizer_graph.cc} (80%) rename mediapipe/tasks/cc/vision/{hand_gesture_recognizer => gesture_recognizer}/handedness_util.cc (93%) rename mediapipe/tasks/cc/vision/{hand_gesture_recognizer => gesture_recognizer}/handedness_util.h (79%) rename mediapipe/tasks/cc/vision/{hand_gesture_recognizer => gesture_recognizer}/handedness_util_test.cc (94%) rename mediapipe/tasks/cc/vision/{hand_gesture_recognizer => gesture_recognizer}/proto/BUILD (73%) rename mediapipe/tasks/cc/vision/{hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto => gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto} (89%) diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD similarity index 83% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD rename to mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index 9e2d9bd17..cb392873e 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -41,8 +41,8 @@ cc_test( ) cc_library( - name = "hand_gesture_recognizer_subgraph", - srcs = ["hand_gesture_recognizer_subgraph.cc"], + name = "hand_gesture_recognizer_graph", + srcs = ["hand_gesture_recognizer_graph.cc"], deps = [ "//mediapipe/calculators/core:concatenate_vector_calculator", "//mediapipe/calculators/tensor:tensor_converter_calculator", @@ -62,10 +62,10 @@ cc_library( "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators:handedness_to_matrix_calculator", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators:landmarks_to_matrix_calculator", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:hand_gesture_recognizer_subgraph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:landmarks_to_matrix_calculator_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:handedness_to_matrix_calculator", + "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator", + "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_subgraph", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "//mediapipe/tasks/metadata:metadata_schema_cc", diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD similarity index 84% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/BUILD rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD index 4863c8682..a6de4f950 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD @@ -12,11 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + package(default_visibility = [ "//mediapipe/app/xeno:__subpackages__", "//mediapipe/tasks:internal", ]) +mediapipe_proto_library( + name = "landmarks_to_matrix_calculator_proto", + srcs = ["landmarks_to_matrix_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) + cc_library( name = "handedness_to_matrix_calculator", srcs = ["handedness_to_matrix_calculator.cc"], @@ -25,7 +37,7 @@ cc_library( "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:ret_check", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer:handedness_util", + "//mediapipe/tasks/cc/vision/gesture_recognizer:handedness_util", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -53,11 +65,11 @@ cc_library( name = "landmarks_to_matrix_calculator", srcs = ["landmarks_to_matrix_calculator.cc"], deps = [ + ":landmarks_to_matrix_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:ret_check", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:landmarks_to_matrix_calculator_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator.cc similarity index 90% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator.cc index 746293d21..b6c973a1b 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator.cc @@ -26,14 +26,16 @@ limitations under the License. #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h" +// TODO Update to use API2 namespace mediapipe { -namespace tasks { -namespace vision { +namespace api2 { namespace { +using ::mediapipe::tasks::vision::gesture_recognizer::GetLeftHandScore; + constexpr char kHandednessTag[] = "HANDEDNESS"; constexpr char kHandednessMatrixTag[] = "HANDEDNESS_MATRIX"; @@ -71,6 +73,8 @@ class HandednessToMatrixCalculator : public CalculatorBase { return absl::OkStatus(); } + // TODO remove this after change to API2, because Setting offset + // to 0 is the default in API2 absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); return absl::OkStatus(); @@ -95,6 +99,5 @@ absl::Status HandednessToMatrixCalculator::Process(CalculatorContext* cc) { return absl::OkStatus(); } -} // namespace vision -} // namespace tasks +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc similarity index 97% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc index c93c48ac5..17b16bf80 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc @@ -28,8 +28,6 @@ limitations under the License. #include "mediapipe/framework/port/status_matchers.h" namespace mediapipe { -namespace tasks { -namespace vision { namespace { @@ -95,6 +93,4 @@ INSTANTIATE_TEST_CASE_P( } // namespace -} // namespace vision -} // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc similarity index 96% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc index 990e99920..b70689eaf 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc @@ -27,13 +27,11 @@ limitations under the License. #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h" +// TODO Update to use API2 namespace mediapipe { -namespace tasks { -namespace vision { - -using proto::LandmarksToMatrixCalculatorOptions; +namespace api2 { namespace { @@ -175,7 +173,7 @@ absl::Status ProcessLandmarks(LandmarkListT landmarks, CalculatorContext* cc) { // input_stream: "IMAGE_SIZE:image_size" // output_stream: "LANDMARKS_MATRIX:landmarks_matrix" // options { -// [mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions.ext] { +// [mediapipe.LandmarksToMatrixCalculatorOptions.ext] { // object_normalization: true // object_normalization_origin_offset: 0 // } @@ -221,6 +219,5 @@ absl::Status LandmarksToMatrixCalculator::Process(CalculatorContext* cc) { return absl::OkStatus(); } -} // namespace vision -} // namespace tasks +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.proto similarity index 97% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.proto rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.proto index 6b004e203..10b034447 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.vision.proto; +package mediapipe; import "mediapipe/framework/calculator.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc similarity index 96% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc index 05d238f66..8a68d8dae 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc @@ -28,8 +28,6 @@ limitations under the License. #include "mediapipe/framework/port/status_matchers.h" namespace mediapipe { -namespace tasks { -namespace vision { namespace { @@ -72,8 +70,7 @@ TEST_P(Landmarks2dToMatrixCalculatorTest, OutputsCorrectResult) { input_stream: "IMAGE_SIZE:image_size" output_stream: "LANDMARKS_MATRIX:landmarks_matrix" options { - [mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions - .ext] { + [mediapipe.LandmarksToMatrixCalculatorOptions.ext] { object_normalization: $0 object_normalization_origin_offset: $1 } @@ -145,8 +142,7 @@ TEST_P(LandmarksWorld3dToMatrixCalculatorTest, OutputsCorrectResult) { input_stream: "IMAGE_SIZE:image_size" output_stream: "LANDMARKS_MATRIX:landmarks_matrix" options { - [mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions - .ext] { + [mediapipe.LandmarksToMatrixCalculatorOptions.ext] { object_normalization: $0 object_normalization_origin_offset: $1 } @@ -202,6 +198,4 @@ INSTANTIATE_TEST_CASE_P( } // namespace -} // namespace vision -} // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc similarity index 80% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 247d8453d..05bc607ae 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -34,14 +34,15 @@ limitations under the License. #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/utils.h" -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.pb.h" -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" namespace mediapipe { namespace tasks { namespace vision { +namespace gesture_recognizer { namespace { @@ -50,9 +51,8 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; -using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto:: - HandGestureRecognizerSubgraphOptions; -using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions; +using ::mediapipe::tasks::vision::gesture_recognizer::proto:: + HandGestureRecognizerGraphOptions; constexpr char kHandednessTag[] = "HANDEDNESS"; constexpr char kLandmarksTag[] = "LANDMARKS"; @@ -70,18 +70,6 @@ constexpr char kIndexTag[] = "INDEX"; constexpr char kIterableTag[] = "ITERABLE"; constexpr char kBatchEndTag[] = "BATCH_END"; -absl::Status SanityCheckOptions( - const HandGestureRecognizerSubgraphOptions& options) { - if (options.min_tracking_confidence() < 0 || - options.min_tracking_confidence() > 1) { - return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, - "Invalid `min_tracking_confidence` option: " - "value must be in the range [0.0, 1.0]", - MediaPipeTasksStatus::kInvalidArgumentError); - } - return absl::OkStatus(); -} - Source> ConvertMatrixToTensor(Source matrix, Graph& graph) { auto& node = graph.AddNode("TensorConverterCalculator"); @@ -91,9 +79,10 @@ Source> ConvertMatrixToTensor(Source matrix, } // namespace -// A "mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph" performs -// single hand gesture recognition. This graph is used as a building block for -// mediapipe.tasks.vision.HandGestureRecognizerGraph. +// A +// "mediapipe.tasks.vision.gesture_recognizer.SingleHandGestureRecognizerGraph" +// performs single hand gesture recognition. This graph is used as a building +// block for mediapipe.tasks.vision.GestureRecognizerGraph. // // Inputs: // HANDEDNESS - ClassificationList @@ -113,14 +102,15 @@ Source> ConvertMatrixToTensor(Source matrix, // // Example: // node { -// calculator: "mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph" +// calculator: +// "mediapipe.tasks.vision.gesture_recognizer.SingleHandGestureRecognizerGraph" // input_stream: "HANDEDNESS:handedness" // input_stream: "LANDMARKS:landmarks" // input_stream: "WORLD_LANDMARKS:world_landmarks" // input_stream: "IMAGE_SIZE:image_size" // output_stream: "HAND_GESTURES:hand_gestures" // options { -// [mediapipe.tasks.vision.hand_gesture_recognizer.proto.HandGestureRecognizerSubgraphOptions.ext] +// [mediapipe.tasks.vision.gesture_recognizer.proto.HandGestureRecognizerGraphOptions.ext] // { // base_options { // model_asset { @@ -130,19 +120,19 @@ Source> ConvertMatrixToTensor(Source matrix, // } // } // } -class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph { +class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { ASSIGN_OR_RETURN( const auto* model_resources, - CreateModelResources(sc)); + CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN( auto hand_gestures, - BuildHandGestureRecognizerGraph( - sc->Options(), - *model_resources, graph[Input(kHandednessTag)], + BuildGestureRecognizerGraph( + sc->Options(), *model_resources, + graph[Input(kHandednessTag)], graph[Input(kLandmarksTag)], graph[Input(kWorldLandmarksTag)], graph[Input>(kImageSizeTag)], graph)); @@ -151,15 +141,13 @@ class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph { } private: - absl::StatusOr> BuildHandGestureRecognizerGraph( - const HandGestureRecognizerSubgraphOptions& graph_options, + absl::StatusOr> BuildGestureRecognizerGraph( + const HandGestureRecognizerGraphOptions& graph_options, const core::ModelResources& model_resources, Source handedness, Source hand_landmarks, Source hand_world_landmarks, Source> image_size, Graph& graph) { - MP_RETURN_IF_ERROR(SanityCheckOptions(graph_options)); - // Converts the ClassificationList to a matrix. auto& handedness_to_matrix = graph.AddNode("HandednessToMatrixCalculator"); handedness >> handedness_to_matrix.In(kHandednessTag); @@ -235,12 +223,15 @@ class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph { } }; +// clang-format off REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::vision::SingleHandGestureRecognizerSubgraph); + ::mediapipe::tasks::vision::gesture_recognizer::SingleHandGestureRecognizerGraph); // NOLINT +// clang-format on -// A "mediapipe.tasks.vision.HandGestureRecognizerSubgraph" performs multi -// hand gesture recognition. This graph is used as a building block for -// mediapipe.tasks.vision.HandGestureRecognizerGraph. +// A +// "mediapipe.tasks.vision.gesture_recognizer.MultipleHandGestureRecognizerGraph" +// performs multi hand gesture recognition. This graph is used as a building +// block for mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph. // // Inputs: // HANDEDNESS - std::vector @@ -263,7 +254,8 @@ REGISTER_MEDIAPIPE_GRAPH( // // Example: // node { -// calculator: "mediapipe.tasks.vision.HandGestureRecognizerSubgraph" +// calculator: +// "mediapipe.tasks.vision.gesture_recognizer.MultipleHandGestureRecognizerGraph" // input_stream: "HANDEDNESS:handedness" // input_stream: "LANDMARKS:landmarks" // input_stream: "WORLD_LANDMARKS:world_landmarks" @@ -271,7 +263,7 @@ REGISTER_MEDIAPIPE_GRAPH( // input_stream: "HAND_TRACKING_IDS:hand_tracking_ids" // output_stream: "HAND_GESTURES:hand_gestures" // options { -// [mediapipe.tasks.vision.hand_gesture_recognizer.proto.HandGestureRecognizerSubgraph.ext] +// [mediapipe.tasks.vision.gesture_recognizer.proto.MultipleHandGestureRecognizerGraph.ext] // { // base_options { // model_asset { @@ -281,15 +273,15 @@ REGISTER_MEDIAPIPE_GRAPH( // } // } // } -class HandGestureRecognizerSubgraph : public core::ModelTaskGraph { +class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; ASSIGN_OR_RETURN( auto multi_hand_gestures, - BuildMultiHandGestureRecognizerSubraph( - sc->Options(), + BuildMultiGestureRecognizerSubraph( + sc->Options(), graph[Input>(kHandednessTag)], graph[Input>(kLandmarksTag)], graph[Input>(kWorldLandmarksTag)], @@ -302,8 +294,8 @@ class HandGestureRecognizerSubgraph : public core::ModelTaskGraph { private: absl::StatusOr>> - BuildMultiHandGestureRecognizerSubraph( - const HandGestureRecognizerSubgraphOptions& graph_options, + BuildMultiGestureRecognizerSubraph( + const HandGestureRecognizerGraphOptions& graph_options, Source> multi_handedness, Source> multi_hand_landmarks, Source> multi_hand_world_landmarks, @@ -341,17 +333,18 @@ class HandGestureRecognizerSubgraph : public core::ModelTaskGraph { hand_tracking_id >> get_world_landmarks_at_index.In(kIndexTag); auto hand_world_landmarks = get_world_landmarks_at_index.Out(kItemTag); - auto& hand_gesture_recognizer_subgraph = graph.AddNode( - "mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph"); - hand_gesture_recognizer_subgraph - .GetOptions() + auto& hand_gesture_recognizer_graph = graph.AddNode( + "mediapipe.tasks.vision.gesture_recognizer." + "SingleHandGestureRecognizerGraph"); + hand_gesture_recognizer_graph + .GetOptions() .CopyFrom(graph_options); - handedness >> hand_gesture_recognizer_subgraph.In(kHandednessTag); - hand_landmarks >> hand_gesture_recognizer_subgraph.In(kLandmarksTag); + handedness >> hand_gesture_recognizer_graph.In(kHandednessTag); + hand_landmarks >> hand_gesture_recognizer_graph.In(kLandmarksTag); hand_world_landmarks >> - hand_gesture_recognizer_subgraph.In(kWorldLandmarksTag); - image_size_clone >> hand_gesture_recognizer_subgraph.In(kImageSizeTag); - auto hand_gestures = hand_gesture_recognizer_subgraph.Out(kHandGesturesTag); + hand_gesture_recognizer_graph.In(kWorldLandmarksTag); + image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag); + auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag); auto& end_loop_classification_results = graph.AddNode("mediapipe.tasks.EndLoopClassificationResultCalculator"); @@ -364,9 +357,12 @@ class HandGestureRecognizerSubgraph : public core::ModelTaskGraph { } }; +// clang-format off REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::vision::HandGestureRecognizerSubgraph); + ::mediapipe::tasks::vision::gesture_recognizer::MultipleHandGestureRecognizerGraph); // NOLINT +// clang-format on +} // namespace gesture_recognizer } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.cc similarity index 93% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.cc index 00e19cdb5..60ccae92c 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h" #include @@ -25,6 +25,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace gesture_recognizer { namespace {} // namespace @@ -58,6 +59,7 @@ absl::StatusOr GetLeftHandScore( } } +} // namespace gesture_recognizer } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h similarity index 79% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h rename to mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h index 74e04b8cc..ae4137d0f 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ -#define MEDIAPIPE_TASKS_CC_VISION_HAND_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ +#ifndef MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ #include "absl/status/statusor.h" #include "mediapipe/framework/formats/classification.pb.h" @@ -22,6 +22,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace gesture_recognizer { bool IsLeftHand(const mediapipe::Classification& c); @@ -30,8 +31,9 @@ bool IsRightHand(const mediapipe::Classification& c); absl::StatusOr GetLeftHandScore( const mediapipe::ClassificationList& classification_list); +} // namespace gesture_recognizer } // namespace vision } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ +#endif // MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util_test.cc similarity index 94% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util_test.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util_test.cc index 51dfb5dea..40a201ae8 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/port/gmock.h" @@ -23,6 +23,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace gesture_recognizer { namespace { TEST(GetLeftHandScore, SingleLeftHandClassification) { @@ -72,6 +73,7 @@ TEST(GetLeftHandScore, LeftAndRightLowerCaseHandClassification) { } } // namespace +} // namespace gesture_recognizer } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD similarity index 73% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD rename to mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD index 44ec611b2..cb6ec8289 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD @@ -21,8 +21,8 @@ package(default_visibility = [ licenses(["notice"]) mediapipe_proto_library( - name = "hand_gesture_recognizer_subgraph_options_proto", - srcs = ["hand_gesture_recognizer_subgraph_options.proto"], + name = "hand_gesture_recognizer_graph_options_proto", + srcs = ["hand_gesture_recognizer_graph_options.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -30,12 +30,3 @@ mediapipe_proto_library( "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) - -mediapipe_proto_library( - name = "landmarks_to_matrix_calculator_proto", - srcs = ["landmarks_to_matrix_calculator.proto"], - deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - ], -) diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto similarity index 89% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto rename to mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto index d8ee95037..ac8cda15c 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto @@ -15,15 +15,15 @@ limitations under the License. // TODO Refactor naming and class structure of hand related Tasks. syntax = "proto2"; -package mediapipe.tasks.vision.hand_gesture_recognizer.proto; +package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; -message HandGestureRecognizerSubgraphOptions { +message HandGestureRecognizerGraphOptions { extend mediapipe.CalculatorOptions { - optional HandGestureRecognizerSubgraphOptions ext = 463370452; + optional HandGestureRecognizerGraphOptions ext = 463370452; } // Base options for configuring hand gesture recognition subgraph, such as // specifying the TfLite model file with metadata, accelerator options, etc. diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index 7ef8d62f5..8573d718f 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -46,6 +46,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace hand_detector { namespace { @@ -139,9 +140,9 @@ void ConfigureRectTransformationCalculator( } // namespace -// A "mediapipe.tasks.vision.HandDetectorGraph" performs hand detection. The -// Hand Detection Graph is based on palm detection model, and scale the detected -// palm bounding box to enclose the detected whole hand. +// A "mediapipe.tasks.vision.hand_detector.HandDetectorGraph" performs hand +// detection. The Hand Detection Graph is based on palm detection model, and +// scale the detected palm bounding box to enclose the detected whole hand. // Accepts CPU input images and outputs Landmark on CPU. // // Inputs: @@ -161,14 +162,15 @@ void ConfigureRectTransformationCalculator( // // Example: // node { -// calculator: "mediapipe.tasks.vision.HandDetectorGraph" +// calculator: "mediapipe.tasks.vision.hand_detector.HandDetectorGraph" // input_stream: "IMAGE:image" // output_stream: "PALM_DETECTIONS:palm_detections" // output_stream: "HAND_RECTS:hand_rects_from_palm_detections" // output_stream: "PALM_RECTS:palm_rects" // output_stream: "IMAGE:image_out" // options { -// [mediapipe.tasks.hand_detector.proto.HandDetectorGraphOptions.ext] { +// [mediapipe.tasks.vision.hand_detector.proto.HandDetectorGraphOptions.ext] +// { // base_options { // model_asset { // file_name: "palm_detection.tflite" @@ -334,8 +336,10 @@ class HandDetectorGraph : public core::ModelTaskGraph { } }; -REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::HandDetectorGraph); +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::hand_detector::HandDetectorGraph); +} // namespace hand_detector } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc index 850ff2732..11cfc3026 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc @@ -47,6 +47,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace hand_detector { namespace { using ::file::Defaults; @@ -105,7 +106,7 @@ absl::StatusOr> CreateTaskRunner( Graph graph; auto& hand_detection = - graph.AddNode("mediapipe.tasks.vision.HandDetectorGraph"); + graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph"); auto options = std::make_unique(); options->mutable_base_options()->mutable_model_asset()->set_file_name( @@ -201,6 +202,7 @@ INSTANTIATE_TEST_SUITE_P( }); } // namespace +} // namespace hand_detector } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index 6041d528f..ab3403d53 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -216,7 +216,7 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { DisallowIf(image_in, has_enough_hands, graph); auto& hand_detector = - graph.AddNode("mediapipe.tasks.vision.HandDetectorGraph"); + graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph"); hand_detector.GetOptions().CopyFrom( tasks_options.hand_detector_graph_options()); image_for_hand_detector >> hand_detector.In("IMAGE"); From 25e424baaf2a0399584fee2a00d1468c1b479154 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 4 Oct 2022 00:22:35 -0700 Subject: [PATCH 021/132] Rename hand landmarker related graphs. PiperOrigin-RevId: 478706652 --- .../tasks/cc/vision/gesture_recognizer/BUILD | 2 +- .../tasks/cc/vision/hand_landmarker/BUILD | 10 +-- .../hand_landmarker/hand_landmarker_graph.cc | 44 ++++++------ .../hand_landmarker_graph_test.cc | 4 +- ...ph.cc => hand_landmarks_detector_graph.cc} | 71 +++++++++++-------- ... => hand_landmarks_detector_graph_test.cc} | 18 ++--- .../cc/vision/hand_landmarker/proto/BUILD | 6 +- .../proto/hand_landmarker_graph_options.proto | 5 +- ...nd_landmarks_detector_graph_options.proto} | 4 +- 9 files changed, 90 insertions(+), 74 deletions(-) rename mediapipe/tasks/cc/vision/hand_landmarker/{hand_landmarker_subgraph.cc => hand_landmarks_detector_graph.cc} (91%) rename mediapipe/tasks/cc/vision/hand_landmarker/{hand_landmarker_subgraph_test.cc => hand_landmarks_detector_graph_test.cc} (97%) rename mediapipe/tasks/cc/vision/hand_landmarker/proto/{hand_landmarker_subgraph_options.proto => hand_landmarks_detector_graph_options.proto} (92%) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index cb392873e..c9319e946 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -66,7 +66,7 @@ cc_library( "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator", "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_subgraph", + "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "//mediapipe/tasks/metadata:metadata_schema_cc", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index c968c17fa..a2bb458db 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -19,10 +19,10 @@ package(default_visibility = [ licenses(["notice"]) cc_library( - name = "hand_landmarker_subgraph", - srcs = ["hand_landmarker_subgraph.cc"], + name = "hand_landmarks_detector_graph", + srcs = ["hand_landmarks_detector_graph.cc"], deps = [ - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_subgraph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "//mediapipe/calculators/core:split_vector_calculator", @@ -72,7 +72,7 @@ cc_library( name = "hand_landmarker_graph", srcs = ["hand_landmarker_graph.cc"], deps = [ - ":hand_landmarker_subgraph", + ":hand_landmarks_detector_graph", "//mediapipe/calculators/core:begin_loop_calculator", "//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto", "//mediapipe/calculators/core:end_loop_calculator", @@ -99,7 +99,7 @@ cc_library( "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator", "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_subgraph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", ], alwayslink = 1, ) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index ab3403d53..949c06520 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -36,7 +36,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h" #include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" -#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" namespace mediapipe { namespace tasks { @@ -55,7 +55,7 @@ using ::mediapipe::tasks::vision::hand_detector::proto:: using ::mediapipe::tasks::vision::hand_landmarker::proto:: HandLandmarkerGraphOptions; using ::mediapipe::tasks::vision::hand_landmarker::proto:: - HandLandmarkerSubgraphOptions; + HandLandmarksDetectorGraphOptions; constexpr char kImageTag[] = "IMAGE"; constexpr char kLandmarksTag[] = "LANDMARKS"; @@ -78,14 +78,14 @@ struct HandLandmarkerOutputs { } // namespace -// A "mediapipe.tasks.vision.HandLandmarkerGraph" performs hand +// A "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" performs hand // landmarks detection. The HandLandmarkerGraph consists of two subgraphs: -// HandDetectorGraph and HandLandmarkerSubgraph. HandLandmarkerSubgraph detects -// landmarks from bounding boxes produced by HandDetectorGraph. -// HandLandmarkerGraph tracks the landmarks over time, and skips the -// HandDetectorGraph. If the tracking is lost or the detectd hands are -// less than configured max number hands, HandDetectorGraph would be triggered -// to detect hands. +// HandDetectorGraph and MultipleHandLandmarksDetectorGraph. +// MultipleHandLandmarksDetectorGraph detects landmarks from bounding boxes +// produced by HandDetectorGraph. HandLandmarkerGraph tracks the landmarks over +// time, and skips the HandDetectorGraph. If the tracking is lost or the detectd +// hands are less than configured max number hands, HandDetectorGraph would be +// triggered to detect hands. // // Accepts CPU input images and outputs Landmarks on CPU. // @@ -113,7 +113,7 @@ struct HandLandmarkerOutputs { // // Example: // node { -// calculator: "mediapipe.tasks.vision.HandLandmarkerGraph" +// calculator: "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" // input_stream: "IMAGE:image_in" // output_stream: "LANDMARKS:hand_landmarks" // output_stream: "WORLD_LANDMARKS:world_hand_landmarks" @@ -138,7 +138,7 @@ struct HandLandmarkerOutputs { // min_detection_confidence: 0.5 // num_hands: 2 // } -// hand_landmarker_subgraph_options { +// hand_landmarks_detector_graph_options { // base_options { // model_asset { // file_name: "hand_landmark_lite.tflite" @@ -238,15 +238,17 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { hand_rects >> clip_hand_rects.In(""); auto clipped_hand_rects = clip_hand_rects.Out(""); - auto& hand_landmarker_subgraph = graph.AddNode( - "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerSubgraph"); - hand_landmarker_subgraph.GetOptions() - .CopyFrom(tasks_options.hand_landmarker_subgraph_options()); - image_in >> hand_landmarker_subgraph.In("IMAGE"); - clipped_hand_rects >> hand_landmarker_subgraph.In("HAND_RECT"); + auto& hand_landmarks_detector_graph = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker." + "MultipleHandLandmarksDetectorGraph"); + hand_landmarks_detector_graph + .GetOptions() + .CopyFrom(tasks_options.hand_landmarks_detector_graph_options()); + image_in >> hand_landmarks_detector_graph.In("IMAGE"); + clipped_hand_rects >> hand_landmarks_detector_graph.In("HAND_RECT"); auto hand_rects_for_next_frame = - hand_landmarker_subgraph[Output>( + hand_landmarks_detector_graph[Output>( kHandRectNextFrameTag)]; // Back edge. hand_rects_for_next_frame >> previous_loopback.In("LOOP"); @@ -257,13 +259,13 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { image_in >> pass_through.In(""); return {{ - /* landmark_lists= */ hand_landmarker_subgraph + /* landmark_lists= */ hand_landmarks_detector_graph [Output>(kLandmarksTag)], /* world_landmark_lists= */ - hand_landmarker_subgraph[Output>( + hand_landmarks_detector_graph[Output>( kWorldLandmarksTag)], /* hand_rects_next_frame= */ hand_rects_for_next_frame, - hand_landmarker_subgraph[Output>( + hand_landmarks_detector_graph[Output>( kHandednessTag)], /* palm_rects= */ hand_detector[Output>(kPalmRectsTag)], diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc index 413af68ff..bce5613ff 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc @@ -38,7 +38,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" -#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" @@ -111,7 +111,7 @@ absl::StatusOr> CreateTaskRunner() { ->set_file_name(JoinPath("./", kTestDataDirectory, kPalmDetectionModel)); options.mutable_hand_detector_graph_options()->mutable_base_options(); options.mutable_hand_detector_graph_options()->set_num_hands(kMaxNumHands); - options.mutable_hand_landmarker_subgraph_options() + options.mutable_hand_landmarks_detector_graph_options() ->mutable_base_options() ->mutable_model_asset() ->set_file_name( diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc similarity index 91% rename from mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph.cc rename to mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index 0ac4686b7..23521790d 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -40,7 +40,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" -#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/util/label_map.pb.h" @@ -60,7 +60,7 @@ using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::utils::AllowIf; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::vision::hand_landmarker::proto:: - HandLandmarkerSubgraphOptions; + HandLandmarksDetectorGraphOptions; using LabelItems = mediapipe::proto_ns::Map; constexpr char kImageTag[] = "IMAGE"; @@ -96,7 +96,8 @@ struct HandLandmarkerOutputs { Source> handednesses; }; -absl::Status SanityCheckOptions(const HandLandmarkerSubgraphOptions& options) { +absl::Status SanityCheckOptions( + const HandLandmarksDetectorGraphOptions& options) { if (options.min_detection_confidence() < 0 || options.min_detection_confidence() > 1) { return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, @@ -183,8 +184,8 @@ void ConfigureHandRectTransformationCalculator( } // namespace -// A "mediapipe.tasks.vision.SingleHandLandmarkerSubgraph" performs hand -// landmark detection. +// A "mediapipe.tasks.vision.hand_landmarker.SingleHandLandmarksDetectorGraph" +// performs hand landmarks detection. // - Accepts CPU input images and outputs Landmark on CPU. // // Inputs: @@ -212,7 +213,8 @@ void ConfigureHandRectTransformationCalculator( // // Example: // node { -// calculator: "mediapipe.tasks.vision.SingleHandLandmarkerSubgraph" +// calculator: +// "mediapipe.tasks.vision.hand_landmarker.SingleHandLandmarksDetectorGraph" // input_stream: "IMAGE:input_image" // input_stream: "HAND_RECT:hand_rect" // output_stream: "LANDMARKS:hand_landmarks" @@ -221,7 +223,7 @@ void ConfigureHandRectTransformationCalculator( // output_stream: "PRESENCE:hand_presence" // output_stream: "PRESENCE_SCORE:hand_presence_score" // options { -// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarkerSubgraphOptions.ext] +// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarksDetectorGraphOptions.ext] // { // base_options { // model_asset { @@ -232,16 +234,17 @@ void ConfigureHandRectTransformationCalculator( // } // } // } -class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { +class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { - ASSIGN_OR_RETURN(const auto* model_resources, - CreateModelResources(sc)); + ASSIGN_OR_RETURN( + const auto* model_resources, + CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN(auto hand_landmark_detection_outs, - BuildSingleHandLandmarkerSubgraph( - sc->Options(), + BuildSingleHandLandmarksDetectorGraph( + sc->Options(), *model_resources, graph[Input(kImageTag)], graph[Input(kHandRectTag)], graph)); hand_landmark_detection_outs.hand_landmarks >> @@ -264,14 +267,16 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { // Adds a mediapipe hand landmark detection graph into the provided // builder::Graph instance. // - // subgraph_options: the mediapipe tasks module HandLandmarkerSubgraphOptions. - // model_resources: the ModelSources object initialized from a hand landmark + // subgraph_options: the mediapipe tasks module + // HandLandmarksDetectorGraphOptions. model_resources: the ModelSources object + // initialized from a hand landmark // detection model file with model metadata. // image_in: (mediapipe::Image) stream to run hand landmark detection on. // rect: (NormalizedRect) stream to run on the RoI of image. // graph: the mediapipe graph instance to be updated. - absl::StatusOr BuildSingleHandLandmarkerSubgraph( - const HandLandmarkerSubgraphOptions& subgraph_options, + absl::StatusOr + BuildSingleHandLandmarksDetectorGraph( + const HandLandmarksDetectorGraphOptions& subgraph_options, const core::ModelResources& model_resources, Source image_in, Source hand_rect, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options)); @@ -415,11 +420,13 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { } }; +// clang-format off REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::vision::hand_landmarker::SingleHandLandmarkerSubgraph); + ::mediapipe::tasks::vision::hand_landmarker::SingleHandLandmarksDetectorGraph); // NOLINT +// clang-format on -// A "mediapipe.tasks.vision.HandLandmarkerSubgraph" performs multi hand -// landmark detection. +// A "mediapipe.tasks.vision.hand_landmarker.MultipleHandLandmarksDetectorGraph" +// performs multi hand landmark detection. // - Accepts CPU input image and a vector of hand rect RoIs to detect the // multiple hands landmarks enclosed by the RoIs. Output vectors of // hand landmarks related results, where each element in the vectors @@ -450,7 +457,8 @@ REGISTER_MEDIAPIPE_GRAPH( // // Example: // node { -// calculator: "mediapipe.tasks.vision.HandLandmarkerSubgraph" +// calculator: +// "mediapipe.tasks.vision.hand_landmarker.MultipleHandLandmarksDetectorGraph" // input_stream: "IMAGE:input_image" // input_stream: "HAND_RECT:hand_rect" // output_stream: "LANDMARKS:hand_landmarks" @@ -460,7 +468,7 @@ REGISTER_MEDIAPIPE_GRAPH( // output_stream: "PRESENCE_SCORE:hand_presence_score" // output_stream: "HANDEDNESS:handedness" // options { -// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarkerSubgraphOptions.ext] +// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarksDetectorGraphOptions.ext] // { // base_options { // model_asset { @@ -471,15 +479,15 @@ REGISTER_MEDIAPIPE_GRAPH( // } // } // } -class HandLandmarkerSubgraph : public core::ModelTaskGraph { +class MultipleHandLandmarksDetectorGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; ASSIGN_OR_RETURN( auto hand_landmark_detection_outputs, - BuildHandLandmarkerSubgraph( - sc->Options(), + BuildHandLandmarksDetectorGraph( + sc->Options(), graph[Input(kImageTag)], graph[Input>(kHandRectTag)], graph)); hand_landmark_detection_outputs.landmark_lists >> @@ -499,14 +507,15 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph { } private: - absl::StatusOr BuildHandLandmarkerSubgraph( - const HandLandmarkerSubgraphOptions& subgraph_options, + absl::StatusOr BuildHandLandmarksDetectorGraph( + const HandLandmarksDetectorGraphOptions& subgraph_options, Source image_in, Source> multi_hand_rects, Graph& graph) { auto& hand_landmark_subgraph = graph.AddNode( - "mediapipe.tasks.vision.hand_landmarker.SingleHandLandmarkerSubgraph"); - hand_landmark_subgraph.GetOptions().CopyFrom( - subgraph_options); + "mediapipe.tasks.vision.hand_landmarker." + "SingleHandLandmarksDetectorGraph"); + hand_landmark_subgraph.GetOptions() + .CopyFrom(subgraph_options); auto& begin_loop_multi_hand_rects = graph.AddNode("BeginLoopNormalizedRectCalculator"); @@ -580,8 +589,10 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph { } }; +// clang-format off REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::vision::hand_landmarker::HandLandmarkerSubgraph); + ::mediapipe::tasks::vision::hand_landmarker::MultipleHandLandmarksDetectorGraph); // NOLINT +// clang-format on } // namespace hand_landmarker } // namespace vision diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc similarity index 97% rename from mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph_test.cc rename to mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc index 7d91dc3c7..d1e928ce7 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc @@ -39,7 +39,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" -#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" namespace mediapipe { @@ -58,7 +58,7 @@ using ::mediapipe::file::JoinPath; using ::mediapipe::tasks::core::TaskRunner; using ::mediapipe::tasks::vision::DecodeImageFromFile; using ::mediapipe::tasks::vision::hand_landmarker::proto:: - HandLandmarkerSubgraphOptions; + HandLandmarksDetectorGraphOptions; using ::testing::ElementsAreArray; using ::testing::EqualsProto; using ::testing::Pointwise; @@ -114,12 +114,13 @@ absl::StatusOr> CreateSingleHandTaskRunner( Graph graph; auto& hand_landmark_detection = graph.AddNode( - "mediapipe.tasks.vision.hand_landmarker.SingleHandLandmarkerSubgraph"); + "mediapipe.tasks.vision.hand_landmarker." + "SingleHandLandmarksDetectorGraph"); - auto options = std::make_unique(); + auto options = std::make_unique(); options->mutable_base_options()->mutable_model_asset()->set_file_name( JoinPath("./", kTestDataDirectory, model_name)); - hand_landmark_detection.GetOptions().Swap( + hand_landmark_detection.GetOptions().Swap( options.get()); graph[Input(kImageTag)].SetName(kImageName) >> @@ -153,12 +154,13 @@ absl::StatusOr> CreateMultiHandTaskRunner( Graph graph; auto& multi_hand_landmark_detection = graph.AddNode( - "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerSubgraph"); + "mediapipe.tasks.vision.hand_landmarker." + "MultipleHandLandmarksDetectorGraph"); - auto options = std::make_unique(); + auto options = std::make_unique(); options->mutable_base_options()->mutable_model_asset()->set_file_name( JoinPath("./", kTestDataDirectory, model_name)); - multi_hand_landmark_detection.GetOptions() + multi_hand_landmark_detection.GetOptions() .Swap(options.get()); graph[Input(kImageTag)].SetName(kImageName) >> diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD index 9d1ba6f90..945b12f3e 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD @@ -21,8 +21,8 @@ package(default_visibility = [ licenses(["notice"]) mediapipe_proto_library( - name = "hand_landmarker_subgraph_options_proto", - srcs = ["hand_landmarker_subgraph_options.proto"], + name = "hand_landmarks_detector_graph_options_proto", + srcs = ["hand_landmarks_detector_graph_options.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -34,7 +34,7 @@ mediapipe_proto_library( name = "hand_landmarker_graph_options_proto", srcs = ["hand_landmarker_graph_options.proto"], deps = [ - ":hand_landmarker_subgraph_options_proto", + ":hand_landmarks_detector_graph_options_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto index 13849ec5e..7f3536b09 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto @@ -20,7 +20,7 @@ package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto"; -import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto"; +import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto"; message HandLandmarkerGraphOptions { extend mediapipe.CalculatorOptions { @@ -35,7 +35,8 @@ message HandLandmarkerGraphOptions { hand_detector_graph_options = 2; // Options for hand landmarker subgraph. - optional HandLandmarkerSubgraphOptions hand_landmarker_subgraph_options = 3; + optional HandLandmarksDetectorGraphOptions + hand_landmarks_detector_graph_options = 3; // Minimum confidence for hand landmarks tracking to be considered // successfully. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto similarity index 92% rename from mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto rename to mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto index 02d18e8ab..8c0fc66f2 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto @@ -20,9 +20,9 @@ package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; -message HandLandmarkerSubgraphOptions { +message HandLandmarksDetectorGraphOptions { extend mediapipe.CalculatorOptions { - optional HandLandmarkerSubgraphOptions ext = 474472470; + optional HandLandmarksDetectorGraphOptions ext = 474472470; } // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. From 14eb6fe62220eb38886c141db72b677b1c281c98 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 4 Oct 2022 02:32:11 -0700 Subject: [PATCH 022/132] Ensure that the REGISTER_DRISHTI_GRAPH argument is fit on one line in the OSS version. PiperOrigin-RevId: 478729958 --- .../processors/classification_postprocessing_graph.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index 35adab687..cd5933ee6 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -507,8 +507,11 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { } }; -REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::components::processors:: - ClassificationPostprocessingGraph); // NOLINT +// REGISTER_MEDIAPIPE_GRAPH argument has to fit on one line to work properly. +// clang-format off +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::components::processors::ClassificationPostprocessingGraph); // NOLINT +// clang-format on } // namespace processors } // namespace components From 05209a43923001ca04acee9ded9d4fc7a593f9ee Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 4 Oct 2022 04:36:56 -0700 Subject: [PATCH 023/132] Refactor mediapipe_aar.bzl to expose `mediapipe_java_proto_srcs`, `mediapipe_logging_java_proto_srcs`, and `mediapipe_java_proto_src_extractor`. PiperOrigin-RevId: 478750184 --- .../com/google/mediapipe/mediapipe_aar.bzl | 194 +++++++++--------- 1 file changed, 92 insertions(+), 102 deletions(-) diff --git a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl index ed1686954..7f2cb146c 100644 --- a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl +++ b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl @@ -89,10 +89,6 @@ def mediapipe_aar( calculators = calculators, ) - _mediapipe_proto( - name = name + "_proto", - ) - native.genrule( name = name + "_aar_manifest_generator", outs = ["AndroidManifest.xml"], @@ -115,19 +111,10 @@ EOF "//mediapipe/java/com/google/mediapipe/components:java_src", "//mediapipe/java/com/google/mediapipe/framework:java_src", "//mediapipe/java/com/google/mediapipe/glutil:java_src", - "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java", - "com/google/mediapipe/formats/proto/ClassificationProto.java", - "com/google/mediapipe/formats/proto/DetectionProto.java", - "com/google/mediapipe/formats/proto/LandmarkProto.java", - "com/google/mediapipe/formats/proto/LocationDataProto.java", - "com/google/mediapipe/proto/CalculatorProto.java", - ] + + ] + mediapipe_java_proto_srcs() + select({ "//conditions:default": [], - "enable_stats_logging": [ - "com/google/mediapipe/proto/MediaPipeLoggingProto.java", - "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java", - ], + "enable_stats_logging": mediapipe_logging_java_proto_srcs(), }), manifest = "AndroidManifest.xml", proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"], @@ -179,93 +166,6 @@ EOF _aar_with_jni(name, name + "_android_lib") -def _mediapipe_proto(name): - """Generates MediaPipe java proto libraries. - - Args: - name: the name of the target. - """ - _proto_java_src_generator( - name = "mediapipe_log_extension_proto", - proto_src = "mediapipe/util/analytics/mediapipe_log_extension.proto", - java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java", - srcs = ["//mediapipe/util/analytics:protos_src"], - ) - - _proto_java_src_generator( - name = "mediapipe_logging_enums_proto", - proto_src = "mediapipe/util/analytics/mediapipe_logging_enums.proto", - java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java", - srcs = ["//mediapipe/util/analytics:protos_src"], - ) - - _proto_java_src_generator( - name = "calculator_proto", - proto_src = "mediapipe/framework/calculator.proto", - java_lite_out = "com/google/mediapipe/proto/CalculatorProto.java", - srcs = ["//mediapipe/framework:protos_src"], - ) - - _proto_java_src_generator( - name = "landmark_proto", - proto_src = "mediapipe/framework/formats/landmark.proto", - java_lite_out = "com/google/mediapipe/formats/proto/LandmarkProto.java", - srcs = ["//mediapipe/framework/formats:protos_src"], - ) - - _proto_java_src_generator( - name = "rasterization_proto", - proto_src = "mediapipe/framework/formats/annotation/rasterization.proto", - java_lite_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java", - srcs = ["//mediapipe/framework/formats/annotation:protos_src"], - ) - - _proto_java_src_generator( - name = "location_data_proto", - proto_src = "mediapipe/framework/formats/location_data.proto", - java_lite_out = "com/google/mediapipe/formats/proto/LocationDataProto.java", - srcs = [ - "//mediapipe/framework/formats:protos_src", - "//mediapipe/framework/formats/annotation:protos_src", - ], - ) - - _proto_java_src_generator( - name = "detection_proto", - proto_src = "mediapipe/framework/formats/detection.proto", - java_lite_out = "com/google/mediapipe/formats/proto/DetectionProto.java", - srcs = [ - "//mediapipe/framework/formats:protos_src", - "//mediapipe/framework/formats/annotation:protos_src", - ], - ) - - _proto_java_src_generator( - name = "classification_proto", - proto_src = "mediapipe/framework/formats/classification.proto", - java_lite_out = "com/google/mediapipe/formats/proto/ClassificationProto.java", - srcs = [ - "//mediapipe/framework/formats:protos_src", - ], - ) - -def _proto_java_src_generator(name, proto_src, java_lite_out, srcs = []): - native.genrule( - name = name + "_proto_java_src_generator", - srcs = srcs + [ - "@com_google_protobuf//:lite_well_known_protos", - ], - outs = [java_lite_out], - cmd = "$(location @com_google_protobuf//:protoc) " + - "--proto_path=. --proto_path=$(GENDIR) " + - "--proto_path=$$(pwd)/external/com_google_protobuf/src " + - "--java_out=lite:$(GENDIR) " + proto_src + " && " + - "mv $(GENDIR)/" + java_lite_out + " $$(dirname $(location " + java_lite_out + "))", - tools = [ - "@com_google_protobuf//:protoc", - ], - ) - def _mediapipe_jni(name, gen_libmediapipe, calculators = []): """Generates MediaPipe jni library. @@ -345,3 +245,93 @@ cp -r lib jni zip -r $$origdir/$(location :{}.aar) jni/*/*.so """.format(android_library, name, name, name, name), ) + +def mediapipe_java_proto_src_extractor(target, src_out, name = ""): + """Extracts the generated MediaPipe java proto source code from the target. + + Args: + target: The java proto lite target to be built and extracted. + src_out: The output java proto src code path. + name: The optional bazel target name. + + Returns: + The output java proto src code path. + """ + + if not name: + name = target.split(":")[-1] + "_proto_java_src_extractor" + src_jar = target.replace("_java_proto_lite", "_proto-lite-src.jar").replace(":", "/").replace("//", "") + native.genrule( + name = name + "_proto_java_src_extractor", + srcs = [target], + outs = [src_out], + cmd = "unzip $(GENDIR)/" + src_jar + " -d $(GENDIR) && mv $(GENDIR)/" + + src_out + " $$(dirname $(location " + src_out + "))", + ) + return src_out + +def mediapipe_java_proto_srcs(name = ""): + """Extracts the generated MediaPipe framework java proto source code. + + Args: + name: The optional bazel target name. + + Returns: + The list of the extrated MediaPipe java proto source code. + """ + + proto_src_list = [] + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework:calculator_java_proto_lite", + src_out = "com/google/mediapipe/proto/CalculatorProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats:landmark_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite", + src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats:location_data_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats:detection_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/DetectionProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats:classification_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java", + )) + return proto_src_list + +def mediapipe_logging_java_proto_srcs(name = ""): + """Extracts the generated logging-related MediaPipe java proto source code. + + Args: + name: The optional bazel target name. + + Returns: + The list of the extrated MediaPipe logging-related java proto source code. + """ + + proto_src_list = [] + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/util/analytics:mediapipe_log_extension_java_proto_lite", + src_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/util/analytics:mediapipe_logging_enums_java_proto_lite", + src_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java", + )) + return proto_src_list From 8d5cf9bbedb2ee97651f57feeb273f0da48e4f78 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 4 Oct 2022 09:40:03 -0700 Subject: [PATCH 024/132] Open source MediaPipe object detector task reference app prototype. PiperOrigin-RevId: 478811683 --- .../instantmotiontracking/GIFEditText.java | 2 +- mediapipe/tasks/examples/android/BUILD | 21 ++ .../src/main/AndroidManifest.xml | 37 +++ .../android/objectdetector/src/main/BUILD | 48 ++++ .../examples/objectdetector/MainActivity.java | 236 ++++++++++++++++++ .../ObjectDetectionResultImageView.java | 77 ++++++ .../drawable-v24/ic_launcher_foreground.xml | 34 +++ .../res/drawable/ic_launcher_background.xml | 74 ++++++ .../android/res/layout/activity_main.xml | 40 +++ .../res/mipmap-anydpi-v26/ic_launcher.xml | 5 + .../mipmap-anydpi-v26/ic_launcher_round.xml | 5 + .../android/res/mipmap-hdpi/ic_launcher.png | Bin 0 -> 1354 bytes .../mipmap-hdpi/ic_launcher_foreground.png | Bin 0 -> 2257 bytes .../res/mipmap-hdpi/ic_launcher_round.png | Bin 0 -> 3246 bytes .../android/res/mipmap-mdpi/ic_launcher.png | Bin 0 -> 959 bytes .../mipmap-mdpi/ic_launcher_foreground.png | Bin 0 -> 900 bytes .../res/mipmap-mdpi/ic_launcher_round.png | Bin 0 -> 1955 bytes .../android/res/mipmap-xhdpi/ic_launcher.png | Bin 0 -> 1971 bytes .../mipmap-xhdpi/ic_launcher_foreground.png | Bin 0 -> 1845 bytes .../res/mipmap-xhdpi/ic_launcher_round.png | Bin 0 -> 4658 bytes .../android/res/mipmap-xxhdpi/ic_launcher.png | Bin 0 -> 3562 bytes .../mipmap-xxhdpi/ic_launcher_foreground.png | Bin 0 -> 5655 bytes .../res/mipmap-xxhdpi/ic_launcher_round.png | Bin 0 -> 7745 bytes .../res/mipmap-xxxhdpi/ic_launcher.png | Bin 0 -> 5004 bytes .../mipmap-xxxhdpi/ic_launcher_foreground.png | Bin 0 -> 8278 bytes .../res/mipmap-xxxhdpi/ic_launcher_round.png | Bin 0 -> 11062 bytes .../examples/android/res/values/colors.xml | 6 + .../examples/android/res/values/strings.xml | 6 + .../examples/android/res/values/styles.xml | 11 + 29 files changed, 601 insertions(+), 1 deletion(-) create mode 100644 mediapipe/tasks/examples/android/BUILD create mode 100644 mediapipe/tasks/examples/android/objectdetector/src/main/AndroidManifest.xml create mode 100644 mediapipe/tasks/examples/android/objectdetector/src/main/BUILD create mode 100644 mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java create mode 100644 mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java create mode 100644 mediapipe/tasks/examples/android/res/drawable-v24/ic_launcher_foreground.xml create mode 100644 mediapipe/tasks/examples/android/res/drawable/ic_launcher_background.xml create mode 100644 mediapipe/tasks/examples/android/res/layout/activity_main.xml create mode 100644 mediapipe/tasks/examples/android/res/mipmap-anydpi-v26/ic_launcher.xml create mode 100644 mediapipe/tasks/examples/android/res/mipmap-anydpi-v26/ic_launcher_round.xml create mode 100644 mediapipe/tasks/examples/android/res/mipmap-hdpi/ic_launcher.png create mode 100644 mediapipe/tasks/examples/android/res/mipmap-hdpi/ic_launcher_foreground.png create mode 100644 mediapipe/tasks/examples/android/res/mipmap-hdpi/ic_launcher_round.png create mode 100644 mediapipe/tasks/examples/android/res/mipmap-mdpi/ic_launcher.png create mode 100644 mediapipe/tasks/examples/android/res/mipmap-mdpi/ic_launcher_foreground.png create mode 100644 mediapipe/tasks/examples/android/res/mipmap-mdpi/ic_launcher_round.png create mode 100644 mediapipe/tasks/examples/android/res/mipmap-xhdpi/ic_launcher.png create mode 100644 mediapipe/tasks/examples/android/res/mipmap-xhdpi/ic_launcher_foreground.png create mode 100644 mediapipe/tasks/examples/android/res/mipmap-xhdpi/ic_launcher_round.png create mode 100644 mediapipe/tasks/examples/android/res/mipmap-xxhdpi/ic_launcher.png create mode 100644 mediapipe/tasks/examples/android/res/mipmap-xxhdpi/ic_launcher_foreground.png create mode 100644 mediapipe/tasks/examples/android/res/mipmap-xxhdpi/ic_launcher_round.png create mode 100644 mediapipe/tasks/examples/android/res/mipmap-xxxhdpi/ic_launcher.png create mode 100644 mediapipe/tasks/examples/android/res/mipmap-xxxhdpi/ic_launcher_foreground.png create mode 100644 mediapipe/tasks/examples/android/res/mipmap-xxxhdpi/ic_launcher_round.png create mode 100644 mediapipe/tasks/examples/android/res/values/colors.xml create mode 100644 mediapipe/tasks/examples/android/res/values/strings.xml create mode 100644 mediapipe/tasks/examples/android/res/values/styles.xml diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java index 1b733ed82..10e6422ba 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java @@ -18,7 +18,7 @@ import android.content.ClipDescription; import android.content.Context; import android.net.Uri; import android.os.Bundle; -import androidx.appcompat.widget.AppCompatEditText; +import android.support.v7.widget.AppCompatEditText; import android.util.AttributeSet; import android.util.Log; import android.view.inputmethod.EditorInfo; diff --git a/mediapipe/tasks/examples/android/BUILD b/mediapipe/tasks/examples/android/BUILD new file mode 100644 index 000000000..c07af2d2c --- /dev/null +++ b/mediapipe/tasks/examples/android/BUILD @@ -0,0 +1,21 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +filegroup( + name = "resource_files", + srcs = glob(["res/**"]), + visibility = ["//mediapipe/tasks/examples/android:__subpackages__"], +) diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/AndroidManifest.xml b/mediapipe/tasks/examples/android/objectdetector/src/main/AndroidManifest.xml new file mode 100644 index 000000000..5c53dc269 --- /dev/null +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/AndroidManifest.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD b/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD new file mode 100644 index 000000000..65b98d647 --- /dev/null +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD @@ -0,0 +1,48 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//visibility:private"]) + +android_binary( + name = "objectdetector", + srcs = glob(["**/*.java"]), + assets = [ + "//mediapipe/tasks/testdata/vision:test_models", + ], + assets_dir = "", + custom_package = "com.google.mediapipe.tasks.examples.objectdetector", + manifest = "AndroidManifest.xml", + manifest_values = { + "applicationId": "com.google.mediapipe.tasks.examples.objectdetector", + }, + multidex = "native", + resource_files = ["//mediapipe/tasks/examples/android:resource_files"], + deps = [ + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector", + "//third_party:androidx_appcompat", + "//third_party:androidx_constraint_layout", + "//third_party:opencv", + "@maven//:androidx_activity_activity", + "@maven//:androidx_concurrent_concurrent_futures", + "@maven//:androidx_exifinterface_exifinterface", + "@maven//:androidx_fragment_fragment", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java new file mode 100644 index 000000000..7f7ec1389 --- /dev/null +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java @@ -0,0 +1,236 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.examples.objectdetector; + +import android.content.Intent; +import android.graphics.Bitmap; +import android.graphics.Matrix; +import android.media.MediaMetadataRetriever; +import android.os.Bundle; +import android.provider.MediaStore; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Log; +import android.view.View; +import android.widget.Button; +import android.widget.FrameLayout; +import androidx.activity.result.ActivityResultLauncher; +import androidx.activity.result.contract.ActivityResultContracts; +import androidx.exifinterface.media.ExifInterface; +// ContentResolver dependency +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult; +import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector; +import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions; +import java.io.IOException; +import java.io.InputStream; + +/** Main activity of MediaPipe Task Object Detector reference app. */ +public class MainActivity extends AppCompatActivity { + private static final String TAG = "MainActivity"; + private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; + + private ObjectDetector objectDetector; + + private enum InputSource { + UNKNOWN, + IMAGE, + VIDEO, + CAMERA, + } + + private InputSource inputSource = InputSource.UNKNOWN; + + // Image mode demo component. + private ActivityResultLauncher imageGetter; + // Video mode demo component. + private ActivityResultLauncher videoGetter; + private ObjectDetectionResultImageView imageView; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + setupImageModeDemo(); + setupVideoModeDemo(); + // TODO: Adds live camera demo. + } + + /** Sets up the image mode demo. */ + private void setupImageModeDemo() { + imageView = new ObjectDetectionResultImageView(this); + // The Intent to access gallery and read images as bitmap. + imageGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + Bitmap bitmap = null; + try { + bitmap = + downscaleBitmap( + MediaStore.Images.Media.getBitmap( + this.getContentResolver(), resultIntent.getData())); + } catch (IOException e) { + Log.e(TAG, "Bitmap reading error:" + e); + } + try { + InputStream imageData = + this.getContentResolver().openInputStream(resultIntent.getData()); + bitmap = rotateBitmap(bitmap, imageData); + } catch (IOException e) { + Log.e(TAG, "Bitmap rotation error:" + e); + } + if (bitmap != null) { + Image image = new BitmapImageBuilder(bitmap).build(); + ObjectDetectionResult detectionResult = objectDetector.detect(image); + imageView.setData(image, detectionResult); + runOnUiThread(() -> imageView.update()); + } + } + } + }); + Button loadImageButton = findViewById(R.id.button_load_picture); + loadImageButton.setOnClickListener( + v -> { + if (inputSource != InputSource.IMAGE) { + createObjectDetector(RunningMode.IMAGE); + this.inputSource = InputSource.IMAGE; + updateLayout(); + } + // Reads images from gallery. + Intent pickImageIntent = new Intent(Intent.ACTION_PICK); + pickImageIntent.setDataAndType(MediaStore.Images.Media.INTERNAL_CONTENT_URI, "image/*"); + imageGetter.launch(pickImageIntent); + }); + } + + /** Sets up the video mode demo. */ + private void setupVideoModeDemo() { + imageView = new ObjectDetectionResultImageView(this); + // The Intent to access gallery and read a video file. + videoGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + MediaMetadataRetriever metaRetriever = new MediaMetadataRetriever(); + metaRetriever.setDataSource(this, resultIntent.getData()); + long duration = + Long.parseLong( + metaRetriever.extractMetadata( + MediaMetadataRetriever.METADATA_KEY_DURATION)); + int numFrames = + Integer.parseInt( + metaRetriever.extractMetadata( + MediaMetadataRetriever.METADATA_KEY_VIDEO_FRAME_COUNT)); + long frameIntervalMs = duration / numFrames; + for (int i = 0; i < numFrames; ++i) { + Image image = new BitmapImageBuilder(metaRetriever.getFrameAtIndex(i)).build(); + ObjectDetectionResult detectionResult = + objectDetector.detectForVideo(image, frameIntervalMs * i); + // Currently only annotates the detection result on the first video frame and + // display it to verify the correctness. + // TODO: Annotates the detection result on every frame, save the + // annotated frames as a video file, and play back the video afterwards. + if (i == 0) { + imageView.setData(image, detectionResult); + runOnUiThread(() -> imageView.update()); + } + } + } + } + }); + Button loadVideoButton = findViewById(R.id.button_load_video); + loadVideoButton.setOnClickListener( + v -> { + createObjectDetector(RunningMode.VIDEO); + updateLayout(); + this.inputSource = InputSource.VIDEO; + + // Reads a video from gallery. + Intent pickVideoIntent = new Intent(Intent.ACTION_PICK); + pickVideoIntent.setDataAndType(MediaStore.Video.Media.INTERNAL_CONTENT_URI, "video/*"); + videoGetter.launch(pickVideoIntent); + }); + } + + private void createObjectDetector(RunningMode mode) { + if (objectDetector != null) { + objectDetector.close(); + } + // Initializes a new MediaPipe ObjectDetector instance + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setScoreThreshold(0.5f) + .setMaxResults(5) + .setRunningMode(mode) + .build(); + objectDetector = ObjectDetector.createFromOptions(this, options); + } + + private void updateLayout() { + // Updates the preview layout. + FrameLayout frameLayout = findViewById(R.id.preview_display_layout); + frameLayout.removeAllViewsInLayout(); + imageView.setImageDrawable(null); + frameLayout.addView(imageView); + imageView.setVisibility(View.VISIBLE); + } + + private Bitmap downscaleBitmap(Bitmap originalBitmap) { + double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight(); + int width = imageView.getWidth(); + int height = imageView.getHeight(); + if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) { + width = (int) (height * aspectRatio); + } else { + height = (int) (width / aspectRatio); + } + return Bitmap.createScaledBitmap(originalBitmap, width, height, false); + } + + private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException { + int orientation = + new ExifInterface(imageData) + .getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL); + if (orientation == ExifInterface.ORIENTATION_NORMAL) { + return inputBitmap; + } + Matrix matrix = new Matrix(); + switch (orientation) { + case ExifInterface.ORIENTATION_ROTATE_90: + matrix.postRotate(90); + break; + case ExifInterface.ORIENTATION_ROTATE_180: + matrix.postRotate(180); + break; + case ExifInterface.ORIENTATION_ROTATE_270: + matrix.postRotate(270); + break; + default: + matrix.postRotate(0); + } + return Bitmap.createBitmap( + inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true); + } +} diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java new file mode 100644 index 000000000..94a4a90dc --- /dev/null +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java @@ -0,0 +1,77 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.examples.objectdetector; + +import android.content.Context; +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Matrix; +import android.graphics.Paint; +import androidx.appcompat.widget.AppCompatImageView; +import com.google.mediapipe.framework.image.BitmapExtractor; +import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.tasks.components.containers.Detection; +import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult; + +/** An ImageView implementation for displaying {@link ObjectDetectionResult}. */ +public class ObjectDetectionResultImageView extends AppCompatImageView { + private static final String TAG = "ObjectDetectionResultImageView"; + + private static final int BBOX_COLOR = Color.GREEN; + private static final int BBOX_THICKNESS = 5; // Pixels + private Bitmap latest; + + public ObjectDetectionResultImageView(Context context) { + super(context); + setScaleType(AppCompatImageView.ScaleType.FIT_CENTER); + } + + /** + * Sets an {@link Image} and an {@link ObjectDetectionResult} to render. + * + * @param image an {@link Image} object for annotation. + * @param result an {@link ObjectDetectionResult} object that contains the detection result. + */ + public void setData(Image image, ObjectDetectionResult result) { + if (image == null || result == null) { + return; + } + latest = BitmapExtractor.extract(image); + Canvas canvas = new Canvas(latest); + canvas.drawBitmap(latest, new Matrix(), null); + for (int i = 0; i < result.detections().size(); ++i) { + drawDetectionOnCanvas(result.detections().get(i), canvas); + } + } + + /** Updates the image view with the latest {@link ObjectDetectionResult}. */ + public void update() { + postInvalidate(); + if (latest != null) { + setImageBitmap(latest); + } + } + + private void drawDetectionOnCanvas(Detection detection, Canvas canvas) { + // TODO: Draws the category and the score per bounding box. + // Draws bounding box. + Paint bboxPaint = new Paint(); + bboxPaint.setColor(BBOX_COLOR); + bboxPaint.setStyle(Paint.Style.STROKE); + bboxPaint.setStrokeWidth(BBOX_THICKNESS); + canvas.drawRect(detection.boundingBox(), bboxPaint); + } +} diff --git a/mediapipe/tasks/examples/android/res/drawable-v24/ic_launcher_foreground.xml b/mediapipe/tasks/examples/android/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 000000000..c7bd21dbd --- /dev/null +++ b/mediapipe/tasks/examples/android/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + diff --git a/mediapipe/tasks/examples/android/res/drawable/ic_launcher_background.xml b/mediapipe/tasks/examples/android/res/drawable/ic_launcher_background.xml new file mode 100644 index 000000000..01f0af0ad --- /dev/null +++ b/mediapipe/tasks/examples/android/res/drawable/ic_launcher_background.xml @@ -0,0 +1,74 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/examples/android/res/layout/activity_main.xml b/mediapipe/tasks/examples/android/res/layout/activity_main.xml new file mode 100644 index 000000000..834e9a3e6 --- /dev/null +++ b/mediapipe/tasks/examples/android/res/layout/activity_main.xml @@ -0,0 +1,40 @@ + + + +