From 35bb18945f21856f62cd99027f7702b92411dfc5 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 5 Dec 2022 07:22:51 -0800 Subject: [PATCH 01/64] Better handling of empty packets in vector calculators. PiperOrigin-RevId: 493000695 --- .../core/get_vector_item_calculator.h | 9 +++-- .../core/get_vector_item_calculator.proto | 3 ++ .../core/get_vector_item_calculator_test.cc | 34 ++++++++++++++----- .../core/merge_to_vector_calculator.cc | 4 +++ .../core/merge_to_vector_calculator.h | 15 ++++++-- 5 files changed, 51 insertions(+), 14 deletions(-) diff --git a/mediapipe/calculators/core/get_vector_item_calculator.h b/mediapipe/calculators/core/get_vector_item_calculator.h index dc98ccfe7..25d90bfe6 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.h +++ b/mediapipe/calculators/core/get_vector_item_calculator.h @@ -65,6 +65,7 @@ class GetVectorItemCalculator : public Node { MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut); absl::Status Open(CalculatorContext* cc) final { + cc->SetOffset(mediapipe::TimestampDiff(0)); auto& options = cc->Options(); RET_CHECK(kIdx(cc).IsConnected() || options.has_item_index()); return absl::OkStatus(); @@ -90,8 +91,12 @@ class GetVectorItemCalculator : public Node { return absl::OkStatus(); } - RET_CHECK(idx >= 0 && idx < items.size()); - kOut(cc).Send(items[idx]); + RET_CHECK(idx >= 0); + RET_CHECK(options.output_empty_on_oob() || idx < items.size()); + + if (idx < items.size()) { + kOut(cc).Send(items[idx]); + } return absl::OkStatus(); } diff --git a/mediapipe/calculators/core/get_vector_item_calculator.proto b/mediapipe/calculators/core/get_vector_item_calculator.proto index c406283e4..9cfb579e4 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.proto +++ b/mediapipe/calculators/core/get_vector_item_calculator.proto @@ -26,4 +26,7 @@ message GetVectorItemCalculatorOptions { // Index of vector item to get. INDEX input stream can be used instead, or to // override. optional int32 item_index = 1; + + // Set to true to output an empty packet when the index is out of bounds. + optional bool output_empty_on_oob = 2; } diff --git a/mediapipe/calculators/core/get_vector_item_calculator_test.cc b/mediapipe/calculators/core/get_vector_item_calculator_test.cc index c148aa9d1..c2974e20a 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator_test.cc +++ b/mediapipe/calculators/core/get_vector_item_calculator_test.cc @@ -32,18 +32,21 @@ CalculatorRunner MakeRunnerWithStream() { )"); } -CalculatorRunner MakeRunnerWithOptions(int set_index) { - return CalculatorRunner(absl::StrFormat(R"( +CalculatorRunner MakeRunnerWithOptions(int set_index, + bool output_empty_on_oob = false) { + return CalculatorRunner( + absl::StrFormat(R"( calculator: "TestGetIntVectorItemCalculator" input_stream: "VECTOR:vector_stream" output_stream: "ITEM:item_stream" options { [mediapipe.GetVectorItemCalculatorOptions.ext] { item_index: %d + output_empty_on_oob: %s } } )", - set_index)); + set_index, output_empty_on_oob ? "true" : "false")); } void AddInputVector(CalculatorRunner& runner, const std::vector& inputs, @@ -140,8 +143,7 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail1) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0")); } TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) { @@ -155,7 +157,8 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + testing::HasSubstr( + "options.output_empty_on_oob() || idx < items.size()")); } TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) { @@ -167,8 +170,7 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0")); } TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) { @@ -181,7 +183,21 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) { absl::Status status = runner.Run(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.message(), - testing::HasSubstr("idx >= 0 && idx < items.size()")); + testing::HasSubstr( + "options.output_empty_on_oob() || idx < items.size()")); +} + +TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail3) { + const int try_index = 3; + CalculatorRunner runner = MakeRunnerWithOptions(try_index, true); + const std::vector inputs = {1, 2, 3}; + + AddInputVector(runner, inputs, 1); + + MP_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Tag("ITEM").packets; + EXPECT_THAT(outputs, testing::ElementsAre()); } TEST(TestGetIntVectorItemCalculatorTest, IndexStreamTwoTimestamps) { diff --git a/mediapipe/calculators/core/merge_to_vector_calculator.cc b/mediapipe/calculators/core/merge_to_vector_calculator.cc index cca64bc9a..5f05ad725 100644 --- a/mediapipe/calculators/core/merge_to_vector_calculator.cc +++ b/mediapipe/calculators/core/merge_to_vector_calculator.cc @@ -23,5 +23,9 @@ namespace api2 { typedef MergeToVectorCalculator MergeImagesToVectorCalculator; MEDIAPIPE_REGISTER_NODE(MergeImagesToVectorCalculator); +typedef MergeToVectorCalculator + MergeGpuBuffersToVectorCalculator; +MEDIAPIPE_REGISTER_NODE(MergeGpuBuffersToVectorCalculator); + } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/merge_to_vector_calculator.h b/mediapipe/calculators/core/merge_to_vector_calculator.h index bed616695..f63d86ee4 100644 --- a/mediapipe/calculators/core/merge_to_vector_calculator.h +++ b/mediapipe/calculators/core/merge_to_vector_calculator.h @@ -42,11 +42,20 @@ class MergeToVectorCalculator : public Node { return absl::OkStatus(); } + absl::Status Open(::mediapipe::CalculatorContext* cc) { + cc->SetOffset(::mediapipe::TimestampDiff(0)); + return absl::OkStatus(); + } + absl::Status Process(CalculatorContext* cc) { const int input_num = kIn(cc).Count(); - std::vector output_vector(input_num); - std::transform(kIn(cc).begin(), kIn(cc).end(), output_vector.begin(), - [](const auto& elem) -> T { return elem.Get(); }); + std::vector output_vector; + for (auto it = kIn(cc).begin(); it != kIn(cc).end(); it++) { + const auto& elem = *it; + if (!elem.IsEmpty()) { + output_vector.push_back(elem.Get()); + } + } kOut(cc).Send(output_vector); return absl::OkStatus(); } From 4f8eaee20f5c02d932b8bacecd1afb0655d84130 Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Mon, 5 Dec 2022 11:33:21 -0800 Subject: [PATCH 02/64] Internal change PiperOrigin-RevId: 493065632 --- mediapipe/graphs/iris_tracking/calculators/BUILD | 1 - mediapipe/java/com/google/mediapipe/framework/jni/BUILD | 7 +++---- mediapipe/modules/hand_landmark/calculators/BUILD | 1 - mediapipe/modules/objectron/calculators/BUILD | 4 ---- mediapipe/util/tracking/BUILD | 1 - 5 files changed, 3 insertions(+), 11 deletions(-) diff --git a/mediapipe/graphs/iris_tracking/calculators/BUILD b/mediapipe/graphs/iris_tracking/calculators/BUILD index 3a3d57a0f..f5124b464 100644 --- a/mediapipe/graphs/iris_tracking/calculators/BUILD +++ b/mediapipe/graphs/iris_tracking/calculators/BUILD @@ -97,7 +97,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:image_file_properties_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD index 4926e2f3c..4540f63a6 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD @@ -84,12 +84,11 @@ cc_library( deps = [ ":class_registry", ":jni_util", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:calculator_profile_cc_proto", - "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_profile_cc_proto", + "//mediapipe/framework:calculator_framework", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", diff --git a/mediapipe/modules/hand_landmark/calculators/BUILD b/mediapipe/modules/hand_landmark/calculators/BUILD index b2a8efe37..b42ec94de 100644 --- a/mediapipe/modules/hand_landmark/calculators/BUILD +++ b/mediapipe/modules/hand_landmark/calculators/BUILD @@ -24,7 +24,6 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", - "//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/modules/objectron/calculators/BUILD b/mediapipe/modules/objectron/calculators/BUILD index eeeaee5f4..14cea526f 100644 --- a/mediapipe/modules/objectron/calculators/BUILD +++ b/mediapipe/modules/objectron/calculators/BUILD @@ -275,7 +275,6 @@ cc_library( ":tflite_tensors_to_objects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/memory", @@ -299,7 +298,6 @@ cc_library( ":tensors_to_objects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/memory", @@ -316,13 +314,11 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":annotation_cc_proto", - ":belief_decoder_config_cc_proto", ":decoder", ":lift_2d_frame_annotation_to_3d_calculator_cc_proto", ":tensor_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/memory", diff --git a/mediapipe/util/tracking/BUILD b/mediapipe/util/tracking/BUILD index 6bca24446..816af2533 100644 --- a/mediapipe/util/tracking/BUILD +++ b/mediapipe/util/tracking/BUILD @@ -282,7 +282,6 @@ cc_library( srcs = ["motion_models_cv.cc"], hdrs = ["motion_models_cv.h"], deps = [ - ":camera_motion_cc_proto", ":motion_models", ":motion_models_cc_proto", "//mediapipe/framework/port:opencv_core", From 69b27b246a3f11e775791805eb2c2b4858ed9412 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 5 Dec 2022 14:14:13 -0800 Subject: [PATCH 03/64] Adds a public function for creating TaskRunner instances. PiperOrigin-RevId: 493109736 --- mediapipe/tasks/web/core/task_runner.ts | 46 ++++++++++++++++--------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index d769139bc..e2ab42e31 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -32,6 +32,34 @@ const GraphRunnerImageLibType = /** An implementation of the GraphRunner that supports image operations */ export class GraphRunnerImageLib extends GraphRunnerImageLibType {} +/** + * Creates a new instance of a Mediapipe Task. Determines if SIMD is + * supported and loads the relevant WASM binary. + * @return A fully instantiated instance of `T`. + */ +export async function +createTaskRunner, O extends TaskRunnerOptions>( + type: WasmMediaPipeConstructor, initializeCanvas: boolean, + fileset: WasmFileset, options: O): Promise { + const fileLocator: FileLocator = { + locateFile() { + // The only file loaded with this mechanism is the Wasm binary + return fileset.wasmBinaryPath.toString(); + } + }; + + // Initialize a canvas if requested. If OffscreenCanvas is availble, we + // let the graph runner initialize it by passing `undefined`. + const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ? + document.createElement('canvas') : + undefined) : + null; + const instance = await createMediaPipeLib( + type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); + await instance.setOptions(options); + return instance; +} + /** Base class for all MediaPipe Tasks. */ export abstract class TaskRunner { protected abstract baseOptions: BaseOptionsProto; @@ -47,23 +75,7 @@ export abstract class TaskRunner { O extends TaskRunnerOptions>( type: WasmMediaPipeConstructor, initializeCanvas: boolean, fileset: WasmFileset, options: O): Promise { - const fileLocator: FileLocator = { - locateFile() { - // The only file loaded with this mechanism is the Wasm binary - return fileset.wasmBinaryPath.toString(); - } - }; - - // Initialize a canvas if requested. If OffscreenCanvas is availble, we - // let the graph runner initialize it by passing `undefined`. - const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ? - document.createElement('canvas') : - undefined) : - null; - const instance = await createMediaPipeLib( - type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); - await instance.setOptions(options); - return instance; + return createTaskRunner(type, initializeCanvas, fileset, options); } constructor( From 3ad03bee0be95376cf4606d39b201dab5a0afcb5 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 5 Dec 2022 14:48:07 -0800 Subject: [PATCH 04/64] Adds missing visibility rule. PiperOrigin-RevId: 493118880 --- mediapipe/calculators/tensor/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 645189a07..577ac4111 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -366,6 +366,9 @@ cc_test( cc_library( name = "universal_sentence_encoder_preprocessor_calculator", srcs = ["universal_sentence_encoder_preprocessor_calculator.cc"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", From 99d1dd6fbb130f9f262365ae334b2ca22c819478 Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Mon, 5 Dec 2022 15:28:52 -0800 Subject: [PATCH 05/64] Internal change PiperOrigin-RevId: 493129643 --- docs/build_py_api_docs.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/docs/build_py_api_docs.py b/docs/build_py_api_docs.py index fe706acd3..46546012d 100644 --- a/docs/build_py_api_docs.py +++ b/docs/build_py_api_docs.py @@ -26,7 +26,6 @@ from absl import app from absl import flags from tensorflow_docs.api_generator import generate_lib -from tensorflow_docs.api_generator import public_api try: # mediapipe has not been set up to work with bazel yet, so catch & report. @@ -68,10 +67,7 @@ def gen_api_docs(): code_url_prefix=_URL_PREFIX.value, search_hints=_SEARCH_HINTS.value, site_path=_SITE_PATH.value, - # This callback ensures that docs are only generated for objects that - # are explicitly imported in your __init__.py files. There are other - # options but this is a good starting point. - callbacks=[public_api.explicit_package_contents_filter], + callbacks=[], ) doc_generator.build(_OUTPUT_DIR.value) From 1e76d47a71602ba0ac4a089f625bbd667a7f184b Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 5 Dec 2022 16:18:36 -0800 Subject: [PATCH 06/64] Checks if a custom global resource provider is used as the first step of loading the model resources on all platforms. PiperOrigin-RevId: 493141519 --- mediapipe/tasks/cc/core/BUILD | 1 + mediapipe/tasks/cc/core/model_resources.cc | 30 +++++++++++----------- mediapipe/util/resource_util.cc | 2 ++ mediapipe/util/resource_util_custom.h | 3 +++ 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 202f3ea3c..f8004d257 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -117,6 +117,7 @@ cc_library_with_tflite( "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/util:resource_util", + "//mediapipe/util:resource_util_custom", "//mediapipe/util/tflite:error_reporter", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/core/model_resources.cc b/mediapipe/tasks/cc/core/model_resources.cc index d5c12ee95..7819f6213 100644 --- a/mediapipe/tasks/cc/core/model_resources.cc +++ b/mediapipe/tasks/cc/core/model_resources.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/util/resource_util.h" +#include "mediapipe/util/resource_util_custom.h" #include "mediapipe/util/tflite/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -99,21 +100,20 @@ const tflite::Model* ModelResources::GetTfLiteModel() const { absl::Status ModelResources::BuildModelFromExternalFileProto() { if (model_file_->has_file_name()) { -#ifdef __EMSCRIPTEN__ - // In browsers, the model file may require a custom ResourceProviderFn to - // provide the model content. The open() method may not work in this case. - // Thus, loading the model content from the model file path in advance with - // the help of GetResourceContents. - MP_RETURN_IF_ERROR(mediapipe::GetResourceContents( - model_file_->file_name(), model_file_->mutable_file_content())); - model_file_->clear_file_name(); -#else - // If the model file name is a relative path, searches the file in a - // platform-specific location and returns the absolute path on success. - ASSIGN_OR_RETURN(std::string path_to_resource, - mediapipe::PathToResourceAsFile(model_file_->file_name())); - model_file_->set_file_name(path_to_resource); -#endif // __EMSCRIPTEN__ + if (HasCustomGlobalResourceProvider()) { + // If the model contents are provided via a custom ResourceProviderFn, the + // open() method may not work. Thus, loads the model content from the + // model file path in advance with the help of GetResourceContents. + MP_RETURN_IF_ERROR(GetResourceContents( + model_file_->file_name(), model_file_->mutable_file_content())); + model_file_->clear_file_name(); + } else { + // If the model file name is a relative path, searches the file in a + // platform-specific location and returns the absolute path on success. + ASSIGN_OR_RETURN(std::string path_to_resource, + PathToResourceAsFile(model_file_->file_name())); + model_file_->set_file_name(path_to_resource); + } } ASSIGN_OR_RETURN( model_file_handler_, diff --git a/mediapipe/util/resource_util.cc b/mediapipe/util/resource_util.cc index 8f40154a0..38636f32e 100644 --- a/mediapipe/util/resource_util.cc +++ b/mediapipe/util/resource_util.cc @@ -37,6 +37,8 @@ absl::Status GetResourceContents(const std::string& path, std::string* output, return internal::DefaultGetResourceContents(path, output, read_as_binary); } +bool HasCustomGlobalResourceProvider() { return resource_provider_ != nullptr; } + void SetCustomGlobalResourceProvider(ResourceProviderFn fn) { resource_provider_ = std::move(fn); } diff --git a/mediapipe/util/resource_util_custom.h b/mediapipe/util/resource_util_custom.h index 6bc1513c6..e74af8b2e 100644 --- a/mediapipe/util/resource_util_custom.h +++ b/mediapipe/util/resource_util_custom.h @@ -10,6 +10,9 @@ namespace mediapipe { typedef std::function ResourceProviderFn; +// Returns true if files are provided via a custom resource provider. +bool HasCustomGlobalResourceProvider(); + // Overrides the behavior of GetResourceContents. void SetCustomGlobalResourceProvider(ResourceProviderFn fn); From 3174b20fbe8225c35433d86f3a82d29645bb82bb Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 5 Dec 2022 17:32:57 -0800 Subject: [PATCH 07/64] Move segmentation calculator and options out of 'components' folder. PiperOrigin-RevId: 493157929 --- mediapipe/tasks/cc/components/proto/BUILD | 24 ------------------- .../tasks/cc/vision/image_segmenter/BUILD | 8 +++---- .../image_segmenter/calculators}/BUILD | 4 ++-- .../tensors_to_segmentation_calculator.cc | 14 +++++------ .../tensors_to_segmentation_calculator.proto | 6 +++-- ...tensors_to_segmentation_calculator_test.cc | 4 +--- .../vision/image_segmenter/image_segmenter.cc | 4 ++-- .../image_segmenter/image_segmenter_graph.cc | 6 ++--- .../image_segmenter/image_segmenter_test.cc | 2 +- .../cc/vision/image_segmenter/proto/BUILD | 7 +++++- .../proto/image_segmenter_graph_options.proto | 4 ++-- .../proto/segmenter_options.proto | 4 ++-- mediapipe/tasks/python/vision/BUILD | 2 +- .../tasks/python/vision/image_segmenter.py | 2 +- 14 files changed, 35 insertions(+), 56 deletions(-) delete mode 100644 mediapipe/tasks/cc/components/proto/BUILD rename mediapipe/tasks/cc/{components/calculators/tensor => vision/image_segmenter/calculators}/BUILD (94%) rename mediapipe/tasks/cc/{components/calculators/tensor => vision/image_segmenter/calculators}/tensors_to_segmentation_calculator.cc (95%) rename mediapipe/tasks/cc/{components/calculators/tensor => vision/image_segmenter/calculators}/tensors_to_segmentation_calculator.proto (82%) rename mediapipe/tasks/cc/{components/calculators/tensor => vision/image_segmenter/calculators}/tensors_to_segmentation_calculator_test.cc (99%) rename mediapipe/tasks/cc/{components => vision/image_segmenter}/proto/segmenter_options.proto (92%) diff --git a/mediapipe/tasks/cc/components/proto/BUILD b/mediapipe/tasks/cc/components/proto/BUILD deleted file mode 100644 index 569023753..000000000 --- a/mediapipe/tasks/cc/components/proto/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -mediapipe_proto_library( - name = "segmenter_options_proto", - srcs = ["segmenter_options.proto"], -) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 2124fe6e0..4c9c6e69c 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -28,7 +28,6 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", @@ -36,6 +35,7 @@ cc_library( "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", @@ -56,17 +56,17 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator", - "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator_cc_proto", "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_util", diff --git a/mediapipe/tasks/cc/components/calculators/tensor/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD similarity index 94% rename from mediapipe/tasks/cc/components/calculators/tensor/BUILD rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD index 6e4322a8f..dcd7fb407 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD @@ -25,7 +25,7 @@ mediapipe_proto_library( "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/framework/formats:image_format_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_proto", "//mediapipe/util:label_map_proto", ], ) @@ -45,7 +45,7 @@ cc_library( "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", - "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_utils", "//mediapipe/util:label_map_cc_proto", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc similarity index 95% rename from mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc index 40585848f..668de0057 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// TODO consolidate TensorsToSegmentationCalculator. #include #include #include @@ -35,14 +34,14 @@ limitations under the License. #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status_macros.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/util/label_map.pb.h" +// TODO: consolidate TensorToSegmentationCalculator. namespace mediapipe { namespace tasks { - namespace { using ::mediapipe::Image; @@ -51,9 +50,9 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Node; using ::mediapipe::api2::Output; using ::mediapipe::tasks::TensorsToSegmentationCalculatorOptions; -using ::mediapipe::tasks::components::proto::SegmenterOptions; using ::mediapipe::tasks::vision::GetImageLikeTensorShape; using ::mediapipe::tasks::vision::Shape; +using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; void StableSoftmax(absl::Span values, absl::Span activated_values) { @@ -90,7 +89,7 @@ void Sigmoid(absl::Span values, // the size to resize masks to. // // Output: -// Segmentation: Segmenation proto. +// Segmentation: Segmentation proto. // // Options: // See tensors_to_segmentation_calculator.proto @@ -132,8 +131,7 @@ class TensorsToSegmentationCalculator : public Node { absl::Status TensorsToSegmentationCalculator::Open( mediapipe::CalculatorContext* cc) { - options_ = - cc->Options(); + options_ = cc->Options(); RET_CHECK_NE(options_.segmenter_options().output_type(), SegmenterOptions::UNSPECIFIED) << "Must specify output_type as one of [CONFIDENCE_MASK|CATEGORY_MASK]."; diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto similarity index 82% rename from mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto index c26cf910a..b0fdfdd32 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto @@ -15,10 +15,11 @@ limitations under the License. syntax = "proto2"; +// TODO: consolidate TensorToSegmentationCalculator. package mediapipe.tasks; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/segmenter_options.proto"; +import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; import "mediapipe/util/label_map.proto"; message TensorsToSegmentationCalculatorOptions { @@ -26,7 +27,8 @@ message TensorsToSegmentationCalculatorOptions { optional TensorsToSegmentationCalculatorOptions ext = 458105876; } - optional components.proto.SegmenterOptions segmenter_options = 1; + optional mediapipe.tasks.vision.image_segmenter.proto.SegmenterOptions + segmenter_options = 1; // Identifying information for each classification label. map label_items = 2; diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc similarity index 99% rename from mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc rename to mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc index 55e46d72b..54fb9b816 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc @@ -33,10 +33,9 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" namespace mediapipe { -namespace api2 { namespace { @@ -374,5 +373,4 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) { expected_index, buffer_indices))); } -} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 6dce1b4ea..bbee714c6 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -18,12 +18,12 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" -#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" namespace mediapipe { namespace tasks { @@ -44,7 +44,7 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; -using ::mediapipe::tasks::components::proto::SegmenterOptions; +using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: image_segmenter::proto::ImageSegmenterGraphOptions; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index d5eb5af0d..5531968c1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -26,16 +26,16 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map_util.h" @@ -54,10 +54,10 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::MultiSource; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::proto::SegmenterOptions; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::mediapipe::tasks::vision::image_segmenter::proto:: ImageSegmenterGraphOptions; +using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ::tflite::Tensor; using ::tflite::TensorMetadata; using LabelItems = mediapipe::proto_ns::Map; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 752a116dd..d5ea088a1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -28,11 +28,11 @@ limitations under the License. #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/status_matchers.h" -#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD index 3b14060f1..9523dd679 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD @@ -18,13 +18,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +mediapipe_proto_library( + name = "segmenter_options_proto", + srcs = ["segmenter_options.proto"], +) + mediapipe_proto_library( name = "image_segmenter_graph_options_proto", srcs = ["image_segmenter_graph_options.proto"], deps = [ + ":segmenter_options_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:segmenter_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto index 166e2e8e0..4d8100842 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto @@ -18,8 +18,8 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_segmenter.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/segmenter_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.imagesegmenter.proto"; option java_outer_classname = "ImageSegmenterGraphOptionsProto"; @@ -37,5 +37,5 @@ message ImageSegmenterGraphOptions { optional string display_names_locale = 2 [default = "en"]; // Segmentation output options. - optional components.proto.SegmenterOptions segmenter_options = 3; + optional SegmenterOptions segmenter_options = 3; } diff --git a/mediapipe/tasks/cc/components/proto/segmenter_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto similarity index 92% rename from mediapipe/tasks/cc/components/proto/segmenter_options.proto rename to mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto index ca9986707..be2b8a589 100644 --- a/mediapipe/tasks/cc/components/proto/segmenter_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto @@ -15,9 +15,9 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components.proto; +package mediapipe.tasks.vision.image_segmenter.proto; -option java_package = "com.google.mediapipe.tasks.components.proto"; +option java_package = "com.google.mediapipe.tasks.vision.imagesegmenter.proto"; option java_outer_classname = "SegmenterOptionsProto"; // Shared options used by image segmentation tasks. diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index e94507eed..29e7577e8 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -69,8 +69,8 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", - "//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_py_pb2", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index 62fc8bb7c..22a37cb3e 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -21,8 +21,8 @@ from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet -from mediapipe.tasks.cc.components.proto import segmenter_options_pb2 from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2 +from mediapipe.tasks.cc.vision.image_segmenter.proto import segmenter_options_pb2 from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls From af43687f2e3c774ff7b0f1f4881d456952a6aadd Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 5 Dec 2022 20:07:29 -0800 Subject: [PATCH 08/64] Open-sources a unit test. PiperOrigin-RevId: 493184055 --- .../text_classifier/text_classifier_test.cc | 51 +++++++++++++++---- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc index 8f73914fc..799885eac 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc @@ -38,10 +38,7 @@ limitations under the License. #include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" -namespace mediapipe { -namespace tasks { -namespace text { -namespace text_classifier { +namespace mediapipe::tasks::text::text_classifier { namespace { using ::mediapipe::file::JoinPath; @@ -88,6 +85,8 @@ void ExpectApproximatelyEqual(const TextClassifierResult& actual, } } +} // namespace + class TextClassifierTest : public tflite_shims::testing::Test {}; TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) { @@ -217,8 +216,42 @@ TEST_F(TextClassifierTest, TextClassifierWithStringToBool) { MP_ASSERT_OK(classifier->Close()); } -} // namespace -} // namespace text_classifier -} // namespace text -} // namespace tasks -} // namespace mediapipe +TEST_F(TextClassifierTest, BertLongPositive) { + std::stringstream ss_for_positive_review; + ss_for_positive_review + << "it's a charming and often affecting journey and this is a long"; + for (int i = 0; i < kMaxSeqLen; ++i) { + ss_for_positive_review << " long"; + } + ss_for_positive_review << " movie review"; + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, + TextClassifier::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult result, + classifier->Classify(ss_for_positive_review.str())); + TextClassifierResult expected; + std::vector categories; + +// Predicted scores are slightly different on Mac OS. +#ifdef __APPLE__ + categories.push_back( + {/*index=*/1, /*score=*/0.974181, /*category_name=*/"positive"}); + categories.push_back( + {/*index=*/0, /*score=*/0.025819, /*category_name=*/"negative"}); +#else + categories.push_back( + {/*index=*/1, /*score=*/0.985889, /*category_name=*/"positive"}); + categories.push_back( + {/*index=*/0, /*score=*/0.014112, /*category_name=*/"negative"}); +#endif // __APPLE__ + + expected.classifications.emplace_back( + Classifications{/*categories=*/categories, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(result, expected); + MP_ASSERT_OK(classifier->Close()); +} + +} // namespace mediapipe::tasks::text::text_classifier From 1761cdcef4ff6fd37d04d15de765eccd7c0a5bcc Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Mon, 5 Dec 2022 22:11:00 -0800 Subject: [PATCH 09/64] Fix aar breakage caused by missing "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark". PiperOrigin-RevId: 493204770 --- .../java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index d91c03cc2..c6aba3c66 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -285,6 +285,7 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", From b0b38a0013c819a6db4156330cbbe2e0dab11bd8 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 6 Dec 2022 08:29:35 -0800 Subject: [PATCH 10/64] Internal change PiperOrigin-RevId: 493313240 --- mediapipe/gpu/metal_shared_resources.h | 40 +++++++++++ mediapipe/gpu/metal_shared_resources.mm | 73 ++++++++++++++++++++ mediapipe/gpu/metal_shared_resources_test.mm | 49 +++++++++++++ 3 files changed, 162 insertions(+) create mode 100644 mediapipe/gpu/metal_shared_resources.h create mode 100644 mediapipe/gpu/metal_shared_resources.mm create mode 100644 mediapipe/gpu/metal_shared_resources_test.mm diff --git a/mediapipe/gpu/metal_shared_resources.h b/mediapipe/gpu/metal_shared_resources.h new file mode 100644 index 000000000..341860a2d --- /dev/null +++ b/mediapipe/gpu/metal_shared_resources.h @@ -0,0 +1,40 @@ +#ifndef MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_ +#define MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_ + +#import +#import +#import +#import + +#ifndef __OBJC__ +#error This class must be built as Objective-C++. +#endif // !__OBJC__ + +@interface MPPMetalSharedResources : NSObject { +} + +- (instancetype)init NS_DESIGNATED_INITIALIZER; + +@property(readonly) id mtlDevice; +@property(readonly) id mtlCommandQueue; +#if COREVIDEO_SUPPORTS_METAL +@property(readonly) CVMetalTextureCacheRef mtlTextureCache; +#endif + +@end + +namespace mediapipe { + +class MetalSharedResources { + public: + MetalSharedResources(); + ~MetalSharedResources(); + MPPMetalSharedResources* resources() { return resources_; } + + private: + MPPMetalSharedResources* resources_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_METAL_SHARED_RESOURCES_H_ diff --git a/mediapipe/gpu/metal_shared_resources.mm b/mediapipe/gpu/metal_shared_resources.mm new file mode 100644 index 000000000..80d755a01 --- /dev/null +++ b/mediapipe/gpu/metal_shared_resources.mm @@ -0,0 +1,73 @@ +#import "mediapipe/gpu/metal_shared_resources.h" + +@interface MPPMetalSharedResources () +@end + +@implementation MPPMetalSharedResources { +} + +@synthesize mtlDevice = _mtlDevice; +@synthesize mtlCommandQueue = _mtlCommandQueue; +#if COREVIDEO_SUPPORTS_METAL +@synthesize mtlTextureCache = _mtlTextureCache; +#endif + +- (instancetype)init { + self = [super init]; + if (self) { + } + return self; +} + +- (void)dealloc { +#if COREVIDEO_SUPPORTS_METAL + if (_mtlTextureCache) { + CFRelease(_mtlTextureCache); + _mtlTextureCache = NULL; + } +#endif +} + +- (id)mtlDevice { + @synchronized(self) { + if (!_mtlDevice) { + _mtlDevice = MTLCreateSystemDefaultDevice(); + } + } + return _mtlDevice; +} + +- (id)mtlCommandQueue { + @synchronized(self) { + if (!_mtlCommandQueue) { + _mtlCommandQueue = [self.mtlDevice newCommandQueue]; + } + } + return _mtlCommandQueue; +} + +#if COREVIDEO_SUPPORTS_METAL +- (CVMetalTextureCacheRef)mtlTextureCache { + @synchronized(self) { + if (!_mtlTextureCache) { + CVReturn __unused err = + CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); + NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d ; device %@", err, + self.mtlDevice); + // TODO: register and flush metal caches too. + } + } + return _mtlTextureCache; +} +#endif + +@end + +namespace mediapipe { + +MetalSharedResources::MetalSharedResources() { + resources_ = [[MPPMetalSharedResources alloc] init]; +} +MetalSharedResources::~MetalSharedResources() {} + +} // namespace mediapipe diff --git a/mediapipe/gpu/metal_shared_resources_test.mm b/mediapipe/gpu/metal_shared_resources_test.mm new file mode 100644 index 000000000..9eb53a9b7 --- /dev/null +++ b/mediapipe/gpu/metal_shared_resources_test.mm @@ -0,0 +1,49 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#include + +#include "absl/memory/memory.h" +#include "mediapipe/framework/port/threadpool.h" + +#import "mediapipe/gpu/gpu_shared_data_internal.h" +#import "mediapipe/gpu/metal_shared_resources.h" + +@interface MPPMetalSharedResourcesTests : XCTestCase { +} +@end + +@implementation MPPMetalSharedResourcesTests + +// This test verifies that the internal Objective-C object is correctly +// released when the C++ wrapper is released. +- (void)testCorrectlyReleased { + __weak id metalRes = nil; + std::weak_ptr weakGpuRes; + @autoreleasepool { + auto maybeGpuRes = mediapipe::GpuResources::Create(); + XCTAssertTrue(maybeGpuRes.ok()); + weakGpuRes = *maybeGpuRes; + metalRes = (**maybeGpuRes).metal_shared().resources(); + XCTAssertNotEqual(weakGpuRes.lock(), nullptr); + XCTAssertNotNil(metalRes); + } + XCTAssertEqual(weakGpuRes.lock(), nullptr); + XCTAssertNil(metalRes); +} + +@end From fb0b96115f148c8c293f6cc3ddc7b3ed67b8043c Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 6 Dec 2022 08:33:51 -0800 Subject: [PATCH 11/64] Open up mediapipe core calculators' visibility. PiperOrigin-RevId: 493314353 --- mediapipe/calculators/core/BUILD | 88 +--------------------------- mediapipe/calculators/internal/BUILD | 6 +- 2 files changed, 4 insertions(+), 90 deletions(-) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 3b658eb5b..29bca5fa6 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -17,12 +17,11 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "concatenate_vector_calculator_proto", srcs = ["concatenate_vector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -32,7 +31,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "dequantize_byte_array_calculator_proto", srcs = ["dequantize_byte_array_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -42,7 +40,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_cloner_calculator_proto", srcs = ["packet_cloner_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -52,7 +49,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_resampler_calculator_proto", srcs = ["packet_resampler_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -62,7 +58,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_thinner_calculator_proto", srcs = ["packet_thinner_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -72,7 +67,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "split_vector_calculator_proto", srcs = ["split_vector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -82,7 +76,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "quantize_float_vector_calculator_proto", srcs = ["quantize_float_vector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -92,7 +85,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "sequence_shift_calculator_proto", srcs = ["sequence_shift_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -102,7 +94,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "gate_calculator_proto", srcs = ["gate_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -112,7 +103,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "constant_side_packet_calculator_proto", srcs = ["constant_side_packet_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -124,7 +114,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "clip_vector_size_calculator_proto", srcs = ["clip_vector_size_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -134,7 +123,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "flow_limiter_calculator_proto", srcs = ["flow_limiter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -144,7 +132,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "graph_profile_calculator_proto", srcs = ["graph_profile_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -154,7 +141,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "get_vector_item_calculator_proto", srcs = ["get_vector_item_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -164,7 +150,6 @@ mediapipe_proto_library( cc_library( name = "add_header_calculator", srcs = ["add_header_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -193,7 +178,6 @@ cc_library( name = "begin_loop_calculator", srcs = ["begin_loop_calculator.cc"], hdrs = ["begin_loop_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_contract", @@ -216,7 +200,6 @@ cc_library( name = "end_loop_calculator", srcs = ["end_loop_calculator.cc"], hdrs = ["end_loop_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_contract", @@ -258,7 +241,6 @@ cc_test( cc_library( name = "concatenate_vector_calculator_hdr", hdrs = ["concatenate_vector_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -284,7 +266,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework/api2:node", @@ -311,7 +292,6 @@ cc_library( cc_library( name = "concatenate_detection_vector_calculator", srcs = ["concatenate_detection_vector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator", "//mediapipe/framework:calculator_framework", @@ -323,7 +303,6 @@ cc_library( cc_library( name = "concatenate_proto_list_calculator", srcs = ["concatenate_proto_list_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -372,7 +351,6 @@ cc_library( name = "clip_vector_size_calculator", srcs = ["clip_vector_size_calculator.cc"], hdrs = ["clip_vector_size_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":clip_vector_size_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -388,7 +366,6 @@ cc_library( cc_library( name = "clip_detection_vector_size_calculator", srcs = ["clip_detection_vector_size_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":clip_vector_size_calculator", "//mediapipe/framework:calculator_framework", @@ -415,9 +392,6 @@ cc_test( cc_library( name = "counting_source_calculator", srcs = ["counting_source_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -430,9 +404,6 @@ cc_library( cc_library( name = "make_pair_calculator", srcs = ["make_pair_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -461,9 +432,6 @@ cc_test( cc_library( name = "matrix_multiply_calculator", srcs = ["matrix_multiply_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -477,9 +445,6 @@ cc_library( cc_library( name = "matrix_subtract_calculator", srcs = ["matrix_subtract_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -493,9 +458,6 @@ cc_library( cc_library( name = "mux_calculator", srcs = ["mux_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -508,9 +470,6 @@ cc_library( cc_library( name = "non_zero_calculator", srcs = ["non_zero_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -556,9 +515,6 @@ cc_test( cc_library( name = "packet_cloner_calculator", srcs = ["packet_cloner_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ ":packet_cloner_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -587,7 +543,6 @@ cc_test( cc_library( name = "packet_inner_join_calculator", srcs = ["packet_inner_join_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -611,7 +566,6 @@ cc_test( cc_library( name = "packet_thinner_calculator", srcs = ["packet_thinner_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -643,9 +597,6 @@ cc_test( cc_library( name = "pass_through_calculator", srcs = ["pass_through_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", @@ -656,9 +607,6 @@ cc_library( cc_library( name = "round_robin_demux_calculator", srcs = ["round_robin_demux_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -670,9 +618,6 @@ cc_library( cc_library( name = "immediate_mux_calculator", srcs = ["immediate_mux_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -684,7 +629,6 @@ cc_library( cc_library( name = "packet_presence_calculator", srcs = ["packet_presence_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", @@ -713,7 +657,6 @@ cc_test( cc_library( name = "previous_loopback_calculator", srcs = ["previous_loopback_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", @@ -729,7 +672,6 @@ cc_library( cc_library( name = "flow_limiter_calculator", srcs = ["flow_limiter_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":flow_limiter_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -746,7 +688,6 @@ cc_library( cc_library( name = "string_to_int_calculator", srcs = ["string_to_int_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:integral_types", @@ -759,7 +700,6 @@ cc_library( cc_library( name = "default_side_packet_calculator", srcs = ["default_side_packet_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -771,7 +711,6 @@ cc_library( cc_library( name = "side_packet_to_stream_calculator", srcs = ["side_packet_to_stream_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:logging", @@ -822,9 +761,6 @@ cc_library( name = "packet_resampler_calculator", srcs = ["packet_resampler_calculator.cc"], hdrs = ["packet_resampler_calculator.h"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -884,7 +820,6 @@ cc_test( cc_test( name = "matrix_multiply_calculator_test", srcs = ["matrix_multiply_calculator_test.cc"], - visibility = ["//visibility:private"], deps = [ ":matrix_multiply_calculator", "//mediapipe/framework:calculator_framework", @@ -900,7 +835,6 @@ cc_test( cc_test( name = "matrix_subtract_calculator_test", srcs = ["matrix_subtract_calculator_test.cc"], - visibility = ["//visibility:private"], deps = [ ":matrix_subtract_calculator", "//mediapipe/framework:calculator_framework", @@ -950,7 +884,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":split_vector_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -996,7 +929,6 @@ cc_test( cc_library( name = "split_proto_list_calculator", srcs = ["split_proto_list_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":split_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1028,7 +960,6 @@ cc_test( cc_library( name = "dequantize_byte_array_calculator", srcs = ["dequantize_byte_array_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":dequantize_byte_array_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -1054,7 +985,6 @@ cc_test( cc_library( name = "quantize_float_vector_calculator", srcs = ["quantize_float_vector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":quantize_float_vector_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -1080,7 +1010,6 @@ cc_test( cc_library( name = "sequence_shift_calculator", srcs = ["sequence_shift_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":sequence_shift_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1105,7 +1034,6 @@ cc_test( cc_library( name = "gate_calculator", srcs = ["gate_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":gate_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1131,7 +1059,6 @@ cc_test( cc_library( name = "matrix_to_vector_calculator", srcs = ["matrix_to_vector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1167,7 +1094,6 @@ cc_test( cc_library( name = "merge_calculator", srcs = ["merge_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1193,7 +1119,6 @@ cc_test( cc_library( name = "stream_to_side_packet_calculator", srcs = ["stream_to_side_packet_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -1219,7 +1144,6 @@ cc_test( cc_library( name = "constant_side_packet_calculator", srcs = ["constant_side_packet_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":constant_side_packet_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1249,7 +1173,6 @@ cc_test( cc_library( name = "graph_profile_calculator", srcs = ["graph_profile_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":graph_profile_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1291,7 +1214,6 @@ cc_library( name = "get_vector_item_calculator", srcs = ["get_vector_item_calculator.cc"], hdrs = ["get_vector_item_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":get_vector_item_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1325,7 +1247,6 @@ cc_library( name = "vector_indices_calculator", srcs = ["vector_indices_calculator.cc"], hdrs = ["vector_indices_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1351,7 +1272,6 @@ cc_library( name = "vector_size_calculator", srcs = ["vector_size_calculator.cc"], hdrs = ["vector_size_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1365,9 +1285,6 @@ cc_library( cc_library( name = "packet_sequencer_calculator", srcs = ["packet_sequencer_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:contract", @@ -1402,7 +1319,6 @@ cc_library( name = "merge_to_vector_calculator", srcs = ["merge_to_vector_calculator.cc"], hdrs = ["merge_to_vector_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -1416,7 +1332,6 @@ cc_library( mediapipe_proto_library( name = "bypass_calculator_proto", srcs = ["bypass_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1426,7 +1341,6 @@ mediapipe_proto_library( cc_library( name = "bypass_calculator", srcs = ["bypass_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":bypass_calculator_cc_proto", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/calculators/internal/BUILD b/mediapipe/calculators/internal/BUILD index 54b6c20f1..caade2dc3 100644 --- a/mediapipe/calculators/internal/BUILD +++ b/mediapipe/calculators/internal/BUILD @@ -21,7 +21,7 @@ package(default_visibility = ["//visibility:private"]) proto_library( name = "callback_packet_calculator_proto", srcs = ["callback_packet_calculator.proto"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = ["//mediapipe/framework:calculator_proto"], ) @@ -29,14 +29,14 @@ mediapipe_cc_proto_library( name = "callback_packet_calculator_cc_proto", srcs = ["callback_packet_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [":callback_packet_calculator_proto"], ) cc_library( name = "callback_packet_calculator", srcs = ["callback_packet_calculator.cc"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ ":callback_packet_calculator_cc_proto", "//mediapipe/framework:calculator_base", From ab0b0ab558c633bc996c41923f9325269cc76e3c Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 6 Dec 2022 10:22:31 -0800 Subject: [PATCH 12/64] Change visibility for MP Tasks Web to public PiperOrigin-RevId: 493343996 --- mediapipe/tasks/web/audio/BUILD | 1 + mediapipe/tasks/web/audio/audio_classifier/BUILD | 2 ++ mediapipe/tasks/web/audio/audio_embedder/BUILD | 2 ++ mediapipe/tasks/web/core/BUILD | 1 + mediapipe/tasks/web/text/BUILD | 1 + mediapipe/tasks/web/text/text_classifier/BUILD | 2 ++ mediapipe/tasks/web/text/text_embedder/BUILD | 2 ++ mediapipe/tasks/web/vision/BUILD | 1 + mediapipe/tasks/web/vision/gesture_recognizer/BUILD | 2 ++ mediapipe/tasks/web/vision/hand_landmarker/BUILD | 2 ++ mediapipe/tasks/web/vision/image_classifier/BUILD | 2 ++ mediapipe/tasks/web/vision/image_embedder/BUILD | 2 ++ mediapipe/tasks/web/vision/object_detector/BUILD | 2 ++ 13 files changed, 22 insertions(+) diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index d08602521..9d26f1118 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -7,6 +7,7 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_library( name = "audio_lib", srcs = ["index.ts"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/audio/audio_classifier", "//mediapipe/tasks/web/audio/audio_embedder", diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 6f785dd0d..dc82a4a24 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -11,6 +11,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "audio_classifier", srcs = ["audio_classifier.ts"], + visibility = ["//visibility:public"], deps = [ ":audio_classifier_types", "//mediapipe/framework:calculator_jspb_proto", @@ -35,6 +36,7 @@ mediapipe_ts_declaration( "audio_classifier_options.d.ts", "audio_classifier_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD index 0555bb639..dc84d0cd6 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -11,6 +11,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "audio_embedder", srcs = ["audio_embedder.ts"], + visibility = ["//visibility:public"], deps = [ ":audio_embedder_types", "//mediapipe/framework:calculator_jspb_proto", @@ -35,6 +36,7 @@ mediapipe_ts_declaration( "audio_embedder_options.d.ts", "audio_embedder_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index de429690d..be1b71f5d 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -28,6 +28,7 @@ mediapipe_ts_library( mediapipe_ts_library( name = "fileset_resolver", srcs = ["fileset_resolver.ts"], + visibility = ["//visibility:public"], deps = [":core"], ) diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index 159db1a0d..32f43d4b6 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -7,6 +7,7 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_library( name = "text_lib", srcs = ["index.ts"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/core:fileset_resolver", "//mediapipe/tasks/web/text/text_classifier", diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 2a7de21d6..07f78ac20 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -12,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "text_classifier", srcs = ["text_classifier.ts"], + visibility = ["//visibility:public"], deps = [ ":text_classifier_types", "//mediapipe/framework:calculator_jspb_proto", @@ -36,6 +37,7 @@ mediapipe_ts_declaration( "text_classifier_options.d.ts", "text_classifier_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index 17d105258..7d796fb7e 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -12,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "text_embedder", srcs = ["text_embedder.ts"], + visibility = ["//visibility:public"], deps = [ ":text_embedder_types", "//mediapipe/framework:calculator_jspb_proto", @@ -36,6 +37,7 @@ mediapipe_ts_declaration( "text_embedder_options.d.ts", "text_embedder_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 42bc0a494..93493e873 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -7,6 +7,7 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_library( name = "vision_lib", srcs = ["index.ts"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/core:fileset_resolver", "//mediapipe/tasks/web/vision/gesture_recognizer", diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index ddfd1a327..6e2e56196 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -12,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "gesture_recognizer", srcs = ["gesture_recognizer.ts"], + visibility = ["//visibility:public"], deps = [ ":gesture_recognizer_types", "//mediapipe/framework:calculator_jspb_proto", @@ -42,6 +43,7 @@ mediapipe_ts_declaration( "gesture_recognizer_options.d.ts", "gesture_recognizer_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index fc3e6ef1f..520898e34 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -12,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "hand_landmarker", srcs = ["hand_landmarker.ts"], + visibility = ["//visibility:public"], deps = [ ":hand_landmarker_types", "//mediapipe/framework:calculator_jspb_proto", @@ -38,6 +39,7 @@ mediapipe_ts_declaration( "hand_landmarker_options.d.ts", "hand_landmarker_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index ebe64ecf4..848c162ae 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -11,6 +11,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "image_classifier", srcs = ["image_classifier.ts"], + visibility = ["//visibility:public"], deps = [ ":image_classifier_types", "//mediapipe/framework:calculator_jspb_proto", @@ -35,6 +36,7 @@ mediapipe_ts_declaration( "image_classifier_options.d.ts", "image_classifier_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index 2f012dc5e..6c9d80fb1 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -11,6 +11,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "image_embedder", srcs = ["image_embedder.ts"], + visibility = ["//visibility:public"], deps = [ ":image_embedder_types", "//mediapipe/framework:calculator_jspb_proto", @@ -36,6 +37,7 @@ mediapipe_ts_declaration( "image_embedder_options.d.ts", "image_embedder_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index 198585258..f73790895 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -12,6 +12,7 @@ licenses(["notice"]) mediapipe_ts_library( name = "object_detector", srcs = ["object_detector.ts"], + visibility = ["//visibility:public"], deps = [ ":object_detector_types", "//mediapipe/framework:calculator_jspb_proto", @@ -32,6 +33,7 @@ mediapipe_ts_declaration( "object_detector_options.d.ts", "object_detector_result.d.ts", ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/core", From c6e6f9e0b9b35d055cd83016e468a8c30a7b153b Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 6 Dec 2022 11:05:47 -0800 Subject: [PATCH 13/64] Fix aar breakage caused by missing "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite". PiperOrigin-RevId: 493357585 --- .../java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index c6aba3c66..727d020a6 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -21,7 +21,6 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite", - "//mediapipe/tasks/cc/components/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite", @@ -43,6 +42,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", From 6deef1a5f13c4af5e38abe96f8aabbba733dcdcb Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 6 Dec 2022 12:07:51 -0800 Subject: [PATCH 14/64] Allow specifying tag_suffix in the templated CreateModelResources method. PiperOrigin-RevId: 493375701 --- mediapipe/tasks/cc/core/model_task_graph.cc | 2 +- mediapipe/tasks/cc/core/model_task_graph.h | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/cc/core/model_task_graph.cc b/mediapipe/tasks/cc/core/model_task_graph.cc index 66434483b..0cb556ec2 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.cc +++ b/mediapipe/tasks/cc/core/model_task_graph.cc @@ -186,7 +186,7 @@ absl::StatusOr ModelTaskGraph::CreateModelResources( absl::StatusOr ModelTaskGraph::CreateModelAssetBundleResources( SubgraphContext* sc, std::unique_ptr external_file, - const std::string tag_suffix) { + std::string tag_suffix) { auto model_resources_cache_service = sc->Service(kModelResourcesCacheService); bool has_file_pointer_meta = external_file->has_file_pointer_meta(); // if external file is set by file pointer, no need to add the model asset diff --git a/mediapipe/tasks/cc/core/model_task_graph.h b/mediapipe/tasks/cc/core/model_task_graph.h index 50dcc903b..3068b2c46 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.h +++ b/mediapipe/tasks/cc/core/model_task_graph.h @@ -59,14 +59,16 @@ class ModelTaskGraph : public Subgraph { // creates a local model resources object that can only be used in the graph // construction stage. The returned model resources pointer will provide graph // authors with the access to the metadata extractor and the tflite model. + // If more than one model resources are created in a graph, the model + // resources graph service add the tag_suffix to support multiple resources. template absl::StatusOr CreateModelResources( - SubgraphContext* sc) { + SubgraphContext* sc, std::string tag_suffix = "") { auto external_file = std::make_unique(); external_file->Swap(sc->MutableOptions() ->mutable_base_options() ->mutable_model_asset()); - return CreateModelResources(sc, std::move(external_file)); + return CreateModelResources(sc, std::move(external_file), tag_suffix); } // If the model resources graph service is available, creates a model @@ -83,7 +85,7 @@ class ModelTaskGraph : public Subgraph { // resources. absl::StatusOr CreateModelResources( SubgraphContext* sc, std::unique_ptr external_file, - const std::string tag_suffix = ""); + std::string tag_suffix = ""); // If the model resources graph service is available, creates a model asset // bundle resources object from the subgraph context, and caches the created From cdc14522e2821a60ec1ee208430e364917e21985 Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Tue, 6 Dec 2022 13:01:06 -0800 Subject: [PATCH 15/64] Added issue templates for MP Preview. PiperOrigin-RevId: 493389856 --- .../ISSUE_TEMPLATE/11-tasks-issue.md | 25 +++++++++++++++++++ .../ISSUE_TEMPLATE/12-model-maker-issue.md | 25 +++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md create mode 100644 mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md diff --git a/mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md b/mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md new file mode 100644 index 000000000..ab7b38368 --- /dev/null +++ b/mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md @@ -0,0 +1,25 @@ +--- +name: "Tasks Issue" +about: Use this template for assistance with a specific task, such as "Gesture Recognition" or "Object Detection", including inference model usage/training etc. +labels: type:support + +--- +Please make sure that this is a [Tasks](https://developers.google.com/mediapipe/solutions) issue. + +**System information** (Please provide as much relevant information as possible) +- Have I written custom code (as opposed to using a stock example script provided in MediaPipe): +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, Android 11, iOS 14.4): +- MediaPipe Tasks SDK version: +- Task name (e.g. Object detection, Gesture recognition etc.): +- Programming Language and version ( e.g. C++, Python, Java): + +**Describe the expected behavior:** + +**Standalone code you may have used to try to get what you need :** + +If there is a problem, provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab, GitHub repo link or anything that we can use to reproduce the problem: + +**Other info / Complete Logs :** +Include any logs or source code that would be helpful to +diagnose the problem. If including tracebacks, please include the full +traceback. Large logs and files should be attached: diff --git a/mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md b/mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md new file mode 100644 index 000000000..687360957 --- /dev/null +++ b/mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md @@ -0,0 +1,25 @@ +--- +name: "Model Maker Issue" +about: Use this template for assistance with a specific task, such as "Gesture Recognition" or "Object Detection", including inference model usage/training etc. +labels: type:support + +--- +Please make sure that this is a [Model Maker](https://developers.google.com/mediapipe/solutions) issue. + +**System information** (Please provide as much relevant information as possible) +- Have I written custom code (as opposed to using a stock example script provided in MediaPipe): +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): +- Python version (e.g. 3.8): +- [MediaPipe Model Maker version](https://pypi.org/project/mediapipe-model-maker/): +- Task name (e.g. Image classification, Gesture recognition etc.): + +**Describe the expected behavior:** + +**Standalone code you may have used to try to get what you need :** + +If there is a problem, provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab, GitHub repo link or anything that we can use to reproduce the problem: + +**Other info / Complete Logs :** +Include any logs or source code that would be helpful to +diagnose the problem. If including tracebacks, please include the full +traceback. Large logs and files should be attached: From 0f32072804a4e078c9f64ae8cb48d9b1777a679f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 6 Dec 2022 14:01:49 -0800 Subject: [PATCH 16/64] Move ISSUE_TEMPLATAE files to .github folder PiperOrigin-RevId: 493405734 --- .../opensource_only => .github}/ISSUE_TEMPLATE/11-tasks-issue.md | 0 .../ISSUE_TEMPLATE/12-model-maker-issue.md | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename {mediapipe/opensource_only => .github}/ISSUE_TEMPLATE/11-tasks-issue.md (100%) rename {mediapipe/opensource_only => .github}/ISSUE_TEMPLATE/12-model-maker-issue.md (100%) diff --git a/mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md b/.github/ISSUE_TEMPLATE/11-tasks-issue.md similarity index 100% rename from mediapipe/opensource_only/ISSUE_TEMPLATE/11-tasks-issue.md rename to .github/ISSUE_TEMPLATE/11-tasks-issue.md diff --git a/mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md similarity index 100% rename from mediapipe/opensource_only/ISSUE_TEMPLATE/12-model-maker-issue.md rename to .github/ISSUE_TEMPLATE/12-model-maker-issue.md From 9bc7b120de85d4991292d831bd844264c783350b Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Tue, 6 Dec 2022 15:12:25 -0800 Subject: [PATCH 17/64] Tweaked the issue templates. PiperOrigin-RevId: 493424927 --- .github/ISSUE_TEMPLATE/11-tasks-issue.md | 2 +- .github/ISSUE_TEMPLATE/12-model-maker-issue.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/11-tasks-issue.md b/.github/ISSUE_TEMPLATE/11-tasks-issue.md index ab7b38368..264371120 100644 --- a/.github/ISSUE_TEMPLATE/11-tasks-issue.md +++ b/.github/ISSUE_TEMPLATE/11-tasks-issue.md @@ -1,6 +1,6 @@ --- name: "Tasks Issue" -about: Use this template for assistance with a specific task, such as "Gesture Recognition" or "Object Detection", including inference model usage/training etc. +about: Use this template for assistance with using MediaPipe Tasks to deploy on-device ML solutions (e.g. gesture recognition etc.) on supported platforms. labels: type:support --- diff --git a/.github/ISSUE_TEMPLATE/12-model-maker-issue.md b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md index 687360957..258390d5e 100644 --- a/.github/ISSUE_TEMPLATE/12-model-maker-issue.md +++ b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md @@ -1,6 +1,6 @@ --- name: "Model Maker Issue" -about: Use this template for assistance with a specific task, such as "Gesture Recognition" or "Object Detection", including inference model usage/training etc. +about: Use this template for assistance with using MediaPipe Model Maker to create custom on-device ML solutions. labels: type:support --- From fca0f5806b470a47a3c74a7085d32c32a12d61f1 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 6 Dec 2022 15:16:42 -0800 Subject: [PATCH 18/64] Create Build Rules for Apple Frameworks PiperOrigin-RevId: 493426040 --- mediapipe/examples/ios/common/BUILD | 10 ++-- mediapipe/examples/ios/faceeffect/BUILD | 10 ++-- mediapipe/gpu/BUILD | 64 ++++++++-------------- mediapipe/objc/BUILD | 68 ++++++++++------------- third_party/apple_frameworks/BUILD | 73 +++++++++++++++++++++++++ 5 files changed, 134 insertions(+), 91 deletions(-) create mode 100644 third_party/apple_frameworks/BUILD diff --git a/mediapipe/examples/ios/common/BUILD b/mediapipe/examples/ios/common/BUILD index 9b8f8a968..bfa770cec 100644 --- a/mediapipe/examples/ios/common/BUILD +++ b/mediapipe/examples/ios/common/BUILD @@ -29,12 +29,6 @@ objc_library( "Base.lproj/LaunchScreen.storyboard", "Base.lproj/Main.storyboard", ], - sdk_frameworks = [ - "AVFoundation", - "CoreGraphics", - "CoreMedia", - "UIKit", - ], visibility = [ "//mediapipe:__subpackages__", ], @@ -42,6 +36,10 @@ objc_library( "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_input_sources_ios", "//mediapipe/objc:mediapipe_layer_renderer", + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:UIKit", ], ) diff --git a/mediapipe/examples/ios/faceeffect/BUILD b/mediapipe/examples/ios/faceeffect/BUILD index 50a6f68bd..e0c3abb86 100644 --- a/mediapipe/examples/ios/faceeffect/BUILD +++ b/mediapipe/examples/ios/faceeffect/BUILD @@ -73,13 +73,11 @@ objc_library( "//mediapipe/modules/face_landmark:face_landmark.tflite", ], features = ["-layering_check"], - sdk_frameworks = [ - "AVFoundation", - "CoreGraphics", - "CoreMedia", - "UIKit", - ], deps = [ + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:UIKit", "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_input_sources_ios", "//mediapipe/objc:mediapipe_layer_renderer", diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 7a8aa6557..f5cb9f715 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -472,13 +472,13 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "Accelerate", - "CoreGraphics", - "CoreVideo", - ], visibility = ["//visibility:public"], - deps = ["//mediapipe/objc:util"], + deps = [ + "//mediapipe/objc:util", + "//third_party/apple_frameworks:Accelerate", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreVideo", + ], ) objc_library( @@ -510,13 +510,11 @@ objc_library( "-x objective-c++", "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", "@com_google_absl//absl/time", "@google_toolbox_for_mac//:GTM_Defines", ], @@ -808,15 +806,13 @@ objc_library( "-Wno-shorten-64-to-32", ], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":gpu_shared_data_internal", ":graph_support", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", "@google_toolbox_for_mac//:GTM_Defines", ], ) @@ -1020,16 +1016,14 @@ objc_library( name = "metal_copy_calculator", srcs = ["MetalCopyCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/gpu:copy_calculator_cc_proto", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -1038,15 +1032,13 @@ objc_library( name = "metal_rgb_weight_calculator", srcs = ["MetalRgbWeightCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -1055,15 +1047,13 @@ objc_library( name = "metal_sobel_calculator", srcs = ["MetalSobelCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -1072,15 +1062,13 @@ objc_library( name = "metal_sobel_compute_calculator", srcs = ["MetalSobelComputeCalculator.mm"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", ":simple_shaders_mtl", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", ], alwayslink = 1, ) @@ -1090,15 +1078,13 @@ objc_library( srcs = ["MPSSobelCalculator.mm"], copts = ["-std=c++17"], features = ["-layering_check"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - "MetalPerformanceShaders", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", + "//third_party/apple_frameworks:MetalPerformanceShaders", ], alwayslink = 1, ) @@ -1106,15 +1092,13 @@ objc_library( objc_library( name = "mps_threshold_calculator", srcs = ["MPSThresholdCalculator.mm"], - sdk_frameworks = [ - "CoreVideo", - "Metal", - "MetalPerformanceShaders", - ], visibility = ["//visibility:public"], deps = [ ":MPPMetalHelper", "//mediapipe/objc:mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Metal", + "//third_party/apple_frameworks:MetalPerformanceShaders", ], alwayslink = 1, ) diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index fafdfee8a..c71c02b6d 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -68,7 +68,6 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = ["Accelerate"], # This build rule is public to allow external customers to build their own iOS apps. visibility = ["//visibility:public"], deps = [ @@ -90,6 +89,7 @@ objc_library( "//mediapipe/gpu:metal_shared_resources", "//mediapipe/gpu:pixel_buffer_pool_util", "//mediapipe/util:cpu_util", + "//third_party/apple_frameworks:Accelerate", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -120,13 +120,13 @@ objc_library( ], "//conditions:default": [], }), - sdk_frameworks = [ - "AVFoundation", - "CoreVideo", - "Foundation", - ], # This build rule is public to allow external customers to build their own iOS apps. visibility = ["//visibility:public"], + deps = [ + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Foundation", + ], ) objc_library( @@ -140,16 +140,14 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "Foundation", - "GLKit", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":mediapipe_framework_ios", "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:gl_simple_shaders", + "//third_party/apple_frameworks:Foundation", + "//third_party/apple_frameworks:GLKit", ], ) @@ -164,16 +162,14 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "Foundation", - "GLKit", - ], # This build rule is public to allow external customers to build their own iOS apps. visibility = ["//visibility:public"], deps = [ ":mediapipe_framework_ios", ":mediapipe_gl_view_renderer", "//mediapipe/gpu:gl_calculator_helper", + "//third_party/apple_frameworks:Foundation", + "//third_party/apple_frameworks:GLKit", ], ) @@ -188,13 +184,11 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "CoreVideo", - "Foundation", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":mediapipe_framework_ios", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:Foundation", "@com_google_absl//absl/strings", ], ) @@ -211,23 +205,21 @@ objc_library( copts = [ "-Wno-shorten-64-to-32", ], - sdk_frameworks = [ - "AVFoundation", - "Accelerate", - "CoreGraphics", - "CoreMedia", - "CoreVideo", - "GLKit", - "OpenGLES", - "QuartzCore", - "UIKit", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":CGImageRefUtils", ":Weakify", ":mediapipe_framework_ios", "//mediapipe/framework:calculator_framework", + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:Accelerate", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:GLKit", + "//third_party/apple_frameworks:OpenGLES", + "//third_party/apple_frameworks:QuartzCore", + "//third_party/apple_frameworks:UIKit", ], ) @@ -245,16 +237,6 @@ objc_library( data = [ "testdata/googlelogo_color_272x92dp.png", ], - sdk_frameworks = [ - "AVFoundation", - "Accelerate", - "CoreGraphics", - "CoreMedia", - "CoreVideo", - "GLKit", - "QuartzCore", - "UIKit", - ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":CGImageRefUtils", @@ -263,6 +245,14 @@ objc_library( ":mediapipe_framework_ios", ":mediapipe_input_sources_ios", "//mediapipe/calculators/core:pass_through_calculator", + "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:Accelerate", + "//third_party/apple_frameworks:CoreGraphics", + "//third_party/apple_frameworks:CoreMedia", + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:GLKit", + "//third_party/apple_frameworks:QuartzCore", + "//third_party/apple_frameworks:UIKit", ], ) diff --git a/third_party/apple_frameworks/BUILD b/third_party/apple_frameworks/BUILD new file mode 100644 index 000000000..05f830e81 --- /dev/null +++ b/third_party/apple_frameworks/BUILD @@ -0,0 +1,73 @@ +# Build rules to inject Apple Frameworks + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "CoreGraphics", + linkopts = ["-framework CoreGraphics"], +) + +cc_library( + name = "CoreMedia", + linkopts = ["-framework CoreMedia"], +) + +cc_library( + name = "UIKit", + linkopts = ["-framework UIKit"], +) + +cc_library( + name = "Accelerate", + linkopts = ["-framework Accelerate"], +) + +cc_library( + name = "CoreVideo", + linkopts = ["-framework CoreVideo"], +) + +cc_library( + name = "Metal", + linkopts = ["-framework Metal"], +) + +cc_library( + name = "MetalPerformanceShaders", + linkopts = ["-framework MetalPerformanceShaders"], +) + +cc_library( + name = "AVFoundation", + linkopts = ["-framework AVFoundation"], +) + +cc_library( + name = "Foundation", + linkopts = ["-framework Foundation"], +) + +cc_library( + name = "CoreImage", + linkopts = ["-framework CoreImage"], +) + +cc_library( + name = "XCTest", + linkopts = ["-framework XCTest"], +) + +cc_library( + name = "GLKit", + linkopts = ["-framework GLKit"], +) + +cc_library( + name = "OpenGLES", + linkopts = ["-framework OpenGLES"], +) + +cc_library( + name = "QuartzCore", + linkopts = ["-framework QuartzCore"], +) From 576c6da173c4b84b13787c0e6926acab05118880 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 6 Dec 2022 15:22:03 -0800 Subject: [PATCH 19/64] Internal change PiperOrigin-RevId: 493427500 --- mediapipe/tasks/python/audio/BUILD | 4 +- .../tasks/python/audio/audio_classifier.py | 34 ++++++++-- .../tasks/python/audio/audio_embedder.py | 21 ++++-- .../tasks/python/components/processors/BUILD | 9 --- .../python/components/processors/__init__.py | 3 - .../components/processors/embedder_options.py | 68 ------------------- mediapipe/tasks/python/components/utils/BUILD | 5 +- .../components/utils/cosine_similarity.py | 2 - mediapipe/tasks/python/test/audio/BUILD | 2 - .../test/audio/audio_classifier_test.py | 20 ++---- .../python/test/audio/audio_embedder_test.py | 10 +-- mediapipe/tasks/python/test/text/BUILD | 2 - .../python/test/text/text_classifier_test.py | 2 - .../python/test/text/text_embedder_test.py | 10 +-- mediapipe/tasks/python/test/vision/BUILD | 2 - .../test/vision/image_classifier_test.py | 52 +++++--------- .../python/test/vision/image_embedder_test.py | 10 +-- mediapipe/tasks/python/text/BUILD | 4 +- .../tasks/python/text/text_classifier.py | 35 ++++++++-- mediapipe/tasks/python/text/text_embedder.py | 20 ++++-- mediapipe/tasks/python/vision/BUILD | 4 +- .../tasks/python/vision/image_classifier.py | 35 ++++++++-- .../tasks/python/vision/image_embedder.py | 20 ++++-- 23 files changed, 162 insertions(+), 212 deletions(-) delete mode 100644 mediapipe/tasks/python/components/processors/embedder_options.py diff --git a/mediapipe/tasks/python/audio/BUILD b/mediapipe/tasks/python/audio/BUILD index 2e5815ff0..ce7c5ce08 100644 --- a/mediapipe/tasks/python/audio/BUILD +++ b/mediapipe/tasks/python/audio/BUILD @@ -29,11 +29,11 @@ py_library( "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/audio/core:base_audio_task_api", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", @@ -51,11 +51,11 @@ py_library( "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/audio/core:base_audio_task_api", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/audio/audio_classifier.py b/mediapipe/tasks/python/audio/audio_classifier.py index d82b6fe27..cc87d6221 100644 --- a/mediapipe/tasks/python/audio/audio_classifier.py +++ b/mediapipe/tasks/python/audio/audio_classifier.py @@ -21,11 +21,11 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import packet from mediapipe.tasks.cc.audio.audio_classifier.proto import audio_classifier_graph_options_pb2 from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module from mediapipe.tasks.python.audio.core import base_audio_task_api from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options as classifier_options_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -34,7 +34,7 @@ AudioClassifierResult = classification_result_module.ClassificationResult _AudioClassifierGraphOptionsProto = audio_classifier_graph_options_pb2.AudioClassifierGraphOptions _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options_module.ClassifierOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _RunningMode = running_mode_module.AudioTaskRunningMode _TaskInfo = task_info_module.TaskInfo @@ -62,16 +62,31 @@ class AudioClassifierOptions: mode for running classification on the audio stream, such as from microphone. In this mode, the "result_callback" below must be specified to receive the classification results asynchronously. - classifier_options: Options for configuring the classifier behavior, such as - score threshold, number of results, etc. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. result_callback: The user-defined result callback for processing audio stream data. The result callback should only be specified when the running mode is set to the audio stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS - classifier_options: Optional[_ClassifierOptions] = dataclasses.field( - default_factory=_ClassifierOptions) + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None @doc_controls.do_not_generate_docs @@ -79,7 +94,12 @@ class AudioClassifierOptions: """Generates an AudioClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True - classifier_options_proto = self.classifier_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) return _AudioClassifierGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/audio/audio_embedder.py b/mediapipe/tasks/python/audio/audio_embedder.py index 629e21882..4c37783e9 100644 --- a/mediapipe/tasks/python/audio/audio_embedder.py +++ b/mediapipe/tasks/python/audio/audio_embedder.py @@ -21,11 +21,11 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import packet from mediapipe.tasks.cc.audio.audio_embedder.proto import audio_embedder_graph_options_pb2 from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module from mediapipe.tasks.python.audio.core import base_audio_task_api from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -35,7 +35,7 @@ AudioEmbedderResult = embedding_result_module.EmbeddingResult _AudioEmbedderGraphOptionsProto = audio_embedder_graph_options_pb2.AudioEmbedderGraphOptions _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options_module.EmbedderOptions +_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions _RunningMode = running_mode_module.AudioTaskRunningMode _TaskInfo = task_info_module.TaskInfo @@ -63,16 +63,22 @@ class AudioEmbedderOptions: stream mode for running embedding extraction on the audio stream, such as from microphone. In this mode, the "result_callback" below must be specified to receive the embedding results asynchronously. - embedder_options: Options for configuring the embedder behavior, such as - l2_normalize and quantize. + l2_normalize: Whether to normalize the returned feature vector with L2 norm. + Use this option only if the model does not already contain a native + L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and + L2 norm is thus achieved through TF Lite inference. + quantize: Whether the returned embedding should be quantized to bytes via + scalar quantization. Embeddings are implicitly assumed to be unit-norm and + therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + the l2_normalize option if this is not the case. result_callback: The user-defined result callback for processing audio stream data. The result callback should only be specified when the running mode is set to the audio stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS - embedder_options: Optional[_EmbedderOptions] = dataclasses.field( - default_factory=_EmbedderOptions) + l2_normalize: Optional[bool] = None + quantize: Optional[bool] = None result_callback: Optional[Callable[[AudioEmbedderResult, int], None]] = None @doc_controls.do_not_generate_docs @@ -80,7 +86,8 @@ class AudioEmbedderOptions: """Generates an AudioEmbedderOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True - embedder_options_proto = self.embedder_options.to_pb2() + embedder_options_proto = _EmbedderOptionsProto( + l2_normalize=self.l2_normalize, quantize=self.quantize) return _AudioEmbedderGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/components/processors/BUILD b/mediapipe/tasks/python/components/processors/BUILD index eef368db0..f87a579b0 100644 --- a/mediapipe/tasks/python/components/processors/BUILD +++ b/mediapipe/tasks/python/components/processors/BUILD @@ -28,12 +28,3 @@ py_library( "//mediapipe/tasks/python/core:optional_dependencies", ], ) - -py_library( - name = "embedder_options", - srcs = ["embedder_options.py"], - deps = [ - "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", - "//mediapipe/tasks/python/core:optional_dependencies", - ], -) diff --git a/mediapipe/tasks/python/components/processors/__init__.py b/mediapipe/tasks/python/components/processors/__init__.py index adcb38757..0eb73abe0 100644 --- a/mediapipe/tasks/python/components/processors/__init__.py +++ b/mediapipe/tasks/python/components/processors/__init__.py @@ -15,12 +15,9 @@ """MediaPipe Tasks Components Processors API.""" import mediapipe.tasks.python.components.processors.classifier_options -import mediapipe.tasks.python.components.processors.embedder_options ClassifierOptions = classifier_options.ClassifierOptions -EmbedderOptions = embedder_options.EmbedderOptions # Remove unnecessary modules to avoid duplication in API docs. del classifier_options -del embedder_options del mediapipe diff --git a/mediapipe/tasks/python/components/processors/embedder_options.py b/mediapipe/tasks/python/components/processors/embedder_options.py deleted file mode 100644 index c86a91105..000000000 --- a/mediapipe/tasks/python/components/processors/embedder_options.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Embedder options data class.""" - -import dataclasses -from typing import Any, Optional - -from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 -from mediapipe.tasks.python.core.optional_dependencies import doc_controls - -_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions - - -@dataclasses.dataclass -class EmbedderOptions: - """Shared options used by all embedding extraction tasks. - - Attributes: - l2_normalize: Whether to normalize the returned feature vector with L2 norm. - Use this option only if the model does not already contain a native - L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and - L2 norm is thus achieved through TF Lite inference. - quantize: Whether the returned embedding should be quantized to bytes via - scalar quantization. Embeddings are implicitly assumed to be unit-norm and - therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use - the l2_normalize option if this is not the case. - """ - - l2_normalize: Optional[bool] = None - quantize: Optional[bool] = None - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _EmbedderOptionsProto: - """Generates a EmbedderOptions protobuf object.""" - return _EmbedderOptionsProto( - l2_normalize=self.l2_normalize, quantize=self.quantize) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _EmbedderOptionsProto) -> 'EmbedderOptions': - """Creates a `EmbedderOptions` object from the given protobuf object.""" - return EmbedderOptions( - l2_normalize=pb2_obj.l2_normalize, quantize=pb2_obj.quantize) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, EmbedderOptions): - return False - - return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/utils/BUILD b/mediapipe/tasks/python/components/utils/BUILD index b64d04c72..31114f326 100644 --- a/mediapipe/tasks/python/components/utils/BUILD +++ b/mediapipe/tasks/python/components/utils/BUILD @@ -23,8 +23,5 @@ licenses(["notice"]) py_library( name = "cosine_similarity", srcs = ["cosine_similarity.py"], - deps = [ - "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", - ], + deps = ["//mediapipe/tasks/python/components/containers:embedding_result"], ) diff --git a/mediapipe/tasks/python/components/utils/cosine_similarity.py b/mediapipe/tasks/python/components/utils/cosine_similarity.py index 486c02ece..ff8979458 100644 --- a/mediapipe/tasks/python/components/utils/cosine_similarity.py +++ b/mediapipe/tasks/python/components/utils/cosine_similarity.py @@ -16,10 +16,8 @@ import numpy as np from mediapipe.tasks.python.components.containers import embedding_result -from mediapipe.tasks.python.components.processors import embedder_options _Embedding = embedding_result.Embedding -_EmbedderOptions = embedder_options.EmbedderOptions def _compute_cosine_similarity(u, v): diff --git a/mediapipe/tasks/python/test/audio/BUILD b/mediapipe/tasks/python/test/audio/BUILD index 9278cea55..43f1d417c 100644 --- a/mediapipe/tasks/python/test/audio/BUILD +++ b/mediapipe/tasks/python/test/audio/BUILD @@ -30,7 +30,6 @@ py_test( "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", ], @@ -48,7 +47,6 @@ py_test( "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", ], diff --git a/mediapipe/tasks/python/test/audio/audio_classifier_test.py b/mediapipe/tasks/python/test/audio/audio_classifier_test.py index 0d067e587..75146547c 100644 --- a/mediapipe/tasks/python/test/audio/audio_classifier_test.py +++ b/mediapipe/tasks/python/test/audio/audio_classifier_test.py @@ -27,7 +27,6 @@ from mediapipe.tasks.python.audio import audio_classifier from mediapipe.tasks.python.audio.core import audio_task_running_mode from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils @@ -36,7 +35,6 @@ _AudioClassifierOptions = audio_classifier.AudioClassifierOptions _AudioClassifierResult = classification_result_module.ClassificationResult _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options.ClassifierOptions _RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode _YAMNET_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite' @@ -210,8 +208,7 @@ class AudioClassifierTest(parameterized.TestCase): with _AudioClassifier.create_from_options( _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - max_results=1))) as classifier: + max_results=1)) as classifier: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -222,8 +219,7 @@ class AudioClassifierTest(parameterized.TestCase): with _AudioClassifier.create_from_options( _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - score_threshold=0.9))) as classifier: + score_threshold=0.9)) as classifier: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -234,8 +230,7 @@ class AudioClassifierTest(parameterized.TestCase): with _AudioClassifier.create_from_options( _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - category_allowlist=['Speech']))) as classifier: + category_allowlist=['Speech'])) as classifier: for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -250,8 +245,8 @@ class AudioClassifierTest(parameterized.TestCase): r'exclusive options.'): options = _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), - classifier_options=_ClassifierOptions( - category_allowlist=['foo'], category_denylist=['bar'])) + category_allowlist=['foo'], + category_denylist=['bar']) with _AudioClassifier.create_from_options(options) as unused_classifier: pass @@ -278,8 +273,7 @@ class AudioClassifierTest(parameterized.TestCase): _AudioClassifierOptions( base_options=_BaseOptions( model_asset_path=self.two_heads_model_path), - classifier_options=_ClassifierOptions( - max_results=1))) as classifier: + max_results=1)) as classifier: for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]: classification_result_list = classifier.classify( self._read_wav_file(audio_file)) @@ -364,7 +358,7 @@ class AudioClassifierTest(parameterized.TestCase): options = _AudioClassifierOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), running_mode=_RUNNING_MODE.AUDIO_STREAM, - classifier_options=_ClassifierOptions(max_results=1), + max_results=1, result_callback=save_result) classifier = _AudioClassifier.create_from_options(options) audio_data_list = self._read_wav_file_as_stream(audio_file) diff --git a/mediapipe/tasks/python/test/audio/audio_embedder_test.py b/mediapipe/tasks/python/test/audio/audio_embedder_test.py index 2e38ea2ee..f280235d7 100644 --- a/mediapipe/tasks/python/test/audio/audio_embedder_test.py +++ b/mediapipe/tasks/python/test/audio/audio_embedder_test.py @@ -26,7 +26,6 @@ from scipy.io import wavfile from mediapipe.tasks.python.audio import audio_embedder from mediapipe.tasks.python.audio.core import audio_task_running_mode from mediapipe.tasks.python.components.containers import audio_data as audio_data_module -from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils @@ -35,7 +34,6 @@ _AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions _AudioEmbedderResult = audio_embedder.AudioEmbedderResult _AudioData = audio_data_module.AudioData _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options.EmbedderOptions _RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode _YAMNET_MODEL_FILE = 'yamnet_embedding_metadata.tflite' @@ -172,9 +170,7 @@ class AudioEmbedderTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _AudioEmbedderOptions( - base_options=base_options, - embedder_options=_EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize)) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) with _AudioEmbedder.create_from_options(options) as embedder: embedding_result0_list = embedder.embed(self._read_wav_file(audio_file0)) @@ -291,8 +287,8 @@ class AudioEmbedderTest(parameterized.TestCase): options = _AudioEmbedderOptions( base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), running_mode=_RUNNING_MODE.AUDIO_STREAM, - embedder_options=_EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize), + l2_normalize=l2_normalize, + quantize=quantize, result_callback=save_result) with _AudioEmbedder.create_from_options(options) as embedder: diff --git a/mediapipe/tasks/python/test/text/BUILD b/mediapipe/tasks/python/test/text/BUILD index 38e56bdb2..0e2b06012 100644 --- a/mediapipe/tasks/python/test/text/BUILD +++ b/mediapipe/tasks/python/test/text/BUILD @@ -28,7 +28,6 @@ py_test( deps = [ "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/text:text_classifier", @@ -44,7 +43,6 @@ py_test( ], deps = [ "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/text:text_embedder", diff --git a/mediapipe/tasks/python/test/text/text_classifier_test.py b/mediapipe/tasks/python/test/text/text_classifier_test.py index 8678d2194..8df7dce86 100644 --- a/mediapipe/tasks/python/test/text/text_classifier_test.py +++ b/mediapipe/tasks/python/test/text/text_classifier_test.py @@ -21,14 +21,12 @@ from absl.testing import parameterized from mediapipe.tasks.python.components.containers import category from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.text import text_classifier TextClassifierResult = classification_result_module.ClassificationResult _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options.ClassifierOptions _Category = category.Category _Classifications = classification_result_module.Classifications _TextClassifier = text_classifier.TextClassifier diff --git a/mediapipe/tasks/python/test/text/text_embedder_test.py b/mediapipe/tasks/python/test/text/text_embedder_test.py index c9090026c..1346ba373 100644 --- a/mediapipe/tasks/python/test/text/text_embedder_test.py +++ b/mediapipe/tasks/python/test/text/text_embedder_test.py @@ -21,13 +21,11 @@ from absl.testing import parameterized import numpy as np from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.text import text_embedder _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options_module.EmbedderOptions _Embedding = embedding_result_module.Embedding _TextEmbedder = text_embedder.TextEmbedder _TextEmbedderOptions = text_embedder.TextEmbedderOptions @@ -128,10 +126,8 @@ class TextEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _TextEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) embedder = _TextEmbedder.create_from_options(options) # Extracts both embeddings. @@ -178,10 +174,8 @@ class TextEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _TextEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) with _TextEmbedder.create_from_options(options) as embedder: # Extracts both embeddings. positive_text0 = "it's a charming and often affecting journey" diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 066107421..48ecc30b3 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -49,7 +49,6 @@ py_test( "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:rect", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:image_classifier", @@ -69,7 +68,6 @@ py_test( "//mediapipe/python:_framework_bindings", "//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/components/containers:rect", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:image_embedder", diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 77f16278f..cbeaf36bd 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -26,7 +26,6 @@ from mediapipe.python._framework_bindings import image from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import rect -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_classifier @@ -36,7 +35,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode ImageClassifierResult = classification_result_module.ClassificationResult _Rect = rect.Rect _BaseOptions = base_options_module.BaseOptions -_ClassifierOptions = classifier_options.ClassifierOptions _Category = category_module.Category _Classifications = classification_result_module.Classifications _Image = image.Image @@ -171,9 +169,8 @@ class ImageClassifierTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - custom_classifier_options = _ClassifierOptions(max_results=max_results) options = _ImageClassifierOptions( - base_options=base_options, classifier_options=custom_classifier_options) + base_options=base_options, max_results=max_results) classifier = _ImageClassifier.create_from_options(options) # Performs image classification on the input. @@ -200,9 +197,8 @@ class ImageClassifierTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - custom_classifier_options = _ClassifierOptions(max_results=max_results) options = _ImageClassifierOptions( - base_options=base_options, classifier_options=custom_classifier_options) + base_options=base_options, max_results=max_results) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -212,9 +208,7 @@ class ImageClassifierTest(parameterized.TestCase): def test_classify_succeeds_with_region_of_interest(self): base_options = _BaseOptions(model_asset_path=self.model_path) - custom_classifier_options = _ClassifierOptions(max_results=1) - options = _ImageClassifierOptions( - base_options=base_options, classifier_options=custom_classifier_options) + options = _ImageClassifierOptions(base_options=base_options, max_results=1) with _ImageClassifier.create_from_options(options) as classifier: # Load the test image. test_image = _Image.create_from_file( @@ -230,11 +224,9 @@ class ImageClassifierTest(parameterized.TestCase): _generate_soccer_ball_results().to_pb2()) def test_score_threshold_option(self): - custom_classifier_options = _ClassifierOptions( - score_threshold=_SCORE_THRESHOLD) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + score_threshold=_SCORE_THRESHOLD) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -249,11 +241,9 @@ class ImageClassifierTest(parameterized.TestCase): f'{classification}') def test_max_results_option(self): - custom_classifier_options = _ClassifierOptions( - score_threshold=_SCORE_THRESHOLD) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + score_threshold=_SCORE_THRESHOLD) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -263,11 +253,9 @@ class ImageClassifierTest(parameterized.TestCase): len(categories), _MAX_RESULTS, 'Too many results returned.') def test_allow_list_option(self): - custom_classifier_options = _ClassifierOptions( - category_allowlist=_ALLOW_LIST) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + category_allowlist=_ALLOW_LIST) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -280,10 +268,9 @@ class ImageClassifierTest(parameterized.TestCase): f'Label {label} found but not in label allow list') def test_deny_list_option(self): - custom_classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + category_denylist=_DENY_LIST) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -301,19 +288,17 @@ class ImageClassifierTest(parameterized.TestCase): ValueError, r'`category_allowlist` and `category_denylist` are mutually ' r'exclusive options.'): - custom_classifier_options = _ClassifierOptions( - category_allowlist=['foo'], category_denylist=['bar']) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + category_allowlist=['foo'], + category_denylist=['bar']) with _ImageClassifier.create_from_options(options) as unused_classifier: pass def test_empty_classification_outputs(self): - custom_classifier_options = _ClassifierOptions(score_threshold=1) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - classifier_options=custom_classifier_options) + score_threshold=1) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -386,11 +371,10 @@ class ImageClassifierTest(parameterized.TestCase): classifier.classify_for_video(self.test_image, 0) def test_classify_for_video(self): - custom_classifier_options = _ClassifierOptions(max_results=4) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.VIDEO, - classifier_options=custom_classifier_options) + max_results=4) with _ImageClassifier.create_from_options(options) as classifier: for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( @@ -399,11 +383,10 @@ class ImageClassifierTest(parameterized.TestCase): _generate_burger_results().to_pb2()) def test_classify_for_video_succeeds_with_region_of_interest(self): - custom_classifier_options = _ClassifierOptions(max_results=1) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.VIDEO, - classifier_options=custom_classifier_options) + max_results=1) with _ImageClassifier.create_from_options(options) as classifier: # Load the test image. test_image = _Image.create_from_file( @@ -439,11 +422,10 @@ class ImageClassifierTest(parameterized.TestCase): classifier.classify_for_video(self.test_image, 0) def test_classify_async_calls_with_illegal_timestamp(self): - custom_classifier_options = _ClassifierOptions(max_results=4) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - classifier_options=custom_classifier_options, + max_results=4, result_callback=mock.MagicMock()) with _ImageClassifier.create_from_options(options) as classifier: classifier.classify_async(self.test_image, 100) @@ -466,12 +448,11 @@ class ImageClassifierTest(parameterized.TestCase): self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms - custom_classifier_options = _ClassifierOptions( - max_results=4, score_threshold=threshold) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - classifier_options=custom_classifier_options, + max_results=4, + score_threshold=threshold, result_callback=check_result) with _ImageClassifier.create_from_options(options) as classifier: for timestamp in range(0, 300, 30): @@ -496,11 +477,10 @@ class ImageClassifierTest(parameterized.TestCase): self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms - custom_classifier_options = _ClassifierOptions(max_results=1) options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - classifier_options=custom_classifier_options, + max_results=1, result_callback=check_result) with _ImageClassifier.create_from_options(options) as classifier: for timestamp in range(0, 300, 30): diff --git a/mediapipe/tasks/python/test/vision/image_embedder_test.py b/mediapipe/tasks/python/test/vision/image_embedder_test.py index 4bb96bad6..11c0cf002 100644 --- a/mediapipe/tasks/python/test/vision/image_embedder_test.py +++ b/mediapipe/tasks/python/test/vision/image_embedder_test.py @@ -24,7 +24,6 @@ import numpy as np from mediapipe.python._framework_bindings import image as image_module from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module from mediapipe.tasks.python.components.containers import rect -from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_embedder @@ -33,7 +32,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni _Rect = rect.Rect _BaseOptions = base_options_module.BaseOptions -_EmbedderOptions = embedder_options_module.EmbedderOptions _Embedding = embedding_result_module.Embedding _Image = image_module.Image _ImageEmbedder = image_embedder.ImageEmbedder @@ -142,10 +140,8 @@ class ImageEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _ImageEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) embedder = _ImageEmbedder.create_from_options(options) image_processing_options = None @@ -186,10 +182,8 @@ class ImageEmbedderTest(parameterized.TestCase): # Should never happen raise ValueError('model_file_type is invalid.') - embedder_options = _EmbedderOptions( - l2_normalize=l2_normalize, quantize=quantize) options = _ImageEmbedderOptions( - base_options=base_options, embedder_options=embedder_options) + base_options=base_options, l2_normalize=l2_normalize, quantize=quantize) with _ImageEmbedder.create_from_options(options) as embedder: # Extracts both embeddings. diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index 10b4b8a6e..e2a51cdbd 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -28,9 +28,9 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:classification_result", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", @@ -47,9 +47,9 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/text/text_classifier.py b/mediapipe/tasks/python/text/text_classifier.py index 9711e8b3a..fdb20f0ef 100644 --- a/mediapipe/tasks/python/text/text_classifier.py +++ b/mediapipe/tasks/python/text/text_classifier.py @@ -14,14 +14,14 @@ """MediaPipe text classifier task.""" import dataclasses -from typing import Optional +from typing import Optional, List from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2 from mediapipe.tasks.python.components.containers import classification_result as classification_result_module -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -30,7 +30,7 @@ from mediapipe.tasks.python.text.core import base_text_task_api TextClassifierResult = classification_result_module.ClassificationResult _BaseOptions = base_options_module.BaseOptions _TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions -_ClassifierOptions = classifier_options.ClassifierOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _TaskInfo = task_info_module.TaskInfo _CLASSIFICATIONS_STREAM_NAME = 'classifications_out' @@ -46,17 +46,38 @@ class TextClassifierOptions: Attributes: base_options: Base options for the text classifier task. - classifier_options: Options for the text classification task. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. """ base_options: _BaseOptions - classifier_options: Optional[_ClassifierOptions] = dataclasses.field( - default_factory=_ClassifierOptions) + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextClassifierGraphOptionsProto: """Generates an TextClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() - classifier_options_proto = self.classifier_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) return _TextClassifierGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/text/text_embedder.py b/mediapipe/tasks/python/text/text_embedder.py index a9e560ac9..be899636d 100644 --- a/mediapipe/tasks/python/text/text_embedder.py +++ b/mediapipe/tasks/python/text/text_embedder.py @@ -19,9 +19,9 @@ from typing import Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 from mediapipe.tasks.cc.text.text_embedder.proto import text_embedder_graph_options_pb2 from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -31,7 +31,7 @@ from mediapipe.tasks.python.text.core import base_text_task_api TextEmbedderResult = embedding_result_module.EmbeddingResult _BaseOptions = base_options_module.BaseOptions _TextEmbedderGraphOptionsProto = text_embedder_graph_options_pb2.TextEmbedderGraphOptions -_EmbedderOptions = embedder_options.EmbedderOptions +_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions _TaskInfo = task_info_module.TaskInfo _EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out' @@ -47,17 +47,25 @@ class TextEmbedderOptions: Attributes: base_options: Base options for the text embedder task. - embedder_options: Options for the text embedder task. + l2_normalize: Whether to normalize the returned feature vector with L2 norm. + Use this option only if the model does not already contain a native + L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and + L2 norm is thus achieved through TF Lite inference. + quantize: Whether the returned embedding should be quantized to bytes via + scalar quantization. Embeddings are implicitly assumed to be unit-norm and + therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + the l2_normalize option if this is not the case. """ base_options: _BaseOptions - embedder_options: Optional[_EmbedderOptions] = dataclasses.field( - default_factory=_EmbedderOptions) + l2_normalize: Optional[bool] = None + quantize: Optional[bool] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextEmbedderGraphOptionsProto: """Generates an TextEmbedderOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() - embedder_options_proto = self.embedder_options.to_pb2() + embedder_options_proto = _EmbedderOptionsProto( + l2_normalize=self.l2_normalize, quantize=self.quantize) return _TextEmbedderGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 29e7577e8..241ca4341 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -47,10 +47,10 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/components/containers:rect", - "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", @@ -89,9 +89,9 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:embedding_result", - "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 6cbce7860..b60d18e31 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -14,17 +14,17 @@ """MediaPipe image classifier task.""" import dataclasses -from typing import Callable, Mapping, Optional +from typing import Callable, Mapping, Optional, List from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 from mediapipe.tasks.python.components.containers import classification_result as classification_result_module from mediapipe.tasks.python.components.containers import rect -from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -36,7 +36,7 @@ ImageClassifierResult = classification_result_module.ClassificationResult _NormalizedRect = rect.NormalizedRect _BaseOptions = base_options_module.BaseOptions _ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions -_ClassifierOptions = classifier_options.ClassifierOptions +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _RunningMode = vision_task_running_mode.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo @@ -63,15 +63,31 @@ class ImageClassifierOptions: objects on single image inputs. 2) The video mode for classifying objects on the decoded frames of a video. 3) The live stream mode for classifying objects on a live stream of input data, such as from camera. - classifier_options: Options for the image classification task. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - classifier_options: Optional[_ClassifierOptions] = dataclasses.field( - default_factory=_ClassifierOptions) + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None result_callback: Optional[Callable[ [ImageClassifierResult, image_module.Image, int], None]] = None @@ -80,7 +96,12 @@ class ImageClassifierOptions: """Generates an ImageClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True - classifier_options_proto = self.classifier_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) return _ImageClassifierGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/python/vision/image_embedder.py b/mediapipe/tasks/python/vision/image_embedder.py index a58dca3ae..0bae21bda 100644 --- a/mediapipe/tasks/python/vision/image_embedder.py +++ b/mediapipe/tasks/python/vision/image_embedder.py @@ -21,9 +21,9 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2 from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module -from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module @@ -35,7 +35,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni ImageEmbedderResult = embedding_result_module.EmbeddingResult _BaseOptions = base_options_module.BaseOptions _ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions -_EmbedderOptions = embedder_options.EmbedderOptions +_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions _RunningMode = running_mode_module.VisionTaskRunningMode _TaskInfo = task_info_module.TaskInfo _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions @@ -62,15 +62,22 @@ class ImageEmbedderOptions: image on single image inputs. 2) The video mode for embedding image on the decoded frames of a video. 3) The live stream mode for embedding image on a live stream of input data, such as from camera. - embedder_options: Options for the image embedder task. + l2_normalize: Whether to normalize the returned feature vector with L2 norm. + Use this option only if the model does not already contain a native + L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and + L2 norm is thus achieved through TF Lite inference. + quantize: Whether the returned embedding should be quantized to bytes via + scalar quantization. Embeddings are implicitly assumed to be unit-norm and + therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + the l2_normalize option if this is not the case. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - embedder_options: Optional[_EmbedderOptions] = dataclasses.field( - default_factory=_EmbedderOptions) + l2_normalize: Optional[bool] = None + quantize: Optional[bool] = None result_callback: Optional[Callable[ [ImageEmbedderResult, image_module.Image, int], None]] = None @@ -79,7 +86,8 @@ class ImageEmbedderOptions: """Generates an ImageEmbedderOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True - embedder_options_proto = self.embedder_options.to_pb2() + embedder_options_proto = _EmbedderOptionsProto( + l2_normalize=self.l2_normalize, quantize=self.quantize) return _ImageEmbedderGraphOptionsProto( base_options=base_options_proto, From 1167f61f9825cc80e3e81b53b08a59f1a19ef456 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 6 Dec 2022 18:02:35 -0800 Subject: [PATCH 20/64] Remove generic Options template argument from TaskRunner PiperOrigin-RevId: 493462947 --- mediapipe/tasks/web/audio/core/BUILD | 5 +---- .../tasks/web/audio/core/audio_task_runner.ts | 3 +-- mediapipe/tasks/web/core/task_runner.ts | 14 ++++++-------- .../web/text/text_classifier/text_classifier.ts | 2 +- .../tasks/web/text/text_embedder/text_embedder.ts | 2 +- .../tasks/web/vision/core/vision_task_runner.ts | 3 +-- 6 files changed, 11 insertions(+), 18 deletions(-) diff --git a/mediapipe/tasks/web/audio/core/BUILD b/mediapipe/tasks/web/audio/core/BUILD index 9ab6c7bee..cea689838 100644 --- a/mediapipe/tasks/web/audio/core/BUILD +++ b/mediapipe/tasks/web/audio/core/BUILD @@ -7,8 +7,5 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_library( name = "audio_task_runner", srcs = ["audio_task_runner.ts"], - deps = [ - "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:task_runner", - ], + deps = ["//mediapipe/tasks/web/core:task_runner"], ) diff --git a/mediapipe/tasks/web/audio/core/audio_task_runner.ts b/mediapipe/tasks/web/audio/core/audio_task_runner.ts index 00cfe0253..24d78378d 100644 --- a/mediapipe/tasks/web/audio/core/audio_task_runner.ts +++ b/mediapipe/tasks/web/audio/core/audio_task_runner.ts @@ -15,10 +15,9 @@ */ import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Base class for all MediaPipe Audio Tasks. */ -export abstract class AudioTaskRunner extends TaskRunner { +export abstract class AudioTaskRunner extends TaskRunner { private defaultSampleRate = 48000; /** diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index e2ab42e31..71e159dce 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -37,10 +37,9 @@ export class GraphRunnerImageLib extends GraphRunnerImageLibType {} * supported and loads the relevant WASM binary. * @return A fully instantiated instance of `T`. */ -export async function -createTaskRunner, O extends TaskRunnerOptions>( +export async function createTaskRunner( type: WasmMediaPipeConstructor, initializeCanvas: boolean, - fileset: WasmFileset, options: O): Promise { + fileset: WasmFileset, options: TaskRunnerOptions): Promise { const fileLocator: FileLocator = { locateFile() { // The only file loaded with this mechanism is the Wasm binary @@ -61,7 +60,7 @@ createTaskRunner, O extends TaskRunnerOptions>( } /** Base class for all MediaPipe Tasks. */ -export abstract class TaskRunner { +export abstract class TaskRunner { protected abstract baseOptions: BaseOptionsProto; protected graphRunner: GraphRunnerImageLib; private processingErrors: Error[] = []; @@ -71,10 +70,9 @@ export abstract class TaskRunner { * supported and loads the relevant WASM binary. * @return A fully instantiated instance of `T`. */ - protected static async createInstance, - O extends TaskRunnerOptions>( + protected static async createInstance( type: WasmMediaPipeConstructor, initializeCanvas: boolean, - fileset: WasmFileset, options: O): Promise { + fileset: WasmFileset, options: TaskRunnerOptions): Promise { return createTaskRunner(type, initializeCanvas, fileset, options); } @@ -92,7 +90,7 @@ export abstract class TaskRunner { } /** Configures the shared options of a MediaPipe Task. */ - async setOptions(options: O): Promise { + async setOptions(options: TaskRunnerOptions): Promise { if (options.baseOptions) { this.baseOptions = await convertBaseOptionsToProto( options.baseOptions, this.baseOptions); diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 8810d4b42..4a8588836 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -41,7 +41,7 @@ const TEXT_CLASSIFIER_GRAPH = // tslint:disable:jspb-use-builder-pattern /** Performs Natural Language classification. */ -export class TextClassifier extends TaskRunner { +export class TextClassifier extends TaskRunner { private classificationResult: TextClassifierResult = {classifications: []}; private readonly options = new TextClassifierGraphOptions(); diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 62f9b06db..cd5bc644e 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -45,7 +45,7 @@ const TEXT_EMBEDDER_CALCULATOR = /** * Performs embedding extraction on text. */ -export class TextEmbedder extends TaskRunner { +export class TextEmbedder extends TaskRunner { private embeddingResult: TextEmbedderResult = {embeddings: []}; private readonly options = new TextEmbedderGraphOptionsProto(); diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 78b4859f2..3432b521b 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -20,8 +20,7 @@ import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {VisionTaskOptions} from './vision_task_options'; /** Base class for all MediaPipe Vision Tasks. */ -export abstract class VisionTaskRunner extends - TaskRunner { +export abstract class VisionTaskRunner extends TaskRunner { /** Configures the shared options of a vision task. */ override async setOptions(options: VisionTaskOptions): Promise { await super.setOptions(options); From 402834b4f2236ed2d707f0d20c0ebd2d1a42a721 Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Tue, 6 Dec 2022 19:46:33 -0800 Subject: [PATCH 21/64] Internal change PiperOrigin-RevId: 493480322 --- docs/build_model_maker_api_docs.py | 81 ++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 docs/build_model_maker_api_docs.py diff --git a/docs/build_model_maker_api_docs.py b/docs/build_model_maker_api_docs.py new file mode 100644 index 000000000..7732b7d56 --- /dev/null +++ b/docs/build_model_maker_api_docs.py @@ -0,0 +1,81 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""MediaPipe Model Maker reference docs generation script. + +This script generates API reference docs for the `mediapipe` PIP package. + +$> pip install -U git+https://github.com/tensorflow/docs mediapipe-model-maker +$> python build_model_maker_api_docs.py +""" + +import os + +from absl import app +from absl import flags + +from tensorflow_docs.api_generator import generate_lib + +try: + # mediapipe has not been set up to work with bazel yet, so catch & report. + import mediapipe_model_maker # pytype: disable=import-error +except ImportError as e: + raise ImportError('Please `pip install mediapipe-model-maker`.') from e + + +PROJECT_SHORT_NAME = 'mediapipe_model_maker' +PROJECT_FULL_NAME = 'MediaPipe Model Maker' + +_OUTPUT_DIR = flags.DEFINE_string( + 'output_dir', + default='/tmp/generated_docs', + help='Where to write the resulting docs.') + +_URL_PREFIX = flags.DEFINE_string( + 'code_url_prefix', + 'https://github.com/google/mediapipe/tree/master/mediapipe/model_maker', + 'The url prefix for links to code.') + +_SEARCH_HINTS = flags.DEFINE_bool( + 'search_hints', True, + 'Include metadata search hints in the generated files') + +_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python', + 'Path prefix in the _toc.yaml') + + +def gen_api_docs(): + """Generates API docs for the mediapipe-model-maker package.""" + + doc_generator = generate_lib.DocGenerator( + root_title=PROJECT_FULL_NAME, + py_modules=[(PROJECT_SHORT_NAME, mediapipe_model_maker)], + base_dir=os.path.dirname(mediapipe_model_maker.__file__), + code_url_prefix=_URL_PREFIX.value, + search_hints=_SEARCH_HINTS.value, + site_path=_SITE_PATH.value, + callbacks=[], + ) + + doc_generator.build(_OUTPUT_DIR.value) + + print('Docs output to:', _OUTPUT_DIR.value) + + +def main(_): + gen_api_docs() + + +if __name__ == '__main__': + app.run(main) From 523d16dffab5d066879b300230cc9ac26ad49128 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 6 Dec 2022 23:54:11 -0800 Subject: [PATCH 22/64] Make GpuBuffer a shared_ptr to a storage collection PiperOrigin-RevId: 493519590 --- mediapipe/gpu/BUILD | 2 + mediapipe/gpu/gpu_buffer.cc | 102 +++++++++++++++++++++--------- mediapipe/gpu/gpu_buffer.h | 105 +++++++++++++++++-------------- mediapipe/gpu/gpu_buffer_test.cc | 22 +++++++ 4 files changed, 156 insertions(+), 75 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index f5cb9f715..009eb3f9e 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -289,7 +289,9 @@ cc_library( deps = [ ":gpu_buffer_format", ":gpu_buffer_storage", + "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:logging", ":gpu_buffer_storage_image_frame", diff --git a/mediapipe/gpu/gpu_buffer.cc b/mediapipe/gpu/gpu_buffer.cc index 388960b11..628e86099 100644 --- a/mediapipe/gpu/gpu_buffer.cc +++ b/mediapipe/gpu/gpu_buffer.cc @@ -3,6 +3,7 @@ #include #include +#include "absl/functional/bind_front.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "mediapipe/framework/port/logging.h" @@ -25,57 +26,101 @@ struct StorageTypeFormatter { } // namespace std::string GpuBuffer::DebugString() const { - return absl::StrCat("GpuBuffer[", - absl::StrJoin(storages_, ", ", StorageTypeFormatter()), - "]"); + return holder_ ? absl::StrCat("GpuBuffer[", width(), "x", height(), " ", + format(), " as ", holder_->DebugString(), "]") + : "GpuBuffer[invalid]"; } -internal::GpuBufferStorage* GpuBuffer::GetStorageForView( +std::string GpuBuffer::StorageHolder::DebugString() const { + absl::MutexLock lock(&mutex_); + return absl::StrJoin(storages_, ", ", StorageTypeFormatter()); +} + +internal::GpuBufferStorage* GpuBuffer::StorageHolder::GetStorageForView( TypeId view_provider_type, bool for_writing) const { - const std::shared_ptr* chosen_storage = nullptr; + std::shared_ptr chosen_storage; + std::function()> conversion; - // First see if any current storage supports the view. - for (const auto& s : storages_) { - if (s->can_down_cast_to(view_provider_type)) { - chosen_storage = &s; - break; - } - } - - // Then try to convert existing storages to one that does. - // TODO: choose best conversion. - if (!chosen_storage) { + { + absl::MutexLock lock(&mutex_); + // First see if any current storage supports the view. for (const auto& s : storages_) { - if (auto converter = internal::GpuBufferStorageRegistry::Get() - .StorageConverterForViewProvider( - view_provider_type, s->storage_type())) { - if (auto new_storage = converter(s)) { - storages_.push_back(new_storage); - chosen_storage = &storages_.back(); + if (s->can_down_cast_to(view_provider_type)) { + chosen_storage = s; + break; + } + } + + // Then try to convert existing storages to one that does. + // TODO: choose best conversion. + if (!chosen_storage) { + for (const auto& s : storages_) { + if (auto converter = internal::GpuBufferStorageRegistry::Get() + .StorageConverterForViewProvider( + view_provider_type, s->storage_type())) { + conversion = absl::bind_front(converter, s); break; } } } } + // Avoid invoking a converter or factory while holding the mutex. + // Two reasons: + // 1. Readers that don't need a conversion will not be blocked. + // 2. We use mutexes to make sure GL contexts are not used simultaneously on + // different threads, and we also rely on Mutex's deadlock detection + // heuristic, which enforces a consistent mutex acquisition order. + // This function is likely to be called within a GL context, and the + // conversion function may in turn use a GL context, and this may cause a + // false positive in the deadlock detector. + // TODO: we could use Mutex::ForgetDeadlockInfo instead. + if (conversion) { + auto new_storage = conversion(); + absl::MutexLock lock(&mutex_); + // Another reader might have already completed and inserted the same + // conversion. TODO: prevent this? + for (const auto& s : storages_) { + if (s->can_down_cast_to(view_provider_type)) { + chosen_storage = s; + break; + } + } + if (!chosen_storage) { + storages_.push_back(std::move(new_storage)); + chosen_storage = storages_.back(); + } + } + if (for_writing) { + // This will temporarily hold storages to be released, and do so while the + // lock is not held (see above). + decltype(storages_) old_storages; + using std::swap; if (chosen_storage) { // Discard all other storages. - storages_ = {*chosen_storage}; - chosen_storage = &storages_.back(); + absl::MutexLock lock(&mutex_); + swap(old_storages, storages_); + storages_ = {chosen_storage}; } else { // Allocate a new storage supporting the requested view. if (auto factory = internal::GpuBufferStorageRegistry::Get() .StorageFactoryForViewProvider(view_provider_type)) { - if (auto new_storage = factory(width(), height(), format())) { + if (auto new_storage = factory(width_, height_, format_)) { + absl::MutexLock lock(&mutex_); + swap(old_storages, storages_); storages_ = {std::move(new_storage)}; - chosen_storage = &storages_.back(); + chosen_storage = storages_.back(); } } } } - return chosen_storage ? chosen_storage->get() : nullptr; + + // It is ok to return a non-owning storage pointer here because this object + // ensures the storage's lifetime. Overwriting a GpuBuffer while readers are + // active would violate this, but it's not allowed in MediaPipe. + return chosen_storage ? chosen_storage.get() : nullptr; } internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie( @@ -84,8 +129,7 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie( GpuBuffer::GetStorageForView(view_provider_type, for_writing); CHECK(chosen_storage) << "no view provider found for requested view " << view_provider_type.name() << "; storages available: " - << absl::StrJoin(storages_, ", ", - StorageTypeFormatter()); + << (holder_ ? holder_->DebugString() : "invalid"); DCHECK(chosen_storage->can_down_cast_to(view_provider_type)); return *chosen_storage; } diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index 56507d92f..b9a88aa53 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -15,9 +15,12 @@ #ifndef MEDIAPIPE_GPU_GPU_BUFFER_H_ #define MEDIAPIPE_GPU_GPU_BUFFER_H_ +#include +#include #include #include +#include "absl/synchronization/mutex.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/gpu_buffer_storage.h" @@ -56,8 +59,7 @@ class GpuBuffer { // Creates an empty buffer of a given size and format. It will be allocated // when a view is requested. GpuBuffer(int width, int height, Format format) - : GpuBuffer(std::make_shared(width, height, - format)) {} + : holder_(std::make_shared(width, height, format)) {} // Copy and move constructors and assignment operators are supported. GpuBuffer(const GpuBuffer& other) = default; @@ -70,9 +72,8 @@ class GpuBuffer { // are not portable. Applications and calculators should normally obtain // GpuBuffers in a portable way from the framework, e.g. using // GpuBufferMultiPool. - explicit GpuBuffer(std::shared_ptr storage) { - storages_.push_back(std::move(storage)); - } + explicit GpuBuffer(std::shared_ptr storage) + : holder_(std::make_shared(std::move(storage))) {} #if !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER // This is used to support backward-compatible construction of GpuBuffer from @@ -84,9 +85,11 @@ class GpuBuffer { : GpuBuffer(internal::AsGpuBufferStorage(storage_convertible)) {} #endif // !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - int width() const { return current_storage().width(); } - int height() const { return current_storage().height(); } - GpuBufferFormat format() const { return current_storage().format(); } + int width() const { return holder_ ? holder_->width() : 0; } + int height() const { return holder_ ? holder_->height() : 0; } + GpuBufferFormat format() const { + return holder_ ? holder_->format() : GpuBufferFormat::kUnknown; + } // Converts to true iff valid. explicit operator bool() const { return operator!=(nullptr); } @@ -122,31 +125,17 @@ class GpuBuffer { // using views. template std::shared_ptr internal_storage() const { - for (const auto& s : storages_) - if (s->down_cast()) return std::static_pointer_cast(s); - return nullptr; + return holder_ ? holder_->internal_storage() : nullptr; } std::string DebugString() const; private: - class PlaceholderGpuBufferStorage - : public internal::GpuBufferStorageImpl { - public: - PlaceholderGpuBufferStorage(int width, int height, Format format) - : width_(width), height_(height), format_(format) {} - int width() const override { return width_; } - int height() const override { return height_; } - GpuBufferFormat format() const override { return format_; } - - private: - int width_ = 0; - int height_ = 0; - GpuBufferFormat format_ = GpuBufferFormat::kUnknown; - }; - internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type, - bool for_writing) const; + bool for_writing) const { + return holder_ ? holder_->GetStorageForView(view_provider_type, for_writing) + : nullptr; + } internal::GpuBufferStorage& GetStorageForViewOrDie(TypeId view_provider_type, bool for_writing) const; @@ -158,25 +147,49 @@ class GpuBuffer { .template down_cast(); } - std::shared_ptr& no_storage() const { - static auto placeholder = - std::static_pointer_cast( - std::make_shared( - 0, 0, GpuBufferFormat::kUnknown)); - return placeholder; - } + // This class manages a set of alternative storages for the contents of a + // GpuBuffer. GpuBuffer was originally designed as a reference-type object, + // where a copy represents another reference to the same contents, so multiple + // GpuBuffer instances can share the same StorageHolder. + class StorageHolder { + public: + explicit StorageHolder(std::shared_ptr storage) + : StorageHolder(storage->width(), storage->height(), + storage->format()) { + storages_.push_back(std::move(storage)); + } + explicit StorageHolder(int width, int height, Format format) + : width_(width), height_(height), format_(format) {} - const internal::GpuBufferStorage& current_storage() const { - return storages_.empty() ? *no_storage() : *storages_[0]; - } + int width() const { return width_; } + int height() const { return height_; } + GpuBufferFormat format() const { return format_; } - internal::GpuBufferStorage& current_storage() { - return storages_.empty() ? *no_storage() : *storages_[0]; - } + internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type, + bool for_writing) const; - // This is mutable because view methods that do not change the contents may - // still need to allocate new storages. - mutable std::vector> storages_; + template + std::shared_ptr internal_storage() const { + absl::MutexLock lock(&mutex_); + for (const auto& s : storages_) + if (s->down_cast()) return std::static_pointer_cast(s); + return nullptr; + } + + std::string DebugString() const; + + private: + int width_ = 0; + int height_ = 0; + GpuBufferFormat format_ = GpuBufferFormat::kUnknown; + // This is mutable because view methods that do not change the contents may + // still need to allocate new storages. + mutable absl::Mutex mutex_; + mutable std::vector> storages_ + ABSL_GUARDED_BY(mutex_); + }; + + std::shared_ptr holder_; #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER friend CVPixelBufferRef GetCVPixelBufferRef(const GpuBuffer& buffer); @@ -184,15 +197,15 @@ class GpuBuffer { }; inline bool GpuBuffer::operator==(std::nullptr_t other) const { - return storages_.empty(); + return holder_ == other; } inline bool GpuBuffer::operator==(const GpuBuffer& other) const { - return storages_ == other.storages_; + return holder_ == other.holder_; } inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) { - storages_.clear(); + holder_ = other; return *this; } diff --git a/mediapipe/gpu/gpu_buffer_test.cc b/mediapipe/gpu/gpu_buffer_test.cc index 145b71806..e4be617db 100644 --- a/mediapipe/gpu/gpu_buffer_test.cc +++ b/mediapipe/gpu/gpu_buffer_test.cc @@ -20,6 +20,7 @@ #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/gpu/gl_texture_util.h" #include "mediapipe/gpu/gpu_buffer_storage_ahwb.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" @@ -228,5 +229,26 @@ TEST_F(GpuBufferTest, GlTextureViewRetainsWhatItNeeds) { EXPECT_TRUE(true); } +TEST_F(GpuBufferTest, CopiesShareConversions) { + GpuBuffer buffer(300, 200, GpuBufferFormat::kBGRA32); + { + std::shared_ptr view = buffer.GetWriteView(); + FillImageFrameRGBA(*view, 255, 0, 0, 255); + } + + GpuBuffer other_handle = buffer; + RunInGlContext([&buffer] { + TempGlFramebuffer fb; + auto view = buffer.GetReadView(0); + }); + + // Check that other_handle also sees the same GlTextureBuffer as buffer. + // Note that this is deliberately written so that it still passes on platforms + // where we use another storage for GL textures (they will both be null). + // TODO: expose more accessors for testing? + EXPECT_EQ(other_handle.internal_storage(), + buffer.internal_storage()); +} + } // anonymous namespace } // namespace mediapipe From aad797197bbb4c4170cd21c6baf18084bee84446 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 7 Dec 2022 07:14:46 -0800 Subject: [PATCH 23/64] TensorV1 EGL.h include fix. PiperOrigin-RevId: 493596083 --- mediapipe/framework/formats/tensor.h | 5 ++--- mediapipe/framework/formats/tensor_ahwb.cc | 5 +++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index ecd63c8c6..3ed72c6fd 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -39,10 +39,9 @@ #endif // MEDIAPIPE_NO_JNI #ifdef MEDIAPIPE_TENSOR_USE_AHWB +#include +#include #include - -#include "third_party/GL/gl/include/EGL/egl.h" -#include "third_party/GL/gl/include/EGL/eglext.h" #endif // MEDIAPIPE_TENSOR_USE_AHWB #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #include "mediapipe/gpu/gl_base.h" diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index b11f6b55b..90d89c40a 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -4,12 +4,13 @@ #include "mediapipe/framework/formats/tensor.h" #ifdef MEDIAPIPE_TENSOR_USE_AHWB +#include +#include + #include "absl/synchronization/mutex.h" #include "mediapipe/framework/port.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/gpu/gl_base.h" -#include "third_party/GL/gl/include/EGL/egl.h" -#include "third_party/GL/gl/include/EGL/eglext.h" #endif // MEDIAPIPE_TENSOR_USE_AHWB namespace mediapipe { From d9688b769f5207aff13bd782d94dd4d2ad8dcd92 Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Wed, 7 Dec 2022 08:13:51 -0800 Subject: [PATCH 24/64] Hide internal APIs from mediapipe pip package's API docs. PiperOrigin-RevId: 493607984 --- .../tasks/python/core/optional_dependencies.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/python/core/optional_dependencies.py b/mediapipe/tasks/python/core/optional_dependencies.py index d4f6a6abc..b1a0ed538 100644 --- a/mediapipe/tasks/python/core/optional_dependencies.py +++ b/mediapipe/tasks/python/core/optional_dependencies.py @@ -13,6 +13,13 @@ # limitations under the License. """MediaPipe Tasks' common but optional dependencies.""" -doc_controls = lambda: None -no_op = lambda x: x -setattr(doc_controls, 'do_not_generate_docs', no_op) +# TensorFlow isn't a dependency of mediapipe pip package. It's only +# required in the API docgen pipeline so we'll ignore it if tensorflow is not +# installed. +try: + from tensorflow.tools.docs import doc_controls +except ModuleNotFoundError: + # Replace the real doc_controls.do_not_generate_docs with an no-op + doc_controls = lambda: None + no_op = lambda x: x + setattr(doc_controls, 'do_not_generate_docs', no_op) From d84eec387bb277b4f379f360d97bbf734cb3ae13 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 7 Dec 2022 10:50:12 -0800 Subject: [PATCH 25/64] Add missing import to InferenceCalculator.proto PiperOrigin-RevId: 493649869 --- mediapipe/calculators/tensor/inference_calculator.proto | 1 + mediapipe/tasks/web/BUILD | 3 --- mediapipe/tasks/web/rollup.config.mjs | 6 ------ package.json | 1 - yarn.lock | 8 -------- 5 files changed, 1 insertion(+), 18 deletions(-) diff --git a/mediapipe/calculators/tensor/inference_calculator.proto b/mediapipe/calculators/tensor/inference_calculator.proto index 46552803b..78a0039bc 100644 --- a/mediapipe/calculators/tensor/inference_calculator.proto +++ b/mediapipe/calculators/tensor/inference_calculator.proto @@ -17,6 +17,7 @@ syntax = "proto2"; package mediapipe; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; option java_package = "com.google.mediapipe.calculator.proto"; option java_outer_classname = "InferenceCalculatorProto"; diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index 20e717433..bc9e84147 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -44,7 +44,6 @@ rollup_bundle( ":audio_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", "@npm//google-protobuf", ], @@ -88,7 +87,6 @@ rollup_bundle( ":text_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", "@npm//google-protobuf", ], @@ -132,7 +130,6 @@ rollup_bundle( ":vision_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", "@npm//google-protobuf", ], diff --git a/mediapipe/tasks/web/rollup.config.mjs b/mediapipe/tasks/web/rollup.config.mjs index e633bf702..3b5119530 100644 --- a/mediapipe/tasks/web/rollup.config.mjs +++ b/mediapipe/tasks/web/rollup.config.mjs @@ -1,15 +1,9 @@ import resolve from '@rollup/plugin-node-resolve'; import commonjs from '@rollup/plugin-commonjs'; -import replace from '@rollup/plugin-replace'; import terser from '@rollup/plugin-terser'; export default { plugins: [ - // Workaround for https://github.com/protocolbuffers/protobuf-javascript/issues/151 - replace({ - 'var calculator_options_pb = {};': 'var calculator_options_pb = {}; var mediapipe_framework_calculator_options_pb = calculator_options_pb;', - delimiters: ['', ''] - }), resolve(), commonjs(), terser() diff --git a/package.json b/package.json index 22a035b74..6ad0b52c0 100644 --- a/package.json +++ b/package.json @@ -7,7 +7,6 @@ "@bazel/typescript": "^5.7.1", "@rollup/plugin-commonjs": "^23.0.2", "@rollup/plugin-node-resolve": "^15.0.1", - "@rollup/plugin-replace": "^5.0.1", "@rollup/plugin-terser": "^0.1.0", "@types/google-protobuf": "^3.15.6", "@types/offscreencanvas": "^2019.7.0", diff --git a/yarn.lock b/yarn.lock index 19c32e322..91b50456e 100644 --- a/yarn.lock +++ b/yarn.lock @@ -148,14 +148,6 @@ is-module "^1.0.0" resolve "^1.22.1" -"@rollup/plugin-replace@^5.0.1": - version "5.0.1" - resolved "https://registry.yarnpkg.com/@rollup/plugin-replace/-/plugin-replace-5.0.1.tgz#49a57af3e6df111a9e75dea3f3572741f4c5c83e" - integrity sha512-Z3MfsJ4CK17BfGrZgvrcp/l6WXoKb0kokULO+zt/7bmcyayokDaQ2K3eDJcRLCTAlp5FPI4/gz9MHAsosz4Rag== - dependencies: - "@rollup/pluginutils" "^5.0.1" - magic-string "^0.26.4" - "@rollup/plugin-terser@^0.1.0": version "0.1.0" resolved "https://registry.yarnpkg.com/@rollup/plugin-terser/-/plugin-terser-0.1.0.tgz#7530c0f11667637419d71820461646c418526041" From 80c605459c2361840c1c0eab05dfa260d7dcfedc Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 7 Dec 2022 11:24:15 -0800 Subject: [PATCH 26/64] Open up framework visibility. PiperOrigin-RevId: 493660013 --- mediapipe/framework/deps/BUILD | 23 ++++---------- mediapipe/framework/port/BUILD | 42 +------------------------ mediapipe/framework/profiler/BUILD | 5 ++- mediapipe/framework/tool/testdata/BUILD | 7 +++-- 4 files changed, 13 insertions(+), 64 deletions(-) diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index 95ab21707..27bc105c8 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -20,7 +20,9 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = [ + "//mediapipe:__subpackages__", +]) bzl_library( name = "expand_template_bzl", @@ -50,13 +52,11 @@ mediapipe_proto_library( cc_library( name = "aligned_malloc_and_free", hdrs = ["aligned_malloc_and_free.h"], - visibility = ["//visibility:public"], ) cc_library( name = "cleanup", hdrs = ["cleanup.h"], - visibility = ["//visibility:public"], deps = ["@com_google_absl//absl/base:core_headers"], ) @@ -86,7 +86,6 @@ cc_library( # Use this library through "mediapipe/framework/port:gtest_main". visibility = [ "//mediapipe/framework/port:__pkg__", - "//third_party/visionai/algorithms/tracking:__pkg__", ], deps = [ "//mediapipe/framework/port:core_proto", @@ -108,7 +107,6 @@ cc_library( name = "file_helpers", srcs = ["file_helpers.cc"], hdrs = ["file_helpers.h"], - visibility = ["//visibility:public"], deps = [ ":file_path", "//mediapipe/framework/port:status", @@ -134,7 +132,6 @@ cc_library( cc_library( name = "image_resizer", hdrs = ["image_resizer.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:opencv_imgproc", ], @@ -151,7 +148,9 @@ cc_library( cc_library( name = "mathutil", hdrs = ["mathutil.h"], - visibility = ["//visibility:public"], + visibility = [ + "//mediapipe:__subpackages__", + ], deps = [ "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -171,7 +170,6 @@ cc_library( cc_library( name = "no_destructor", hdrs = ["no_destructor.h"], - visibility = ["//visibility:public"], ) cc_library( @@ -190,7 +188,6 @@ cc_library( cc_library( name = "random", hdrs = ["random_base.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/port:integral_types"], ) @@ -211,14 +208,12 @@ cc_library( name = "registration_token", srcs = ["registration_token.cc"], hdrs = ["registration_token.h"], - visibility = ["//visibility:public"], ) cc_library( name = "registration", srcs = ["registration.cc"], hdrs = ["registration.h"], - visibility = ["//visibility:public"], deps = [ ":registration_token", "//mediapipe/framework/port:logging", @@ -279,7 +274,6 @@ cc_library( hdrs = [ "re2.h", ], - visibility = ["//visibility:public"], ) cc_library( @@ -310,7 +304,6 @@ cc_library( cc_library( name = "thread_options", hdrs = ["thread_options.h"], - visibility = ["//visibility:public"], ) cc_library( @@ -356,7 +349,6 @@ cc_library( cc_test( name = "mathutil_unittest", srcs = ["mathutil_unittest.cc"], - visibility = ["//visibility:public"], deps = [ ":mathutil", "//mediapipe/framework/port:benchmark", @@ -368,7 +360,6 @@ cc_test( name = "registration_token_test", srcs = ["registration_token_test.cc"], linkstatic = 1, - visibility = ["//visibility:public"], deps = [ ":registration_token", "//mediapipe/framework/port:gtest_main", @@ -381,7 +372,6 @@ cc_test( timeout = "long", srcs = ["safe_int_test.cc"], linkstatic = 1, - visibility = ["//visibility:public"], deps = [ ":intops", "//mediapipe/framework/port:gtest_main", @@ -393,7 +383,6 @@ cc_test( name = "monotonic_clock_test", srcs = ["monotonic_clock_test.cc"], linkstatic = 1, - visibility = ["//visibility:public"], deps = [ ":clock", "//mediapipe/framework/port:gtest_main", diff --git a/mediapipe/framework/port/BUILD b/mediapipe/framework/port/BUILD index e499ca3a6..1039dc1c6 100644 --- a/mediapipe/framework/port/BUILD +++ b/mediapipe/framework/port/BUILD @@ -18,7 +18,7 @@ licenses(["notice"]) package( - default_visibility = ["//visibility:private"], + default_visibility = ["//visibility:public"], features = ["-parse_headers"], ) @@ -28,7 +28,6 @@ config_setting( define_values = { "USE_MEDIAPIPE_THREADPOOL": "1", }, - visibility = ["//visibility:public"], ) #TODO : remove from OSS. @@ -37,13 +36,11 @@ config_setting( define_values = { "USE_MEDIAPIPE_THREADPOOL": "0", }, - visibility = ["//visibility:public"], ) cc_library( name = "aligned_malloc_and_free", hdrs = ["aligned_malloc_and_free.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/deps:aligned_malloc_and_free", "@com_google_absl//absl/base:core_headers", @@ -57,7 +54,6 @@ cc_library( "advanced_proto_inc.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ ":advanced_proto_lite", ":core_proto", @@ -72,7 +68,6 @@ cc_library( "advanced_proto_lite_inc.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ ":core_proto", "//mediapipe/framework:port", @@ -83,7 +78,6 @@ cc_library( cc_library( name = "any_proto", hdrs = ["any_proto.h"], - visibility = ["//visibility:public"], deps = [ ":core_proto", ], @@ -94,7 +88,6 @@ cc_library( hdrs = [ "commandlineflags.h", ], - visibility = ["//visibility:public"], deps = [ "//third_party:glog", "@com_google_absl//absl/flags:flag", @@ -107,7 +100,6 @@ cc_library( "core_proto_inc.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "@com_google_protobuf//:protobuf", @@ -117,7 +109,6 @@ cc_library( cc_library( name = "file_helpers", hdrs = ["file_helpers.h"], - visibility = ["//visibility:public"], deps = [ ":status", "//mediapipe/framework/deps:file_helpers", @@ -128,7 +119,6 @@ cc_library( cc_library( name = "image_resizer", hdrs = ["image_resizer.h"], - visibility = ["//visibility:public"], deps = select({ "//conditions:default": [ "//mediapipe/framework/deps:image_resizer", @@ -140,14 +130,12 @@ cc_library( cc_library( name = "integral_types", hdrs = ["integral_types.h"], - visibility = ["//visibility:public"], ) cc_library( name = "benchmark", testonly = 1, hdrs = ["benchmark.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_benchmark//:benchmark", ], @@ -158,7 +146,6 @@ cc_library( hdrs = [ "re2.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/deps:re2", ], @@ -173,7 +160,6 @@ cc_library( "gtest-spi.h", "status_matchers.h", ], - visibility = ["//visibility:public"], deps = [ ":status_matchers", "//mediapipe/framework/deps:message_matchers", @@ -190,7 +176,6 @@ cc_library( "gtest-spi.h", "status_matchers.h", ], - visibility = ["//visibility:public"], deps = [ ":status_matchers", "//mediapipe/framework/deps:message_matchers", @@ -204,7 +189,6 @@ cc_library( hdrs = [ "logging.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//third_party:glog", @@ -217,7 +201,6 @@ cc_library( hdrs = [ "map_util.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:map_util", @@ -227,7 +210,6 @@ cc_library( cc_library( name = "numbers", hdrs = ["numbers.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:numbers"], ) @@ -238,13 +220,11 @@ config_setting( define_values = { "MEDIAPIPE_DISABLE_OPENCV": "1", }, - visibility = ["//visibility:public"], ) cc_library( name = "opencv_core", hdrs = ["opencv_core_inc.h"], - visibility = ["//visibility:public"], deps = [ "//third_party:opencv", ], @@ -253,7 +233,6 @@ cc_library( cc_library( name = "opencv_imgproc", hdrs = ["opencv_imgproc_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -263,7 +242,6 @@ cc_library( cc_library( name = "opencv_imgcodecs", hdrs = ["opencv_imgcodecs_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -273,7 +251,6 @@ cc_library( cc_library( name = "opencv_highgui", hdrs = ["opencv_highgui_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -283,7 +260,6 @@ cc_library( cc_library( name = "opencv_video", hdrs = ["opencv_video_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//mediapipe/framework:port", @@ -294,7 +270,6 @@ cc_library( cc_library( name = "opencv_features2d", hdrs = ["opencv_features2d_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -304,7 +279,6 @@ cc_library( cc_library( name = "opencv_calib3d", hdrs = ["opencv_calib3d_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//third_party:opencv", @@ -314,7 +288,6 @@ cc_library( cc_library( name = "opencv_videoio", hdrs = ["opencv_videoio_inc.h"], - visibility = ["//visibility:public"], deps = [ ":opencv_core", "//mediapipe/framework:port", @@ -328,7 +301,6 @@ cc_library( "parse_text_proto.h", "proto_ns.h", ], - visibility = ["//visibility:public"], deps = [ ":core_proto", ":logging", @@ -339,14 +311,12 @@ cc_library( cc_library( name = "point", hdrs = ["point2.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:point"], ) cc_library( name = "port", hdrs = ["port.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "@com_google_absl//absl/base:core_headers", @@ -356,14 +326,12 @@ cc_library( cc_library( name = "rectangle", hdrs = ["rectangle.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:rectangle"], ) cc_library( name = "ret_check", hdrs = ["ret_check.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:ret_check", @@ -373,7 +341,6 @@ cc_library( cc_library( name = "singleton", hdrs = ["singleton.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:singleton"], ) @@ -382,7 +349,6 @@ cc_library( hdrs = [ "source_location.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:source_location", @@ -397,7 +363,6 @@ cc_library( "status_builder.h", "status_macros.h", ], - visibility = ["//visibility:public"], deps = [ ":source_location", "//mediapipe/framework:port", @@ -412,7 +377,6 @@ cc_library( hdrs = [ "statusor.h", ], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "@com_google_absl//absl/status:statusor", @@ -423,7 +387,6 @@ cc_library( name = "status_matchers", testonly = 1, hdrs = ["status_matchers.h"], - visibility = ["//visibility:private"], deps = [ ":status", "@com_google_googletest//:gtest", @@ -433,7 +396,6 @@ cc_library( cc_library( name = "threadpool", hdrs = ["threadpool.h"], - visibility = ["//visibility:public"], deps = select({ "//conditions:default": [":threadpool_impl_default_to_google"], "//mediapipe:android": [":threadpool_impl_default_to_mediapipe"], @@ -460,7 +422,6 @@ alias( cc_library( name = "topologicalsorter", hdrs = ["topologicalsorter.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/deps:topologicalsorter", @@ -470,6 +431,5 @@ cc_library( cc_library( name = "vector", hdrs = ["vector.h"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/deps:vector"], ) diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index b53a1ac39..2947b9844 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -140,7 +140,7 @@ cc_library( name = "circular_buffer", hdrs = ["circular_buffer.h"], visibility = [ - "//visibility:public", + "//mediapipe:__subpackages__", ], deps = [ "//mediapipe/framework/port:integral_types", @@ -151,7 +151,6 @@ cc_test( name = "circular_buffer_test", size = "small", srcs = ["circular_buffer_test.cc"], - visibility = ["//visibility:public"], deps = [ ":circular_buffer", "//mediapipe/framework/port:gtest_main", @@ -164,7 +163,7 @@ cc_library( name = "trace_buffer", srcs = ["trace_buffer.h"], hdrs = ["trace_buffer.h"], - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework/profiler:__subpackages__"], deps = [ ":circular_buffer", "//mediapipe/framework:calculator_profile_cc_proto", diff --git a/mediapipe/framework/tool/testdata/BUILD b/mediapipe/framework/tool/testdata/BUILD index 906688520..f9aab7b72 100644 --- a/mediapipe/framework/tool/testdata/BUILD +++ b/mediapipe/framework/tool/testdata/BUILD @@ -20,7 +20,9 @@ load( licenses(["notice"]) -package(default_visibility = ["//mediapipe:__subpackages__"]) +package(default_visibility = [ + "//mediapipe:__subpackages__", +]) filegroup( name = "test_graph", @@ -40,7 +42,6 @@ mediapipe_simple_subgraph( testonly = 1, graph = "dub_quad_test_subgraph.pbtxt", register_as = "DubQuadTestSubgraph", - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:test_calculators", ], @@ -51,7 +52,7 @@ mediapipe_simple_subgraph( testonly = 1, graph = "nested_test_subgraph.pbtxt", register_as = "NestedTestSubgraph", - visibility = ["//visibility:public"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ ":dub_quad_test_subgraph", "//mediapipe/framework:test_calculators", From 3c0ddf16b4c2b04cfff07d0db0aba48411468e9c Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 7 Dec 2022 11:37:04 -0800 Subject: [PATCH 27/64] Fix an incorrect model sanity check in the object detector graph. PiperOrigin-RevId: 493663592 --- .../tasks/cc/vision/object_detector/object_detector_graph.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index f5dc7e061..a1625c16c 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -532,8 +532,7 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); // Checks that the model has 4 outputs. auto& model = *model_resources.GetTfLiteModel(); - if (model.subgraphs()->size() != 1 || - (*model.subgraphs())[0]->outputs()->size() != 4) { + if (model.subgraphs()->size() != 1) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrFormat("Expected a model with a single subgraph, found %d.", From 2811e0c5c81e0ac7d39eab8c32efbe694de45940 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 7 Dec 2022 12:13:25 -0800 Subject: [PATCH 28/64] Open Source the first set of MediaPipe Tasks tests for Web PiperOrigin-RevId: 493673279 --- mediapipe/framework/port/build_config.bzl | 2 + .../tasks/web/components/processors/BUILD | 76 ++++ .../processors/base_options.test.ts | 127 ++++++ .../processors/classifier_options.test.ts | 114 +++++ .../processors/classifier_result.test.ts | 80 ++++ .../processors/embedder_options.test.ts | 93 ++++ .../processors/embedder_result.test.ts | 75 ++++ mediapipe/tasks/web/components/utils/BUILD | 16 + .../utils/cosine_similarity.test.ts | 85 ++++ mediapipe/tasks/web/core/BUILD | 33 ++ mediapipe/tasks/web/core/task_runner.ts | 7 +- mediapipe/tasks/web/core/task_runner_test.ts | 107 +++++ .../tasks/web/core/task_runner_test_utils.ts | 113 +++++ package.json | 5 + tsconfig.json | 2 +- yarn.lock | 419 ++++++++++++++++-- 16 files changed, 1308 insertions(+), 46 deletions(-) create mode 100644 mediapipe/tasks/web/components/processors/base_options.test.ts create mode 100644 mediapipe/tasks/web/components/processors/classifier_options.test.ts create mode 100644 mediapipe/tasks/web/components/processors/classifier_result.test.ts create mode 100644 mediapipe/tasks/web/components/processors/embedder_options.test.ts create mode 100644 mediapipe/tasks/web/components/processors/embedder_result.test.ts create mode 100644 mediapipe/tasks/web/components/utils/cosine_similarity.test.ts create mode 100644 mediapipe/tasks/web/core/task_runner_test.ts create mode 100644 mediapipe/tasks/web/core/task_runner_test_utils.ts diff --git a/mediapipe/framework/port/build_config.bzl b/mediapipe/framework/port/build_config.bzl index eaabda856..94a4a5646 100644 --- a/mediapipe/framework/port/build_config.bzl +++ b/mediapipe/framework/port/build_config.bzl @@ -228,6 +228,8 @@ def mediapipe_ts_library( srcs = srcs, visibility = visibility, deps = deps + [ + "@npm//@types/jasmine", + "@npm//@types/node", "@npm//@types/offscreencanvas", "@npm//@types/google-protobuf", ], diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index 86e743928..148a08238 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -1,5 +1,6 @@ # This package contains options shared by all MediaPipe Tasks for Web. +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -13,6 +14,22 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "classifier_options_test_lib", + testonly = True, + srcs = ["classifier_options.test.ts"], + deps = [ + ":classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_jspb_proto", + "//mediapipe/tasks/web/core:classifier_options", + ], +) + +jasmine_node_test( + name = "classifier_options_test", + deps = [":classifier_options_test_lib"], +) + mediapipe_ts_library( name = "classifier_result", srcs = ["classifier_result.ts"], @@ -22,6 +39,22 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "classifier_result_test_lib", + testonly = True, + srcs = ["classifier_result.test.ts"], + deps = [ + ":classifier_result", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + ], +) + +jasmine_node_test( + name = "classifier_result_test", + deps = [":classifier_result_test_lib"], +) + mediapipe_ts_library( name = "embedder_result", srcs = ["embedder_result.ts"], @@ -31,6 +64,21 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "embedder_result_test_lib", + testonly = True, + srcs = ["embedder_result.test.ts"], + deps = [ + ":embedder_result", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + ], +) + +jasmine_node_test( + name = "embedder_result_test", + deps = [":embedder_result_test_lib"], +) + mediapipe_ts_library( name = "embedder_options", srcs = ["embedder_options.ts"], @@ -40,6 +88,22 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "embedder_options_test_lib", + testonly = True, + srcs = ["embedder_options.test.ts"], + deps = [ + ":embedder_options", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_jspb_proto", + "//mediapipe/tasks/web/core:embedder_options", + ], +) + +jasmine_node_test( + name = "embedder_options_test", + deps = [":embedder_options_test_lib"], +) + mediapipe_ts_library( name = "base_options", srcs = [ @@ -53,3 +117,15 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", ], ) + +mediapipe_ts_library( + name = "base_options_test_lib", + testonly = True, + srcs = ["base_options.test.ts"], + deps = [":base_options"], +) + +jasmine_node_test( + name = "base_options_test", + deps = [":base_options_test_lib"], +) diff --git a/mediapipe/tasks/web/components/processors/base_options.test.ts b/mediapipe/tasks/web/components/processors/base_options.test.ts new file mode 100644 index 000000000..46c2277e9 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/base_options.test.ts @@ -0,0 +1,127 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +// Placeholder for internal dependency on trusted resource URL builder + +import {convertBaseOptionsToProto} from './base_options'; + +describe('convertBaseOptionsToProto()', () => { + const mockBytes = new Uint8Array([0, 1, 2, 3]); + const mockBytesResult = { + modelAsset: { + fileContent: Buffer.from(mockBytes).toString('base64'), + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined, + }, + useStreamMode: false, + acceleration: { + xnnpack: undefined, + gpu: undefined, + tflite: {}, + }, + }; + + let fetchSpy: jasmine.Spy; + + beforeEach(() => { + fetchSpy = jasmine.createSpy().and.callFake(async url => { + expect(url).toEqual('foo'); + return { + arrayBuffer: () => mockBytes.buffer, + } as unknown as Response; + }); + global.fetch = fetchSpy; + }); + + it('verifies that at least one model asset option is provided', async () => { + await expectAsync(convertBaseOptionsToProto({})) + .toBeRejectedWithError( + /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); + }); + + it('verifies that no more than one model asset option is provided', async () => { + await expectAsync(convertBaseOptionsToProto({ + modelAssetPath: `foo`, + modelAssetBuffer: new Uint8Array([]) + })) + .toBeRejectedWithError( + /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); + }); + + it('downloads model', async () => { + const baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetPath: `foo`, + }); + + expect(fetchSpy).toHaveBeenCalled(); + expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); + }); + + it('does not download model when bytes are provided', async () => { + const baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetBuffer: new Uint8Array(mockBytes), + }); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); + }); + + it('can enable CPU delegate', async () => { + const baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'cpu', + }); + expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); + }); + + it('can enable GPU delegate', async () => { + const baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'gpu', + }); + expect(baseOptionsProto.toObject()).toEqual({ + ...mockBytesResult, + acceleration: { + xnnpack: undefined, + gpu: { + useAdvancedGpuApi: false, + api: 0, + allowPrecisionLoss: true, + cachedKernelPath: undefined, + serializedModelDir: undefined, + modelToken: undefined, + usage: 2, + }, + tflite: undefined, + }, + }); + }); + + it('can reset delegate', async () => { + let baseOptionsProto = await convertBaseOptionsToProto({ + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'gpu', + }); + // Clear backend + baseOptionsProto = + await convertBaseOptionsToProto({delegate: undefined}, baseOptionsProto); + expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/classifier_options.test.ts b/mediapipe/tasks/web/components/processors/classifier_options.test.ts new file mode 100644 index 000000000..928bda426 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/classifier_options.test.ts @@ -0,0 +1,114 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {ClassifierOptions as ClassifierOptionsProto} from '../../../../tasks/cc/components/processors/proto/classifier_options_pb'; +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; + +import {convertClassifierOptionsToProto} from './classifier_options'; + +interface TestCase { + optionName: keyof ClassifierOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; +} + +describe('convertClassifierOptionsToProto()', () => { + function verifyOption( + actualClassifierOptions: ClassifierOptionsProto, + expectedClassifierOptions: Record = {}): void { + expect(actualClassifierOptions.toObject()) + .toEqual(jasmine.objectContaining(expectedClassifierOptions)); + } + + const testCases: TestCase[] = [ + { + optionName: 'maxResults', + protoName: 'maxResults', + customValue: 5, + defaultValue: -1 + }, + { + optionName: 'displayNamesLocale', + protoName: 'displayNamesLocale', + customValue: 'en', + defaultValue: 'en' + }, + { + optionName: 'scoreThreshold', + protoName: 'scoreThreshold', + customValue: 0.1, + defaultValue: undefined + }, + { + optionName: 'categoryAllowlist', + protoName: 'categoryAllowlistList', + customValue: ['foo'], + defaultValue: [] + }, + { + optionName: 'categoryDenylist', + protoName: 'categoryDenylistList', + customValue: ['bar'], + defaultValue: [] + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, () => { + const classifierOptionsProto = convertClassifierOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + classifierOptionsProto, {[testCase.protoName]: testCase.customValue}); + }); + + it(`can clear ${testCase.optionName}`, () => { + let classifierOptionsProto = convertClassifierOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + classifierOptionsProto, {[testCase.protoName]: testCase.customValue}); + + classifierOptionsProto = + convertClassifierOptionsToProto({[testCase.optionName]: undefined}); + verifyOption( + classifierOptionsProto, + {[testCase.protoName]: testCase.defaultValue}); + }); + } + + it('overwrites options', () => { + let classifierOptionsProto = + convertClassifierOptionsToProto({maxResults: 1}); + verifyOption(classifierOptionsProto, {'maxResults': 1}); + + classifierOptionsProto = convertClassifierOptionsToProto( + {maxResults: 2}, classifierOptionsProto); + verifyOption(classifierOptionsProto, {'maxResults': 2}); + }); + + it('merges options', () => { + let classifierOptionsProto = + convertClassifierOptionsToProto({maxResults: 1}); + verifyOption(classifierOptionsProto, {'maxResults': 1}); + + classifierOptionsProto = convertClassifierOptionsToProto( + {displayNamesLocale: 'en'}, classifierOptionsProto); + verifyOption( + classifierOptionsProto, {'maxResults': 1, 'displayNamesLocale': 'en'}); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/classifier_result.test.ts b/mediapipe/tasks/web/components/processors/classifier_result.test.ts new file mode 100644 index 000000000..4b93d0a76 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/classifier_result.test.ts @@ -0,0 +1,80 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; + +import {convertFromClassificationResultProto} from './classifier_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +describe('convertFromClassificationResultProto()', () => { + it('transforms custom values', () => { + const classificationResult = new ClassificationResult(); + classificationResult.setTimestampMs(1); + const classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + clasification.setIndex(2); + clasification.setScore(0.3); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + const result = convertFromClassificationResultProto(classificationResult); + + expect(result).toEqual({ + classifications: [{ + categories: [{ + index: 2, + score: 0.3, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }], + timestampMs: 1 + }); + }); + + it('transforms default values', () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + const result = convertFromClassificationResultProto(classificationResult); + + expect(result).toEqual({ + classifications: [{ + categories: [{index: 0, score: 0, displayName: '', categoryName: ''}], + headIndex: 0, + headName: '' + }], + }); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/embedder_options.test.ts b/mediapipe/tasks/web/components/processors/embedder_options.test.ts new file mode 100644 index 000000000..b879a6b29 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/embedder_options.test.ts @@ -0,0 +1,93 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {EmbedderOptions as EmbedderOptionsProto} from '../../../../tasks/cc/components/processors/proto/embedder_options_pb'; +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; + +import {convertEmbedderOptionsToProto} from './embedder_options'; + +interface TestCase { + optionName: keyof EmbedderOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; +} + +describe('convertEmbedderOptionsToProto()', () => { + function verifyOption( + actualEmbedderOptions: EmbedderOptionsProto, + expectedEmbedderOptions: Record = {}): void { + expect(actualEmbedderOptions.toObject()) + .toEqual(jasmine.objectContaining(expectedEmbedderOptions)); + } + + const testCases: TestCase[] = [ + { + optionName: 'l2Normalize', + protoName: 'l2Normalize', + customValue: true, + defaultValue: undefined + }, + { + optionName: 'quantize', + protoName: 'quantize', + customValue: true, + defaultValue: undefined + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, () => { + const embedderOptionsProto = convertEmbedderOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + embedderOptionsProto, {[testCase.protoName]: testCase.customValue}); + }); + + it(`can clear ${testCase.optionName}`, () => { + let embedderOptionsProto = convertEmbedderOptionsToProto( + {[testCase.optionName]: testCase.customValue}); + verifyOption( + embedderOptionsProto, {[testCase.protoName]: testCase.customValue}); + + embedderOptionsProto = + convertEmbedderOptionsToProto({[testCase.optionName]: undefined}); + verifyOption( + embedderOptionsProto, {[testCase.protoName]: testCase.defaultValue}); + }); + } + + it('overwrites options', () => { + let embedderOptionsProto = + convertEmbedderOptionsToProto({l2Normalize: true}); + verifyOption(embedderOptionsProto, {'l2Normalize': true}); + + embedderOptionsProto = convertEmbedderOptionsToProto( + {l2Normalize: false}, embedderOptionsProto); + verifyOption(embedderOptionsProto, {'l2Normalize': false}); + }); + + it('replaces options', () => { + let embedderOptionsProto = convertEmbedderOptionsToProto({quantize: true}); + verifyOption(embedderOptionsProto, {'quantize': true}); + + embedderOptionsProto = convertEmbedderOptionsToProto( + {l2Normalize: true}, embedderOptionsProto); + verifyOption(embedderOptionsProto, {'l2Normalize': true, 'quantize': true}); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/embedder_result.test.ts b/mediapipe/tasks/web/components/processors/embedder_result.test.ts new file mode 100644 index 000000000..97ba935c8 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/embedder_result.test.ts @@ -0,0 +1,75 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {Embedding, EmbeddingResult, FloatEmbedding, QuantizedEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; + +import {convertFromEmbeddingResultProto} from './embedder_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +describe('convertFromEmbeddingResultProto()', () => { + it('transforms custom values', () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + embedding.setFloatEmbedding(floatEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + resultProto.setTimestampMs(1); + + const embedderResult = convertFromEmbeddingResultProto(resultProto); + const embeddings = embedderResult.embeddings; + const timestampMs = embedderResult.timestampMs; + expect(embeddings.length).toEqual(1); + expect(embeddings[0]) + .toEqual( + {floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'}); + expect(timestampMs).toEqual(1); + }); + + it('transforms custom quantized values', () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const quantizedEmbedding = new QuantizedEmbedding(); + const quantizedValues = new Uint8Array([1, 2, 3]); + quantizedEmbedding.setValues(quantizedValues); + + embedding.setQuantizedEmbedding(quantizedEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + resultProto.setTimestampMs(1); + + const embedderResult = convertFromEmbeddingResultProto(resultProto); + const embeddings = embedderResult.embeddings; + const timestampMs = embedderResult.timestampMs; + expect(embeddings.length).toEqual(1); + expect(embeddings[0]).toEqual({ + quantizedEmbedding: new Uint8Array([1, 2, 3]), + headIndex: 1, + headName: 'headName' + }); + expect(timestampMs).toEqual(1); + }); +}); diff --git a/mediapipe/tasks/web/components/utils/BUILD b/mediapipe/tasks/web/components/utils/BUILD index 1c1ba69ca..f4a215e48 100644 --- a/mediapipe/tasks/web/components/utils/BUILD +++ b/mediapipe/tasks/web/components/utils/BUILD @@ -1,4 +1,5 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -9,3 +10,18 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:embedding_result", ], ) + +mediapipe_ts_library( + name = "cosine_similarity_test_lib", + testonly = True, + srcs = ["cosine_similarity.test.ts"], + deps = [ + ":cosine_similarity", + "//mediapipe/tasks/web/components/containers:embedding_result", + ], +) + +jasmine_node_test( + name = "cosine_similarity_test", + deps = [":cosine_similarity_test_lib"], +) diff --git a/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts b/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts new file mode 100644 index 000000000..f442caa20 --- /dev/null +++ b/mediapipe/tasks/web/components/utils/cosine_similarity.test.ts @@ -0,0 +1,85 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + *

Licensed under the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + *

http://www.apache.org/licenses/LICENSE-2.0 + * + *

Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; + +import {computeCosineSimilarity} from './cosine_similarity'; + +describe('computeCosineSimilarity', () => { + it('fails with quantized and float embeddings', () => { + const u: Embedding = {floatEmbedding: [1.0], headIndex: 0, headName: ''}; + const v: Embedding = { + quantizedEmbedding: new Uint8Array([1.0]), + headIndex: 0, + headName: '' + }; + + expect(() => computeCosineSimilarity(u, v)) + .toThrowError( + /Cannot compute cosine similarity between quantized and float embeddings/); + }); + + it('fails with zero norm', () => { + const u = {floatEmbedding: [0.0], headIndex: 0, headName: ''}; + expect(() => computeCosineSimilarity(u, u)) + .toThrowError( + /Cannot compute cosine similarity on embedding with 0 norm/); + }); + + it('fails with different sizes', () => { + const u: + Embedding = {floatEmbedding: [1.0, 2.0], headIndex: 0, headName: ''}; + const v: Embedding = { + floatEmbedding: [1.0, 2.0, 3.0], + headIndex: 0, + headName: '' + }; + + expect(() => computeCosineSimilarity(u, v)) + .toThrowError( + /Cannot compute cosine similarity between embeddings of different sizes/); + }); + + it('succeeds with float embeddings', () => { + const u: Embedding = { + floatEmbedding: [1.0, 0.0, 0.0, 0.0], + headIndex: 0, + headName: '' + }; + const v: Embedding = { + floatEmbedding: [0.5, 0.5, 0.5, 0.5], + headIndex: 0, + headName: '' + }; + + expect(computeCosineSimilarity(u, v)).toEqual(0.5); + }); + + it('succeeds with quantized embeddings', () => { + const u: Embedding = { + quantizedEmbedding: new Uint8Array([255, 128, 128, 128]), + headIndex: 0, + headName: '' + }; + const v: Embedding = { + quantizedEmbedding: new Uint8Array([0, 128, 128, 128]), + headIndex: 0, + headName: '' + }; + + expect(computeCosineSimilarity(u, v)).toEqual(-1.0); + }); +}); diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index be1b71f5d..1721661f5 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -1,6 +1,7 @@ # This package contains options shared by all MediaPipe Tasks for Web. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -32,6 +33,38 @@ mediapipe_ts_library( deps = [":core"], ) +mediapipe_ts_library( + name = "task_runner_test_utils", + testonly = True, + srcs = [ + "task_runner_test_utils.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/web/graph_runner:graph_runner_ts", + "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", + ], +) + +mediapipe_ts_library( + name = "task_runner_test_lib", + testonly = True, + srcs = [ + "task_runner_test.ts", + ], + deps = [ + ":task_runner", + ":task_runner_test_utils", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +jasmine_node_test( + name = "task_runner_test", + deps = [":task_runner_test_lib"], +) + mediapipe_ts_declaration( name = "classifier_options", srcs = ["classifier_options.d.ts"], diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 71e159dce..6712c4d89 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -77,9 +77,10 @@ export abstract class TaskRunner { } constructor( - wasmModule: WasmModule, - glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - this.graphRunner = new GraphRunnerImageLib(wasmModule, glCanvas); + wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, + graphRunner?: GraphRunnerImageLib) { + this.graphRunner = + graphRunner ?? new GraphRunnerImageLib(wasmModule, glCanvas); // Disables the automatic render-to-screen code, which allows for pure // CPU processing. diff --git a/mediapipe/tasks/web/core/task_runner_test.ts b/mediapipe/tasks/web/core/task_runner_test.ts new file mode 100644 index 000000000..c9aad9d25 --- /dev/null +++ b/mediapipe/tasks/web/core/task_runner_test.ts @@ -0,0 +1,107 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; +import {TaskRunner} from '../../../tasks/web/core/task_runner'; +import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils'; +import {ErrorListener} from '../../../web/graph_runner/graph_runner'; + +import {GraphRunnerImageLib} from './task_runner'; + +class TaskRunnerFake extends TaskRunner { + protected baseOptions = new BaseOptionsProto(); + private errorListener: ErrorListener|undefined; + private errors: string[] = []; + + static createFake(): TaskRunnerFake { + const wasmModule = createSpyWasmModule(); + return new TaskRunnerFake(wasmModule); + } + + constructor(wasmModuleFake: SpyWasmModule) { + super( + wasmModuleFake, /* glCanvas= */ null, + jasmine.createSpyObj([ + 'setAutoRenderToScreen', 'setGraph', 'finishProcessing', + 'registerModelResourcesGraphService', 'attachErrorListener' + ])); + const graphRunner = this.graphRunner as jasmine.SpyObj; + expect(graphRunner.registerModelResourcesGraphService).toHaveBeenCalled(); + expect(graphRunner.setAutoRenderToScreen).toHaveBeenCalled(); + graphRunner.attachErrorListener.and.callFake(listener => { + this.errorListener = listener; + }); + graphRunner.setGraph.and.callFake(() => { + this.throwErrors(); + }); + graphRunner.finishProcessing.and.callFake(() => { + this.throwErrors(); + }); + } + + enqueueError(message: string): void { + this.errors.push(message); + } + + override finishProcessing(): void { + super.finishProcessing(); + } + + override setGraph(graphData: Uint8Array, isBinary: boolean): void { + super.setGraph(graphData, isBinary); + } + + private throwErrors(): void { + expect(this.errorListener).toBeDefined(); + for (const error of this.errors) { + this.errorListener!(/* errorCode= */ -1, error); + } + this.errors = []; + } +} + +describe('TaskRunner', () => { + it('handles errors during graph update', () => { + const taskRunner = TaskRunnerFake.createFake(); + taskRunner.enqueueError('Test error'); + + expect(() => { + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + }).toThrowError('Test error'); + }); + + it('handles errors during graph execution', () => { + const taskRunner = TaskRunnerFake.createFake(); + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + + taskRunner.enqueueError('Test error'); + + expect(() => { + taskRunner.finishProcessing(); + }).toThrowError('Test error'); + }); + + it('can handle multiple errors', () => { + const taskRunner = TaskRunnerFake.createFake(); + taskRunner.enqueueError('Test error 1'); + taskRunner.enqueueError('Test error 2'); + + expect(() => { + taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); + }).toThrowError(/Test error 1, Test error 2/); + }); +}); diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts new file mode 100644 index 000000000..2a1161a55 --- /dev/null +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -0,0 +1,113 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../framework/calculator_pb'; +import {WasmModule} from '../../../web/graph_runner/graph_runner'; +import {WasmModuleRegisterModelResources} from '../../../web/graph_runner/register_model_resources_graph_service'; + +type SpyWasmModuleInternal = WasmModule&WasmModuleRegisterModelResources; + +/** + * Convenience type for our fake WasmModule for Jasmine testing. + */ +export declare type SpyWasmModule = jasmine.SpyObj; + +/** + * Factory function for creating a fake WasmModule for our Jasmine tests, + * allowing our APIs to no longer rely on the Wasm layer so they can run tests + * in pure JS/TS (and optionally spy on the calls). + */ +export function createSpyWasmModule(): SpyWasmModule { + return jasmine.createSpyObj([ + '_setAutoRenderToScreen', 'stringToNewUTF8', '_attachProtoListener', + '_attachProtoVectorListener', '_free', '_waitUntilIdle', + '_addStringToInputStream', '_registerModelResourcesGraphService', + '_configureAudio' + ]); +} + +/** + * Sets up our equality testing to use a custom float equality checking function + * to avoid incorrect test results due to minor floating point inaccuracies. + */ +export function addJasmineCustomFloatEqualityTester() { + jasmine.addCustomEqualityTester((a, b) => { // Custom float equality + if (a === +a && b === +b && (a !== (a | 0) || b !== (b | 0))) { + return Math.abs(a - b) < 5e-8; + } + return; + }); +} + +/** The minimum interface provided by a test fake. */ +export interface MediapipeTasksFake { + graph: CalculatorGraphConfig|undefined; + calculatorName: string; + attachListenerSpies: jasmine.Spy[]; +} + +/** An map of field paths to values */ +export type FieldPathToValue = [string[] | string, unknown]; + +/** + * Verifies that the graph has been initialized and that it contains the + * provided options. + */ +export function verifyGraph( + tasksFake: MediapipeTasksFake, + expectedCalculatorOptions?: FieldPathToValue, + expectedBaseOptions?: FieldPathToValue, + ): void { + expect(tasksFake.graph).toBeDefined(); + expect(tasksFake.graph!.getNodeList().length).toBe(1); + const node = tasksFake.graph!.getNodeList()[0].toObject(); + expect(node).toEqual( + jasmine.objectContaining({calculator: tasksFake.calculatorName})); + + if (expectedBaseOptions) { + const [fieldPath, value] = expectedBaseOptions; + let proto = (node.options as {ext: {baseOptions: unknown}}).ext.baseOptions; + for (const fieldName of ( + Array.isArray(fieldPath) ? fieldPath : [fieldPath])) { + proto = ((proto ?? {}) as Record)[fieldName]; + } + expect(proto).toEqual(value); + } + + if (expectedCalculatorOptions) { + const [fieldPath, value] = expectedCalculatorOptions; + let proto = (node.options as {ext: unknown}).ext; + for (const fieldName of ( + Array.isArray(fieldPath) ? fieldPath : [fieldPath])) { + proto = ((proto ?? {}) as Record)[fieldName]; + } + expect(proto).toEqual(value); + } +} + +/** + * Verifies all listeners (as exposed by `.attachListenerSpies`) have been + * attached at least once since the last call to `verifyListenersRegistered()`. + * This helps us to ensure that listeners are re-registered with every graph + * update. + */ +export function verifyListenersRegistered(tasksFake: MediapipeTasksFake): void { + for (const spy of tasksFake.attachListenerSpies) { + expect(spy.calls.count()).toBeGreaterThanOrEqual(1); + spy.calls.reset(); + } +} diff --git a/package.json b/package.json index 6ad0b52c0..89b62bc83 100644 --- a/package.json +++ b/package.json @@ -3,14 +3,19 @@ "version": "0.0.0-alphga", "description": "MediaPipe GitHub repo", "devDependencies": { + "@bazel/jasmine": "^5.7.2", "@bazel/rollup": "^5.7.1", "@bazel/typescript": "^5.7.1", "@rollup/plugin-commonjs": "^23.0.2", "@rollup/plugin-node-resolve": "^15.0.1", "@rollup/plugin-terser": "^0.1.0", "@types/google-protobuf": "^3.15.6", + "@types/jasmine": "^4.3.1", + "@types/node": "^18.11.11", "@types/offscreencanvas": "^2019.7.0", "google-protobuf": "^3.21.2", + "jasmine": "^4.5.0", + "jasmine-core": "^4.5.0", "protobufjs": "^7.1.2", "protobufjs-cli": "^1.0.2", "rollup": "^2.3.0", diff --git a/tsconfig.json b/tsconfig.json index c17b1902e..970246dbb 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -10,7 +10,7 @@ "inlineSourceMap": true, "inlineSources": true, "strict": true, - "types": ["@types/offscreencanvas"], + "types": ["@types/offscreencanvas", "@types/jasmine", "node"], "rootDirs": [ ".", "./bazel-out/host/bin", diff --git a/yarn.lock b/yarn.lock index 91b50456e..9c4d91d30 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3,34 +3,52 @@ "@babel/parser@^7.9.4": - version "7.20.3" - resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.20.3.tgz#5358cf62e380cf69efcb87a7bb922ff88bfac6e2" - integrity sha512-OP/s5a94frIPXwjzEcv5S/tpQfc6XhxYUnmWpgdqMWGgYCuErA3SzozaRAMQgSZWKeTJxht9aWAkUY+0UzvOFg== + version "7.20.5" + resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.20.5.tgz#7f3c7335fe417665d929f34ae5dceae4c04015e8" + integrity sha512-r27t/cy/m9uKLXQNWWebeCUHgnAZq0CpG1OwKRxzJMP1vpSU4bSIK2hq+/cp0bQxetkXx38n09rNu8jVkcK/zA== + +"@bazel/jasmine@^5.7.2": + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/jasmine/-/jasmine-5.7.2.tgz#438f272e66e939106cbdd58db709cd6aa008131b" + integrity sha512-RJruOB6S9e0efTNIE2JVdaslguUXh5KcmLUCq/xLCt0zENP44ssp9OooDIrZ8H+Sp4mLDNBX7CMMA9WTsbsxTQ== + dependencies: + c8 "~7.5.0" + jasmine-reporters "~2.5.0" "@bazel/rollup@^5.7.1": - version "5.7.1" - resolved "https://registry.yarnpkg.com/@bazel/rollup/-/rollup-5.7.1.tgz#6f644c2d493a5bd9cd3724a6f239e609585c6e37" - integrity sha512-LLNogoK2Qx9GIJVywQ+V/czjud8236mnaRX//g7qbOyXoWZDQvAEgsxRHq+lS/XX9USbh+zJJlfb+Dfp/PXx4A== + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/rollup/-/rollup-5.7.2.tgz#9953b06e3de52794791cee4f89540c263b035fcf" + integrity sha512-yGWLheSKdMnJ/Y3/qg+zCDx/qkD04FBFp+BjRS8xP4yvlz9G4rW3zc45VzHHz3oOywgQaY1vhfKuZMCcjTGEyA== dependencies: - "@bazel/worker" "5.7.1" + "@bazel/worker" "5.7.2" "@bazel/typescript@^5.7.1": - version "5.7.1" - resolved "https://registry.yarnpkg.com/@bazel/typescript/-/typescript-5.7.1.tgz#e585bcdc54a4ccb23d99c3e1206abf4853cf0682" - integrity sha512-MAnAtFxA2znadm81+rbYXcyWX1DEF/urzZ1F4LBq+w27EQ4PGyqIqCM5om7JcoSZJwjjMoBJc3SflRsMrZZ6+g== + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/typescript/-/typescript-5.7.2.tgz#a341215dc93ce28794e8430b311756816140bd78" + integrity sha512-tarBJBEIirnq/YaeYu18vXcDxjzlq4xhCXvXUxA0lhHX5oArjEcAEn4tmO0jF+t/7cbkAdMT7daG6vIHSz0QAA== dependencies: - "@bazel/worker" "5.7.1" + "@bazel/worker" "5.7.2" semver "5.6.0" source-map-support "0.5.9" tsutils "3.21.0" -"@bazel/worker@5.7.1": - version "5.7.1" - resolved "https://registry.yarnpkg.com/@bazel/worker/-/worker-5.7.1.tgz#2c4a9bd0e0ef75e496aec9599ff64a87307e7dad" - integrity sha512-UndmQVRqK0t0NMNl8I1P5XmxzdPvMA0X6jufszpfwy5gyzjOxeiOIzmC0ALCOx78CuJqOB/8WOI1pwTRmhd0tg== +"@bazel/worker@5.7.2": + version "5.7.2" + resolved "https://registry.yarnpkg.com/@bazel/worker/-/worker-5.7.2.tgz#43d800dc1b5a3707340a4eb0102da81c53fc6f63" + integrity sha512-H+auDA0QKF4mtZxKkZ2OKJvD7hGXVsVKtvcf4lbb93ur0ldpb5k810PcDxngmIGBcIX5kmyxniNTIiGFNobWTg== dependencies: google-protobuf "^3.6.1" +"@bcoe/v8-coverage@^0.2.3": + version "0.2.3" + resolved "https://registry.yarnpkg.com/@bcoe/v8-coverage/-/v8-coverage-0.2.3.tgz#75a2e8b51cb758a7553d6804a5932d7aace75c39" + integrity sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw== + +"@istanbuljs/schema@^0.1.2": + version "0.1.3" + resolved "https://registry.yarnpkg.com/@istanbuljs/schema/-/schema-0.1.3.tgz#e45e384e4b8ec16bce2fd903af78450f6bf7ec98" + integrity sha512-ZXRY4jNvVgSVQ8DL3LTcakaAtXwTVUxE81hslsyD2AtoXW/wVob10HkOJ1X/pAlcI7D+2YoZKg5do8G/w6RYgA== + "@jridgewell/gen-mapping@^0.3.0": version "0.3.2" resolved "https://registry.yarnpkg.com/@jridgewell/gen-mapping/-/gen-mapping-0.3.2.tgz#c1aedc61e853f2bb9f5dfe6d4442d3b565b253b9" @@ -125,9 +143,9 @@ integrity sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw== "@rollup/plugin-commonjs@^23.0.2": - version "23.0.2" - resolved "https://registry.yarnpkg.com/@rollup/plugin-commonjs/-/plugin-commonjs-23.0.2.tgz#3a3a5b7b1b1cb29037eb4992edcaae997d7ebd92" - integrity sha512-e9ThuiRf93YlVxc4qNIurvv+Hp9dnD+4PjOqQs5vAYfcZ3+AXSrcdzXnVjWxcGQOa6KGJFcRZyUI3ktWLavFjg== + version "23.0.3" + resolved "https://registry.yarnpkg.com/@rollup/plugin-commonjs/-/plugin-commonjs-23.0.3.tgz#442cd8ccca1b7563a503da86fc84a1a7112b54bb" + integrity sha512-31HxrT5emGfTyIfAs1lDQHj6EfYxTXcwtX5pIIhq+B/xZBNIqQ179d/CkYxlpYmFCxT78AeU4M8aL8Iv/IBxFA== dependencies: "@rollup/pluginutils" "^5.0.1" commondir "^1.0.1" @@ -174,6 +192,21 @@ resolved "https://registry.yarnpkg.com/@types/google-protobuf/-/google-protobuf-3.15.6.tgz#674a69493ef2c849b95eafe69167ea59079eb504" integrity sha512-pYVNNJ+winC4aek+lZp93sIKxnXt5qMkuKmaqS3WGuTq0Bw1ZDYNBgzG5kkdtwcv+GmYJGo3yEg6z2cKKAiEdw== +"@types/is-windows@^1.0.0": + version "1.0.0" + resolved "https://registry.yarnpkg.com/@types/is-windows/-/is-windows-1.0.0.tgz#1011fa129d87091e2f6faf9042d6704cdf2e7be0" + integrity sha512-tJ1rq04tGKuIJoWIH0Gyuwv4RQ3+tIu7wQrC0MV47raQ44kIzXSSFKfrxFUOWVRvesoF7mrTqigXmqoZJsXwTg== + +"@types/istanbul-lib-coverage@^2.0.1": + version "2.0.4" + resolved "https://registry.yarnpkg.com/@types/istanbul-lib-coverage/-/istanbul-lib-coverage-2.0.4.tgz#8467d4b3c087805d63580480890791277ce35c44" + integrity sha512-z/QT1XN4K4KYuslS23k62yDIDLwLFkzxOuMplDtObz0+y7VqJCaO2o+SPwHCvLFZh7xazvvoor2tA/hPz9ee7g== + +"@types/jasmine@^4.3.1": + version "4.3.1" + resolved "https://registry.yarnpkg.com/@types/jasmine/-/jasmine-4.3.1.tgz#2d8ab5601c2fe7d9673dcb157e03f128ab5c5fff" + integrity sha512-Vu8l+UGcshYmV1VWwULgnV/2RDbBaO6i2Ptx7nd//oJPIZGhoI1YLST4VKagD2Pq/Bc2/7zvtvhM7F3p4SN7kQ== + "@types/linkify-it@*": version "3.0.2" resolved "https://registry.yarnpkg.com/@types/linkify-it/-/linkify-it-3.0.2.tgz#fd2cd2edbaa7eaac7e7f3c1748b52a19143846c9" @@ -192,10 +225,10 @@ resolved "https://registry.yarnpkg.com/@types/mdurl/-/mdurl-1.0.2.tgz#e2ce9d83a613bacf284c7be7d491945e39e1f8e9" integrity sha512-eC4U9MlIcu2q0KQmXszyn5Akca/0jrQmwDRgpAMJai7qBWq4amIQhZyNau4VYGtCeALvW1/NtjzJJ567aZxfKA== -"@types/node@>=13.7.0": - version "18.11.9" - resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.9.tgz#02d013de7058cea16d36168ef2fc653464cfbad4" - integrity sha512-CRpX21/kGdzjOpFsZSkcrXMGIBWMGNIHXXBVFSH+ggkftxg+XYP20TESbh+zFvFj3EQOl5byk0HTRn1IL6hbqg== +"@types/node@>=13.7.0", "@types/node@^18.11.11": + version "18.11.11" + resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.11.tgz#1d455ac0211549a8409d3cdb371cd55cc971e8dc" + integrity sha512-KJ021B1nlQUBLopzZmPBVuGU9un7WJd/W4ya7Ih02B4Uwky5Nja0yGYav2EfYIk0RR2Q9oVhf60S2XR1BCWJ2g== "@types/offscreencanvas@^2019.7.0": version "2019.7.0" @@ -207,6 +240,11 @@ resolved "https://registry.yarnpkg.com/@types/resolve/-/resolve-1.20.2.tgz#97d26e00cd4a0423b4af620abecf3e6f442b7975" integrity sha512-60BCwRFOZCQhDncwQdxxeOEEkbc5dIMccYLwbxsS4TUNeVECQ/pBJ0j09mrHOl/JJvpRPGwO9SvE4nR2Nb/a4Q== +"@xmldom/xmldom@^0.8.5": + version "0.8.6" + resolved "https://registry.yarnpkg.com/@xmldom/xmldom/-/xmldom-0.8.6.tgz#8a1524eb5bd5e965c1e3735476f0262469f71440" + integrity sha512-uRjjusqpoqfmRkTaNuLJ2VohVr67Q5YwDATW3VU7PfzTj6IRaihGrYI7zckGZjxQPBIp63nfvJbM+Yu5ICh0Bg== + acorn-jsx@^5.3.2: version "5.3.2" resolved "https://registry.yarnpkg.com/acorn-jsx/-/acorn-jsx-5.3.2.tgz#7ed5bb55908b3b2f1bc55c6af1653bada7f07937" @@ -217,7 +255,12 @@ acorn@^8.5.0, acorn@^8.8.0: resolved "https://registry.yarnpkg.com/acorn/-/acorn-8.8.1.tgz#0a3f9cbecc4ec3bea6f0a80b66ae8dd2da250b73" integrity sha512-7zFpHzhnqYKrkYdUjF1HI1bzd0VygEGX8lFk4k5zVMqHEoES+P+7TKI+EvLO9WVMJ8eekdO0aDEK044xTXwPPA== -ansi-styles@^4.1.0: +ansi-regex@^5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.1.tgz#082cb2c89c9fe8659a311a53bd6a4dc5301db304" + integrity sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ== + +ansi-styles@^4.0.0, ansi-styles@^4.1.0: version "4.3.0" resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937" integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg== @@ -264,6 +307,25 @@ builtin-modules@^3.3.0: resolved "https://registry.yarnpkg.com/builtin-modules/-/builtin-modules-3.3.0.tgz#cae62812b89801e9656336e46223e030386be7b6" integrity sha512-zhaCDicdLuWN5UbN5IMnFqNMhNfo919sH85y2/ea+5Yg9TsTkeZxpL+JLbp6cgYFS4sRLp3YV4S6yDuqVWHYOw== +c8@~7.5.0: + version "7.5.0" + resolved "https://registry.yarnpkg.com/c8/-/c8-7.5.0.tgz#a69439ab82848f344a74bb25dc5dd4e867764481" + integrity sha512-GSkLsbvDr+FIwjNSJ8OwzWAyuznEYGTAd1pzb/Kr0FMLuV4vqYJTyjboDTwmlUNAG6jAU3PFWzqIdKrOt1D8tw== + dependencies: + "@bcoe/v8-coverage" "^0.2.3" + "@istanbuljs/schema" "^0.1.2" + find-up "^5.0.0" + foreground-child "^2.0.0" + furi "^2.0.0" + istanbul-lib-coverage "^3.0.0" + istanbul-lib-report "^3.0.0" + istanbul-reports "^3.0.2" + rimraf "^3.0.0" + test-exclude "^6.0.0" + v8-to-istanbul "^7.1.0" + yargs "^16.0.0" + yargs-parser "^20.0.0" + catharsis@^0.9.0: version "0.9.0" resolved "https://registry.yarnpkg.com/catharsis/-/catharsis-0.9.0.tgz#40382a168be0e6da308c277d3a2b3eb40c7d2121" @@ -279,6 +341,15 @@ chalk@^4.0.0: ansi-styles "^4.1.0" supports-color "^7.1.0" +cliui@^7.0.2: + version "7.0.4" + resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.4.tgz#a0265ee655476fc807aea9df3df8df7783808b4f" + integrity sha512-OcRE68cOsVMXp1Yvonl/fzkQOyjLSu/8bhPDfQt0e0/Eb283TKP20Fs2MqoPsr9SwA595rRCA+QMzYc9nBP+JQ== + dependencies: + string-width "^4.2.0" + strip-ansi "^6.0.0" + wrap-ansi "^7.0.0" + color-convert@^2.0.1: version "2.0.1" resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3" @@ -306,6 +377,20 @@ concat-map@0.0.1: resolved "https://registry.yarnpkg.com/concat-map/-/concat-map-0.0.1.tgz#d8a96bd77fd68df7793a73036a3ba0d5405d477b" integrity sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg== +convert-source-map@^1.6.0: + version "1.9.0" + resolved "https://registry.yarnpkg.com/convert-source-map/-/convert-source-map-1.9.0.tgz#7faae62353fb4213366d0ca98358d22e8368b05f" + integrity sha512-ASFBup0Mz1uyiIjANan1jzLQami9z1PoYSZCiiYW2FczPbenXc45FZdBZLzOT+r6+iciuEModtmCti+hjaAk0A== + +cross-spawn@^7.0.0: + version "7.0.3" + resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.3.tgz#f73a85b9d5d41d045551c177e2882d4ac85728a6" + integrity sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w== + dependencies: + path-key "^3.1.0" + shebang-command "^2.0.0" + which "^2.0.1" + deep-is@~0.1.3: version "0.1.4" resolved "https://registry.yarnpkg.com/deep-is/-/deep-is-0.1.4.tgz#a6f2dce612fadd2ef1f519b73551f17e85199831" @@ -316,11 +401,21 @@ deepmerge@^4.2.2: resolved "https://registry.yarnpkg.com/deepmerge/-/deepmerge-4.2.2.tgz#44d2ea3679b8f4d4ffba33f03d865fc1e7bf4955" integrity sha512-FJ3UgI4gIl+PHZm53knsuSFpE+nESMr7M4v9QcgB7S63Kj/6WqMiFQJpBBYz1Pt+66bZpP3Q7Lye0Oo9MPKEdg== +emoji-regex@^8.0.0: + version "8.0.0" + resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-8.0.0.tgz#e818fd69ce5ccfcb404594f842963bf53164cc37" + integrity sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A== + entities@~2.1.0: version "2.1.0" resolved "https://registry.yarnpkg.com/entities/-/entities-2.1.0.tgz#992d3129cf7df6870b96c57858c249a120f8b8b5" integrity sha512-hCx1oky9PFrJ611mf0ifBLBRW8lUUVRlFolb5gWRfIELabBlbp9xZvrqZLZAs+NxFnbfQoeGd8wDkygjg7U85w== +escalade@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40" + integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw== + escape-string-regexp@^2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-2.0.0.tgz#a30304e99daa32e23b2fd20f51babd07cffca344" @@ -382,6 +477,22 @@ fast-levenshtein@~2.0.6: resolved "https://registry.yarnpkg.com/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz#3d8a5c66883a16a30ca8643e851f19baa7797917" integrity sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw== +find-up@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/find-up/-/find-up-5.0.0.tgz#4c92819ecb7083561e4f4a240a86be5198f536fc" + integrity sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng== + dependencies: + locate-path "^6.0.0" + path-exists "^4.0.0" + +foreground-child@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/foreground-child/-/foreground-child-2.0.0.tgz#71b32800c9f15aa8f2f83f4a6bd9bff35d861a53" + integrity sha512-dCIq9FpEcyQyXKCkyzmlPTFNgrCzPudOe+mhvJU5zAtlBnGVy2yKxtfsxK2tQBThwq225jcvBjpw1Gr40uzZCA== + dependencies: + cross-spawn "^7.0.0" + signal-exit "^3.0.2" + fs.realpath@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/fs.realpath/-/fs.realpath-1.0.0.tgz#1504ad2523158caa40db4a2787cb01411994ea4f" @@ -397,7 +508,20 @@ function-bind@^1.1.1: resolved "https://registry.yarnpkg.com/function-bind/-/function-bind-1.1.1.tgz#a56899d3ea3c9bab874bb9773b7c5ede92f4895d" integrity sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A== -glob@^7.1.3: +furi@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/furi/-/furi-2.0.0.tgz#13d85826a1af21acc691da6254b3888fc39f0b4a" + integrity sha512-uKuNsaU0WVaK/vmvj23wW1bicOFfyqSsAIH71bRZx8kA4Xj+YCHin7CJKJJjkIsmxYaPFLk9ljmjEyB7xF7WvQ== + dependencies: + "@types/is-windows" "^1.0.0" + is-windows "^1.0.2" + +get-caller-file@^2.0.5: + version "2.0.5" + resolved "https://registry.yarnpkg.com/get-caller-file/-/get-caller-file-2.0.5.tgz#4f94412a82db32f36e3b0b9741f8a97feb031f7e" + integrity sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg== + +glob@^7.1.3, glob@^7.1.4, glob@^7.1.6: version "7.2.3" resolved "https://registry.yarnpkg.com/glob/-/glob-7.2.3.tgz#b8df0fb802bbfa8e89bd1d938b4e16578ed44f2b" integrity sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q== @@ -442,6 +566,11 @@ has@^1.0.3: dependencies: function-bind "^1.1.1" +html-escaper@^2.0.0: + version "2.0.2" + resolved "https://registry.yarnpkg.com/html-escaper/-/html-escaper-2.0.2.tgz#dfd60027da36a36dfcbe236262c00a5822681453" + integrity sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg== + inflight@^1.0.4: version "1.0.6" resolved "https://registry.yarnpkg.com/inflight/-/inflight-1.0.6.tgz#49bd6331d7d02d0c09bc910a1075ba8165b56df9" @@ -469,6 +598,11 @@ is-core-module@^2.9.0: dependencies: has "^1.0.3" +is-fullwidth-code-point@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz#f116f8064fe90b3f7844a38997c0b75051269f1d" + integrity sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg== + is-module@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/is-module/-/is-module-1.0.0.tgz#3258fb69f78c14d5b815d664336b4cffb6441591" @@ -481,6 +615,59 @@ is-reference@1.2.1: dependencies: "@types/estree" "*" +is-windows@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/is-windows/-/is-windows-1.0.2.tgz#d1850eb9791ecd18e6182ce12a30f396634bb19d" + integrity sha512-eXK1UInq2bPmjyX6e3VHIzMLobc4J94i4AWn+Hpq3OU5KkrRC96OAcR3PRJ/pGu6m8TRnBHP9dkXQVsT/COVIA== + +isexe@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/isexe/-/isexe-2.0.0.tgz#e8fbf374dc556ff8947a10dcb0572d633f2cfa10" + integrity sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw== + +istanbul-lib-coverage@^3.0.0: + version "3.2.0" + resolved "https://registry.yarnpkg.com/istanbul-lib-coverage/-/istanbul-lib-coverage-3.2.0.tgz#189e7909d0a39fa5a3dfad5b03f71947770191d3" + integrity sha512-eOeJ5BHCmHYvQK7xt9GkdHuzuCGS1Y6g9Gvnx3Ym33fz/HpLRYxiS0wHNr+m/MBC8B647Xt608vCDEvhl9c6Mw== + +istanbul-lib-report@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/istanbul-lib-report/-/istanbul-lib-report-3.0.0.tgz#7518fe52ea44de372f460a76b5ecda9ffb73d8a6" + integrity sha512-wcdi+uAKzfiGT2abPpKZ0hSU1rGQjUQnLvtY5MpQ7QCTahD3VODhcu4wcfY1YtkGaDD5yuydOLINXsfbus9ROw== + dependencies: + istanbul-lib-coverage "^3.0.0" + make-dir "^3.0.0" + supports-color "^7.1.0" + +istanbul-reports@^3.0.2: + version "3.1.5" + resolved "https://registry.yarnpkg.com/istanbul-reports/-/istanbul-reports-3.1.5.tgz#cc9a6ab25cb25659810e4785ed9d9fb742578bae" + integrity sha512-nUsEMa9pBt/NOHqbcbeJEgqIlY/K7rVWUX6Lql2orY5e9roQOthbR3vtY4zzf2orPELg80fnxxk9zUyPlgwD1w== + dependencies: + html-escaper "^2.0.0" + istanbul-lib-report "^3.0.0" + +jasmine-core@^4.5.0: + version "4.5.0" + resolved "https://registry.yarnpkg.com/jasmine-core/-/jasmine-core-4.5.0.tgz#1a6bd0bde3f60996164311c88a0995d67ceda7c3" + integrity sha512-9PMzyvhtocxb3aXJVOPqBDswdgyAeSB81QnLop4npOpbqnheaTEwPc9ZloQeVswugPManznQBjD8kWDTjlnHuw== + +jasmine-reporters@~2.5.0: + version "2.5.2" + resolved "https://registry.yarnpkg.com/jasmine-reporters/-/jasmine-reporters-2.5.2.tgz#b5dfa1d9c40b8020c5225e0e1e2b9953d66a4d69" + integrity sha512-qdewRUuFOSiWhiyWZX8Yx3YNQ9JG51ntBEO4ekLQRpktxFTwUHy24a86zD/Oi2BRTKksEdfWQZcQFqzjqIkPig== + dependencies: + "@xmldom/xmldom" "^0.8.5" + mkdirp "^1.0.4" + +jasmine@^4.5.0: + version "4.5.0" + resolved "https://registry.yarnpkg.com/jasmine/-/jasmine-4.5.0.tgz#8d3c0d0a33a61e4d05c9f9747ee5dfaf6f7b5d3d" + integrity sha512-9olGRvNZyADIwYL9XBNBst5BTU/YaePzuddK+YRslc7rI9MdTIE4r3xaBKbv2GEmzYYUfMOdTR8/i6JfLZaxSQ== + dependencies: + glob "^7.1.6" + jasmine-core "^4.5.0" + js2xmlparser@^4.0.2: version "4.0.2" resolved "https://registry.yarnpkg.com/js2xmlparser/-/js2xmlparser-4.0.2.tgz#2a1fdf01e90585ef2ae872a01bc169c6a8d5e60a" @@ -531,7 +718,14 @@ linkify-it@^3.0.1: dependencies: uc.micro "^1.0.1" -lodash@^4.17.14, lodash@^4.17.15: +locate-path@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/locate-path/-/locate-path-6.0.0.tgz#55321eb309febbc59c4801d931a72452a681d286" + integrity sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw== + dependencies: + p-locate "^5.0.0" + +lodash@^4.17.15, lodash@^4.17.21: version "4.17.21" resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c" integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg== @@ -555,6 +749,13 @@ magic-string@^0.26.4: dependencies: sourcemap-codec "^1.4.8" +make-dir@^3.0.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/make-dir/-/make-dir-3.1.0.tgz#415e967046b3a7f1d185277d84aa58203726a13f" + integrity sha512-g3FeP20LNwhALb/6Cz6Dd4F2ngze0jz7tbzrD2wAV+o9FeNHe4rL+yK2md0J/fiSf1sa1ADhXqi5+oVwOM/eGw== + dependencies: + semver "^6.0.0" + markdown-it-anchor@^8.4.1: version "8.6.5" resolved "https://registry.yarnpkg.com/markdown-it-anchor/-/markdown-it-anchor-8.6.5.tgz#30c4bc5bbff327f15ce3c429010ec7ba75e7b5f8" @@ -572,16 +773,16 @@ markdown-it@^12.3.2: uc.micro "^1.0.5" marked@^4.0.10: - version "4.2.2" - resolved "https://registry.yarnpkg.com/marked/-/marked-4.2.2.tgz#1d2075ad6cdfe42e651ac221c32d949a26c0672a" - integrity sha512-JjBTFTAvuTgANXx82a5vzK9JLSMoV6V3LBVn4Uhdso6t7vXrGx7g1Cd2r6NYSsxrYbQGFCMqBDhFHyK5q2UvcQ== + version "4.2.3" + resolved "https://registry.yarnpkg.com/marked/-/marked-4.2.3.tgz#bd76a5eb510ff1d8421bc6c3b2f0b93488c15bea" + integrity sha512-slWRdJkbTZ+PjkyJnE30Uid64eHwbwa1Q25INCAYfZlK4o6ylagBy/Le9eWntqJFoFT93ikUKMv47GZ4gTwHkw== mdurl@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/mdurl/-/mdurl-1.0.1.tgz#fe85b2ec75a59037f2adfec100fd6c601761152e" integrity sha512-/sKlQJCBYVY9Ers9hqzKou4H6V5UWc/M59TH2dvkt+84itfnq7uFOMLpOiOS4ujvHP4etln18fmIxA5R5fll0g== -minimatch@^3.1.1: +minimatch@^3.0.4, minimatch@^3.1.1: version "3.1.2" resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-3.1.2.tgz#19cd194bfd3e428f049a70817c038d89ab4be35b" integrity sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw== @@ -589,9 +790,9 @@ minimatch@^3.1.1: brace-expansion "^1.1.7" minimatch@^5.0.1: - version "5.1.0" - resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-5.1.0.tgz#1717b464f4971b144f6aabe8f2d0b8e4511e09c7" - integrity sha512-9TPBGGak4nHfGZsPBohm9AWg6NoT7QTCehS3BIJABslyZbzxfV78QM2Y6+i741OPZIafFAaiiEMh5OyIrJPgtg== + version "5.1.1" + resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-5.1.1.tgz#6c9dffcf9927ff2a31e74b5af11adf8b9604b022" + integrity sha512-362NP+zlprccbEt/SkxKfRMHnNY85V74mVnpUpNyr3F35covl09Kec7/sEFLt3RA4oXmewtoaanoIf67SE5Y5g== dependencies: brace-expansion "^2.0.1" @@ -624,11 +825,35 @@ optionator@^0.8.1: type-check "~0.3.2" word-wrap "~1.2.3" +p-limit@^3.0.2: + version "3.1.0" + resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-3.1.0.tgz#e1daccbe78d0d1388ca18c64fea38e3e57e3706b" + integrity sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ== + dependencies: + yocto-queue "^0.1.0" + +p-locate@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/p-locate/-/p-locate-5.0.0.tgz#83c8315c6785005e3bd021839411c9e110e6d834" + integrity sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw== + dependencies: + p-limit "^3.0.2" + +path-exists@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/path-exists/-/path-exists-4.0.0.tgz#513bdbe2d3b95d7762e8c1137efa195c6c61b5b3" + integrity sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w== + path-is-absolute@^1.0.0: version "1.0.1" resolved "https://registry.yarnpkg.com/path-is-absolute/-/path-is-absolute-1.0.1.tgz#174b9268735534ffbc7ace6bf53a5a9e1b5c5f5f" integrity sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg== +path-key@^3.1.0: + version "3.1.1" + resolved "https://registry.yarnpkg.com/path-key/-/path-key-3.1.1.tgz#581f6ade658cbba65a0d3380de7753295054f375" + integrity sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q== + path-parse@^1.0.7: version "1.0.7" resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.7.tgz#fbc114b60ca42b30d9daf5858e4bd68bbedb6735" @@ -678,12 +903,17 @@ protobufjs@^7.1.2: "@types/node" ">=13.7.0" long "^5.0.0" +require-directory@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/require-directory/-/require-directory-2.1.1.tgz#8c64ad5fd30dab1c976e2344ffe7f792a6a6df42" + integrity sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q== + requizzle@^0.2.3: - version "0.2.3" - resolved "https://registry.yarnpkg.com/requizzle/-/requizzle-0.2.3.tgz#4675c90aacafb2c036bd39ba2daa4a1cb777fded" - integrity sha512-YanoyJjykPxGHii0fZP0uUPEXpvqfBDxWV7s6GKAiiOsiqhX6vHNyW3Qzdmqp/iq/ExbhaGbVrjB4ruEVSM4GQ== + version "0.2.4" + resolved "https://registry.yarnpkg.com/requizzle/-/requizzle-0.2.4.tgz#319eb658b28c370f0c20f968fa8ceab98c13d27c" + integrity sha512-JRrFk1D4OQ4SqovXOgdav+K8EAhSB/LJZqCz8tbX0KObcdeM15Ss59ozWMBWmmINMagCwmqn4ZNryUGpBsl6Jw== dependencies: - lodash "^4.17.14" + lodash "^4.17.21" resolve@^1.22.1: version "1.22.1" @@ -713,6 +943,11 @@ semver@5.6.0: resolved "https://registry.yarnpkg.com/semver/-/semver-5.6.0.tgz#7e74256fbaa49c75aa7c7a205cc22799cac80004" integrity sha512-RS9R6R35NYgQn++fkDWaOmqGoj4Ek9gGs+DPxNUZKuwE183xjJroKvyo1IzVFeXvUrvmALy6FWD5xrdJT25gMg== +semver@^6.0.0: + version "6.3.0" + resolved "https://registry.yarnpkg.com/semver/-/semver-6.3.0.tgz#ee0a64c8af5e8ceea67687b133761e1becbd1d3d" + integrity sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw== + semver@^7.1.2: version "7.3.8" resolved "https://registry.yarnpkg.com/semver/-/semver-7.3.8.tgz#07a78feafb3f7b32347d725e33de7e2a2df67798" @@ -720,6 +955,23 @@ semver@^7.1.2: dependencies: lru-cache "^6.0.0" +shebang-command@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/shebang-command/-/shebang-command-2.0.0.tgz#ccd0af4f8835fbdc265b82461aaf0c36663f34ea" + integrity sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA== + dependencies: + shebang-regex "^3.0.0" + +shebang-regex@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/shebang-regex/-/shebang-regex-3.0.0.tgz#ae16f1644d873ecad843b0307b143362d4c42172" + integrity sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A== + +signal-exit@^3.0.2: + version "3.0.7" + resolved "https://registry.yarnpkg.com/signal-exit/-/signal-exit-3.0.7.tgz#a9a1767f8af84155114eaabd73f99273c8f59ad9" + integrity sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ== + source-map-support@0.5.9: version "0.5.9" resolved "https://registry.yarnpkg.com/source-map-support/-/source-map-support-0.5.9.tgz#41bc953b2534267ea2d605bccfa7bfa3111ced5f" @@ -741,11 +993,32 @@ source-map@^0.6.0, source-map@~0.6.1: resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.6.1.tgz#74722af32e9614e9c287a8d0bbde48b5e2f1a263" integrity sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g== +source-map@^0.7.3: + version "0.7.4" + resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.7.4.tgz#a9bbe705c9d8846f4e08ff6765acf0f1b0898656" + integrity sha512-l3BikUxvPOcn5E74dZiq5BGsTb5yEwhaTSzccU6t4sDOH8NWJCstKO5QT2CvtFoK6F0saL7p9xHAqHOlCPJygA== + sourcemap-codec@^1.4.8: version "1.4.8" resolved "https://registry.yarnpkg.com/sourcemap-codec/-/sourcemap-codec-1.4.8.tgz#ea804bd94857402e6992d05a38ef1ae35a9ab4c4" integrity sha512-9NykojV5Uih4lgo5So5dtw+f0JgJX30KCNI8gwhz2J9A15wD0Ml6tjHKwf6fTSa6fAdVBdZeNOs9eJ71qCk8vA== +string-width@^4.1.0, string-width@^4.2.0: + version "4.2.3" + resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" + integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== + dependencies: + emoji-regex "^8.0.0" + is-fullwidth-code-point "^3.0.0" + strip-ansi "^6.0.1" + +strip-ansi@^6.0.0, strip-ansi@^6.0.1: + version "6.0.1" + resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9" + integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== + dependencies: + ansi-regex "^5.0.1" + strip-json-comments@^3.1.0: version "3.1.1" resolved "https://registry.yarnpkg.com/strip-json-comments/-/strip-json-comments-3.1.1.tgz#31f1281b3832630434831c310c01cccda8cbe006" @@ -769,15 +1042,24 @@ taffydb@2.6.2: integrity sha512-y3JaeRSplks6NYQuCOj3ZFMO3j60rTwbuKCvZxsAraGYH2epusatvZ0baZYA01WsGqJBq/Dl6vOrMUJqyMj8kA== terser@^5.15.1: - version "5.15.1" - resolved "https://registry.yarnpkg.com/terser/-/terser-5.15.1.tgz#8561af6e0fd6d839669c73b92bdd5777d870ed6c" - integrity sha512-K1faMUvpm/FBxjBXud0LWVAGxmvoPbZbfTCYbSgaaYQaIXI3/TdI7a7ZGA73Zrou6Q8Zmz3oeUTsp/dj+ag2Xw== + version "5.16.1" + resolved "https://registry.yarnpkg.com/terser/-/terser-5.16.1.tgz#5af3bc3d0f24241c7fb2024199d5c461a1075880" + integrity sha512-xvQfyfA1ayT0qdK47zskQgRZeWLoOQ8JQ6mIgRGVNwZKdQMU+5FkCBjmv4QjcrTzyZquRw2FVtlJSRUmMKQslw== dependencies: "@jridgewell/source-map" "^0.3.2" acorn "^8.5.0" commander "^2.20.0" source-map-support "~0.5.20" +test-exclude@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/test-exclude/-/test-exclude-6.0.0.tgz#04a8698661d805ea6fa293b6cb9e63ac044ef15e" + integrity sha512-cAGWPIyOHU6zlmg88jwm7VRyXnMN7iV68OGAbYDk/Mh/xC/pzVPlQtY6ngoIH/5/tciuhGfvESU8GrHrcxD56w== + dependencies: + "@istanbuljs/schema" "^0.1.2" + glob "^7.1.4" + minimatch "^3.0.4" + tmp@^0.2.1: version "0.2.1" resolved "https://registry.yarnpkg.com/tmp/-/tmp-0.2.1.tgz#8457fc3037dcf4719c251367a1af6500ee1ccf14" @@ -812,9 +1094,9 @@ type-check@~0.3.2: prelude-ls "~1.1.2" typescript@^4.8.4: - version "4.8.4" - resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.8.4.tgz#c464abca159669597be5f96b8943500b238e60e6" - integrity sha512-QCh+85mCy+h0IGff8r5XWzOVSbBO+KfeYrMQh7NJ58QujwcE22u+NUSmUxqF+un70P9GXKxa2HCNiTTMJknyjQ== + version "4.9.3" + resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.9.3.tgz#3aea307c1746b8c384435d8ac36b8a2e580d85db" + integrity sha512-CIfGzTelbKNEnLpLdGFgdyKhG23CKdKgQPOBc+OUNrkJ2vr+KSzsSV5kq5iWhEQbok+quxgGzrAtGWCyU7tHnA== uc.micro@^1.0.1, uc.micro@^1.0.5: version "1.0.6" @@ -831,11 +1113,36 @@ underscore@~1.13.2: resolved "https://registry.yarnpkg.com/underscore/-/underscore-1.13.6.tgz#04786a1f589dc6c09f761fc5f45b89e935136441" integrity sha512-+A5Sja4HP1M08MaXya7p5LvjuM7K6q/2EaC0+iovj/wOcMsTzMvDFbasi/oSapiwOlt252IqsKqPjCl7huKS0A== +v8-to-istanbul@^7.1.0: + version "7.1.2" + resolved "https://registry.yarnpkg.com/v8-to-istanbul/-/v8-to-istanbul-7.1.2.tgz#30898d1a7fa0c84d225a2c1434fb958f290883c1" + integrity sha512-TxNb7YEUwkLXCQYeudi6lgQ/SZrzNO4kMdlqVxaZPUIUjCv6iSSypUQX70kNBSERpQ8fk48+d61FXk+tgqcWow== + dependencies: + "@types/istanbul-lib-coverage" "^2.0.1" + convert-source-map "^1.6.0" + source-map "^0.7.3" + +which@^2.0.1: + version "2.0.2" + resolved "https://registry.yarnpkg.com/which/-/which-2.0.2.tgz#7c6a8dd0a636a0327e10b59c9286eee93f3f51b1" + integrity sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA== + dependencies: + isexe "^2.0.0" + word-wrap@~1.2.3: version "1.2.3" resolved "https://registry.yarnpkg.com/word-wrap/-/word-wrap-1.2.3.tgz#610636f6b1f703891bd34771ccb17fb93b47079c" integrity sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ== +wrap-ansi@^7.0.0: + version "7.0.0" + resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43" + integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q== + dependencies: + ansi-styles "^4.0.0" + string-width "^4.1.0" + strip-ansi "^6.0.0" + wrappy@1: version "1.0.2" resolved "https://registry.yarnpkg.com/wrappy/-/wrappy-1.0.2.tgz#b5243d8f3ec1aa35f1364605bc0d1036e30ab69f" @@ -846,7 +1153,35 @@ xmlcreate@^2.0.4: resolved "https://registry.yarnpkg.com/xmlcreate/-/xmlcreate-2.0.4.tgz#0c5ab0f99cdd02a81065fa9cd8f8ae87624889be" integrity sha512-nquOebG4sngPmGPICTS5EnxqhKbCmz5Ox5hsszI2T6U5qdrJizBc+0ilYSEjTSzU0yZcmvppztXe/5Al5fUwdg== +y18n@^5.0.5: + version "5.0.8" + resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.8.tgz#7f4934d0f7ca8c56f95314939ddcd2dd91ce1d55" + integrity sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA== + yallist@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72" integrity sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A== + +yargs-parser@^20.0.0, yargs-parser@^20.2.2: + version "20.2.9" + resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-20.2.9.tgz#2eb7dc3b0289718fc295f362753845c41a0c94ee" + integrity sha512-y11nGElTIV+CT3Zv9t7VKl+Q3hTQoT9a1Qzezhhl6Rp21gJ/IVTW7Z3y9EWXhuUBC2Shnf+DX0antecpAwSP8w== + +yargs@^16.0.0: + version "16.2.0" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-16.2.0.tgz#1c82bf0f6b6a66eafce7ef30e376f49a12477f66" + integrity sha512-D1mvvtDG0L5ft/jGWkLpG1+m0eQxOfaBvTNELraWj22wSVUMWxZUvYgJYcKh6jGGIkJFhH4IZPQhR4TKpc8mBw== + dependencies: + cliui "^7.0.2" + escalade "^3.1.1" + get-caller-file "^2.0.5" + require-directory "^2.1.1" + string-width "^4.2.0" + y18n "^5.0.5" + yargs-parser "^20.2.2" + +yocto-queue@^0.1.0: + version "0.1.0" + resolved "https://registry.yarnpkg.com/yocto-queue/-/yocto-queue-0.1.0.tgz#0294eb3dee05028d31ee1a5fa2c556a6aaf10a1b" + integrity sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q== From 955f090f9f0e69300c3bc331c52a426b0dec5dab Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 7 Dec 2022 13:06:57 -0800 Subject: [PATCH 29/64] Retire the visibility group "//mediapipe/framework:mediapipe_internal". PiperOrigin-RevId: 493687025 --- mediapipe/framework/profiler/BUILD | 4 +--- mediapipe/framework/tool/BUILD | 19 +++++++++---------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 2947b9844..3b6976fc8 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -291,9 +291,7 @@ cc_library( "-ObjC++", ], }), - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], + visibility = ["//visibility:private"], deps = [ "@com_google_absl//absl/flags:flag", "//mediapipe/framework/port:logging", diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 52d04b4b1..453b5a0e8 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -90,7 +90,7 @@ mediapipe_proto_library( name = "packet_generator_wrapper_calculator_proto", srcs = ["packet_generator_wrapper_calculator.proto"], def_py_proto = False, - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:packet_generator_proto", @@ -120,13 +120,13 @@ cc_library( name = "fill_packet_set", srcs = ["fill_packet_set.cc"], hdrs = ["fill_packet_set.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ + ":status_util", "//mediapipe/framework:packet_set", "//mediapipe/framework:packet_type", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", - "//mediapipe/framework/tool:status_util", "@com_google_absl//absl/memory", ], ) @@ -162,7 +162,6 @@ cc_library( cc_test( name = "executor_util_test", srcs = ["executor_util_test.cc"], - visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":executor_util", "//mediapipe/framework/port:gtest_main", @@ -173,7 +172,7 @@ cc_test( cc_library( name = "options_map", hdrs = ["options_map.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe:__subpackages__"], deps = [ ":type_util", "//mediapipe/framework:calculator_cc_proto", @@ -193,7 +192,7 @@ cc_library( name = "options_field_util", srcs = ["options_field_util.cc"], hdrs = ["options_field_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:private"], deps = [ ":field_data_cc_proto", ":name_util", @@ -216,7 +215,7 @@ cc_library( name = "options_syntax_util", srcs = ["options_syntax_util.cc"], hdrs = ["options_syntax_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:private"], deps = [ ":name_util", ":options_field_util", @@ -235,8 +234,9 @@ cc_library( name = "options_util", srcs = ["options_util.cc"], hdrs = ["options_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:public"], deps = [ + ":name_util", ":options_field_util", ":options_map", ":options_registry", @@ -254,7 +254,6 @@ cc_library( "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:name_util", "@com_google_absl//absl/strings", ], ) @@ -323,7 +322,7 @@ mediapipe_cc_test( cc_library( name = "packet_generator_wrapper_calculator", srcs = ["packet_generator_wrapper_calculator.cc"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//mediapipe/framework:__subpackages__"], deps = [ ":packet_generator_wrapper_calculator_cc_proto", "//mediapipe/framework:calculator_base", From ea74db86dd2926278c9d2486bf58e23caa1a97a6 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 7 Dec 2022 14:04:37 -0800 Subject: [PATCH 30/64] Tensor: clang tidy fixes. PiperOrigin-RevId: 493703073 --- mediapipe/framework/formats/tensor.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index c31eba350..9e1406dbb 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -551,7 +551,7 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { }); } else #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - + { // Transfer data from texture if not transferred from SSBO/MTLBuffer // yet. if (valid_ & kValidOpenGlTexture2d) { @@ -582,6 +582,7 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { } }); } + } #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 valid_ |= kValidCpu; } From 7faee517c4606e647ae63ae4296fae54d08f6abb Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 7 Dec 2022 14:31:02 -0800 Subject: [PATCH 31/64] Tensor: Move general CPU/SSBO tensor storage into Ahwb-backed CPU/SSBO storage. PiperOrigin-RevId: 493710495 --- mediapipe/framework/formats/tensor.h | 1 + mediapipe/framework/formats/tensor_ahwb.cc | 40 +++++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 3ed72c6fd..151aa299d 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -418,6 +418,7 @@ class Tensor { void ReleaseAhwbStuff(); void* MapAhwbToCpuRead() const; void* MapAhwbToCpuWrite() const; + void MoveCpuOrSsboToAhwb() const; #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 mutable std::shared_ptr gl_context_; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 90d89c40a..21bae9593 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -215,10 +215,15 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const { CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer)) << "Interoperability bettween OpenGL buffer and AHardwareBuffer is not " "supported on targe system."; + bool transfer = !ahwb_; CHECK(AllocateAHardwareBuffer()) << "AHardwareBuffer is not supported on the target system."; valid_ |= kValidAHardwareBuffer; - if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd(); + if (transfer) { + MoveCpuOrSsboToAhwb(); + } else { + if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd(); + } return {ahwb_, ssbo_written_, &fence_fd_, // The FD is created for SSBO -> AHWB synchronization. @@ -303,6 +308,39 @@ bool Tensor::AllocateAhwbMapToSsbo() const { return false; } +// Moves Cpu/Ssbo resource under the Ahwb backed memory. +void Tensor::MoveCpuOrSsboToAhwb() const { + void* dest = nullptr; + if (__builtin_available(android 26, *)) { + auto error = AHardwareBuffer_lock( + ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, -1, nullptr, &dest); + CHECK(error == 0) << "AHardwareBuffer_lock " << error; + } + if (valid_ & kValidOpenGlBuffer) { + gl_context_->Run([this, dest]() { + glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); + const void* src = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(), + GL_MAP_READ_BIT); + std::memcpy(dest, src, bytes()); + glUnmapBuffer(GL_SHADER_STORAGE_BUFFER); + glDeleteBuffers(1, &opengl_buffer_); + }); + opengl_buffer_ = GL_INVALID_INDEX; + gl_context_ = nullptr; + } else if (valid_ & kValidCpu) { + std::memcpy(dest, cpu_buffer_, bytes()); + // Free CPU memory because next time AHWB is mapped instead. + free(cpu_buffer_); + cpu_buffer_ = nullptr; + } else { + LOG(FATAL) << "Can't convert tensor with mask " << valid_ << " into AHWB."; + } + if (__builtin_available(android 26, *)) { + auto error = AHardwareBuffer_unlock(ahwb_, nullptr); + CHECK(error == 0) << "AHardwareBuffer_unlock " << error; + } +} + // SSBO is created on top of AHWB. A fence is inserted into the GPU queue before // the GPU task that is going to read from the SSBO. When the writing into AHWB // is finished then the GPU reads from the SSBO. From ef1507ed5df48f00daaa2111518cc0a32faec3a6 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 7 Dec 2022 14:43:59 -0800 Subject: [PATCH 32/64] Retire the visibility group "//mediapipe/framework:mediapipe_internal". PiperOrigin-RevId: 493713823 --- mediapipe/tasks/cc/core/BUILD | 5 ++++- mediapipe/tasks/cc/text/tokenizers/BUILD | 2 +- mediapipe/tasks/testdata/text/BUILD | 5 ++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index f8004d257..d440271df 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -309,7 +309,10 @@ cc_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = [ + "//mediapipe/calculators:__subpackages__", + "//mediapipe/tasks:internal", + ], deps = [ "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto", diff --git a/mediapipe/tasks/cc/text/tokenizers/BUILD b/mediapipe/tasks/cc/text/tokenizers/BUILD index 7f1ea2848..92fac8eaa 100644 --- a/mediapipe/tasks/cc/text/tokenizers/BUILD +++ b/mediapipe/tasks/cc/text/tokenizers/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/framework:mediapipe_internal"]) +package(default_visibility = ["//mediapipe/calculators/tensor:__subpackages__"]) licenses(["notice"]) diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index 081e63c2c..a0131c056 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -18,7 +18,10 @@ load( ) package( - default_visibility = ["//mediapipe/framework:mediapipe_internal"], + default_visibility = [ + "//mediapipe/calculators/tensor:__subpackages__", + "//mediapipe/tasks:__subpackages__", + ], licenses = ["notice"], # Apache 2.0 ) From 91664eb254bb44adb03b1e4823e0a5250d1f3837 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 7 Dec 2022 14:52:58 -0800 Subject: [PATCH 33/64] Object Detector deduplication PiperOrigin-RevId: 493716159 --- mediapipe/calculators/util/BUILD | 17 +++ .../util/detections_deduplicate_calculator.cc | 114 ++++++++++++++++++ .../tasks/cc/vision/object_detector/BUILD | 1 + .../object_detector/object_detector_graph.cc | 7 +- 4 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 mediapipe/calculators/util/detections_deduplicate_calculator.cc diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 43eadd53b..1529ead8a 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -456,6 +456,23 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "detections_deduplicate_calculator", + srcs = [ + "detections_deduplicate_calculator.cc", + ], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + ], + alwayslink = 1, +) + cc_library( name = "rect_transformation_calculator", srcs = ["rect_transformation_calculator.cc"], diff --git a/mediapipe/calculators/util/detections_deduplicate_calculator.cc b/mediapipe/calculators/util/detections_deduplicate_calculator.cc new file mode 100644 index 000000000..2dfa09028 --- /dev/null +++ b/mediapipe/calculators/util/detections_deduplicate_calculator.cc @@ -0,0 +1,114 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" + +namespace mediapipe { +namespace api2 { +namespace { + +struct BoundingBoxHash { + size_t operator()(const LocationData::BoundingBox& bbox) const { + return std::hash{}(bbox.xmin()) ^ std::hash{}(bbox.ymin()) ^ + std::hash{}(bbox.width()) ^ std::hash{}(bbox.height()); + } +}; + +struct BoundingBoxEq { + bool operator()(const LocationData::BoundingBox& lhs, + const LocationData::BoundingBox& rhs) const { + return lhs.xmin() == rhs.xmin() && lhs.ymin() == rhs.ymin() && + lhs.width() == rhs.width() && lhs.height() == rhs.height(); + } +}; + +} // namespace + +// This Calculator deduplicates the bunding boxes with exactly the same +// coordinates, and folds the labels into a single Detection proto. Note +// non-maximum-suppression remove the overlapping bounding boxes within a class, +// while the deduplication operation merges bounding boxes from different +// classes. + +// Example config: +// node { +// calculator: "DetectionsDeduplicateCalculator" +// input_stream: "detections" +// output_stream: "deduplicated_detections" +// } +class DetectionsDeduplicateCalculator : public Node { + public: + static constexpr Input> kIn{""}; + static constexpr Output> kOut{""}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + absl::Status Open(mediapipe::CalculatorContext* cc) { + cc->SetOffset(::mediapipe::TimestampDiff(0)); + return absl::OkStatus(); + } + + absl::Status Process(mediapipe::CalculatorContext* cc) { + const std::vector& raw_detections = kIn(cc).Get(); + absl::flat_hash_map + bbox_to_detections; + std::vector deduplicated_detections; + for (const auto& detection : raw_detections) { + if (!detection.has_location_data() || + !detection.location_data().has_bounding_box()) { + return absl::InvalidArgumentError( + "The location data of Detections must be BoundingBox."); + } + if (bbox_to_detections.contains( + detection.location_data().bounding_box())) { + // The bbox location already exists. Merge the detection labels into + // the existing detection proto. + Detection& deduplicated_detection = + *bbox_to_detections[detection.location_data().bounding_box()]; + deduplicated_detection.mutable_score()->MergeFrom(detection.score()); + deduplicated_detection.mutable_label()->MergeFrom(detection.label()); + deduplicated_detection.mutable_label_id()->MergeFrom( + detection.label_id()); + deduplicated_detection.mutable_display_name()->MergeFrom( + detection.display_name()); + } else { + // The bbox location appears first time. Add the detection to output + // detection vector. + deduplicated_detections.push_back(detection); + bbox_to_detections[detection.location_data().bounding_box()] = + &deduplicated_detections.back(); + } + } + kOut(cc).Send(std::move(deduplicated_detections)); + return absl::OkStatus(); + } +}; + +MEDIAPIPE_REGISTER_NODE(DetectionsDeduplicateCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index c2dd9995d..224eca520 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -63,6 +63,7 @@ cc_library( "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", "//mediapipe/calculators/util:detection_projection_calculator", "//mediapipe/calculators/util:detection_transformation_calculator", + "//mediapipe/calculators/util:detections_deduplicate_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index a1625c16c..fd95bb1ac 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -662,11 +662,16 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { detection_transformation.Out(kPixelDetectionsTag) >> detection_label_id_to_text.In(""); + // Deduplicate Detections with same bounding box coordinates. + auto& detections_deduplicate = + graph.AddNode("DetectionsDeduplicateCalculator"); + detection_label_id_to_text.Out("") >> detections_deduplicate.In(""); + // Outputs the labeled detections and the processed image as the subgraph // output streams. return {{ /* detections= */ - detection_label_id_to_text[Output>("")], + detections_deduplicate[Output>("")], /* image= */ preprocessing[Output(kImageTag)], }}; } From 5f97b29b3ba41d1a7765221ed242e0fab9a89751 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 7 Dec 2022 15:23:10 -0800 Subject: [PATCH 34/64] Update Bazel dependencies for Apple PiperOrigin-RevId: 493723833 --- WORKSPACE | 56 +++++++++++++++++++++++++------------------------------ 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index d43394883..bf5e4236b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -320,12 +320,30 @@ http_archive( ], ) -# iOS basic build deps. +# Load Zlib before initializing TensorFlow and the iOS build rules to guarantee +# that the target @zlib//:mini_zlib is available +http_archive( + name = "zlib", + build_file = "//third_party:zlib.BUILD", + sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", + strip_prefix = "zlib-1.2.11", + urls = [ + "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", + "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 + ], + patches = [ + "@//third_party:zlib.diff", + ], + patch_args = [ + "-p1", + ], +) +# iOS basic build deps. http_archive( name = "build_bazel_rules_apple", - sha256 = "77e8bf6fda706f420a55874ae6ee4df0c9d95da6c7838228b26910fc82eea5a2", - url = "https://github.com/bazelbuild/rules_apple/releases/download/0.32.0/rules_apple.0.32.0.tar.gz", + sha256 = "f94e6dddf74739ef5cb30f000e13a2a613f6ebfa5e63588305a71fce8a8a9911", + url = "https://github.com/bazelbuild/rules_apple/releases/download/1.1.3/rules_apple.1.1.3.tar.gz", patches = [ # Bypass checking ios unit test runner when building MP ios applications. "@//third_party:build_bazel_rules_apple_bypass_test_runner_check.diff" @@ -339,29 +357,24 @@ load( "@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies", ) - apple_rules_dependencies() load( "@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies", ) - swift_rules_dependencies() -http_archive( - name = "build_bazel_apple_support", - sha256 = "741366f79d900c11e11d8efd6cc6c66a31bfb2451178b58e0b5edc6f1db17b35", - urls = [ - "https://github.com/bazelbuild/apple_support/releases/download/0.10.0/apple_support.0.10.0.tar.gz" - ], +load( + "@build_bazel_rules_swift//swift:extras.bzl", + "swift_rules_extra_dependencies", ) +swift_rules_extra_dependencies() load( "@build_bazel_apple_support//lib:repositories.bzl", "apple_support_dependencies", ) - apple_support_dependencies() # More iOS deps. @@ -442,25 +455,6 @@ http_archive( ], ) -# Load Zlib before initializing TensorFlow to guarantee that the target -# @zlib//:mini_zlib is available -http_archive( - name = "zlib", - build_file = "//third_party:zlib.BUILD", - sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", - strip_prefix = "zlib-1.2.11", - urls = [ - "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", - "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 - ], - patches = [ - "@//third_party:zlib.diff", - ], - patch_args = [ - "-p1", - ], -) - # TensorFlow repo should always go after the other external dependencies. # TF on 2022-08-10. _TENSORFLOW_GIT_COMMIT = "af1d5bc4fbb66d9e6cc1cf89503014a99233583b" From a59f0a99243a77c5f1cef684c5cd542c320c59f8 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 7 Dec 2022 15:49:08 -0800 Subject: [PATCH 35/64] Make java/C++/python tasks API public visible. PiperOrigin-RevId: 493730506 --- mediapipe/tasks/cc/audio/audio_classifier/BUILD | 4 +--- mediapipe/tasks/cc/audio/audio_embedder/BUILD | 4 +--- mediapipe/tasks/cc/vision/object_detector/BUILD | 4 +--- mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD | 2 +- .../com/google/mediapipe/tasks/components/containers/BUILD | 2 +- .../com/google/mediapipe/tasks/components/processors/BUILD | 2 +- .../java/com/google/mediapipe/tasks/components/utils/BUILD | 2 +- mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD | 2 +- mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD | 2 +- mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD | 2 +- mediapipe/tasks/python/audio/BUILD | 2 +- mediapipe/tasks/python/audio/core/BUILD | 2 +- mediapipe/tasks/python/components/containers/BUILD | 2 +- mediapipe/tasks/python/components/processors/BUILD | 2 +- mediapipe/tasks/python/components/utils/BUILD | 2 +- mediapipe/tasks/python/core/BUILD | 2 +- mediapipe/tasks/python/text/BUILD | 2 +- mediapipe/tasks/python/text/core/BUILD | 2 +- mediapipe/tasks/python/vision/BUILD | 2 +- mediapipe/tasks/python/vision/core/BUILD | 2 +- 20 files changed, 20 insertions(+), 26 deletions(-) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index f61472413..c575caabe 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -22,9 +22,7 @@ cc_library( name = "audio_classifier", srcs = ["audio_classifier.cc"], hdrs = ["audio_classifier.h"], - visibility = [ - "//mediapipe/tasks:users", - ], + visibility = ["//visibility:public"], deps = [ ":audio_classifier_graph", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/audio/audio_embedder/BUILD b/mediapipe/tasks/cc/audio/audio_embedder/BUILD index 6a0f627b2..1dfdd6f1b 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/cc/audio/audio_embedder/BUILD @@ -22,9 +22,7 @@ cc_library( name = "audio_embedder", srcs = ["audio_embedder.cc"], hdrs = ["audio_embedder.h"], - visibility = [ - "//mediapipe/tasks:users", - ], + visibility = ["//visibility:public"], deps = [ ":audio_embedder_graph", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 224eca520..77373303a 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -22,9 +22,7 @@ cc_library( name = "object_detector", srcs = ["object_detector.cc"], hdrs = ["object_detector.h"], - visibility = [ - "//mediapipe/tasks:users", - ], + visibility = ["//visibility:public"], deps = [ ":object_detector_graph", "//mediapipe/calculators/core:concatenate_vector_calculator", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD index 2d29ccf23..e5d472e8a 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD @@ -14,7 +14,7 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) android_library( name = "core", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index ad17d5552..4d302b950 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD index 1f99f1612..b4d453935 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD index b2d27bfa7..6c724106f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index 01b1f653a..31f885267 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) android_library( name = "core", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index 5b10e9aab..31cd2c89a 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -14,7 +14,7 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) # The native library of all MediaPipe text tasks. cc_binary( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 6161fe032..f469aed0c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -14,7 +14,7 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) android_library( name = "core", diff --git a/mediapipe/tasks/python/audio/BUILD b/mediapipe/tasks/python/audio/BUILD index ce7c5ce08..6dda7a53c 100644 --- a/mediapipe/tasks/python/audio/BUILD +++ b/mediapipe/tasks/python/audio/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/audio/core/BUILD b/mediapipe/tasks/python/audio/core/BUILD index 3cb9cb8e8..5b4203d7b 100644 --- a/mediapipe/tasks/python/audio/core/BUILD +++ b/mediapipe/tasks/python/audio/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 9d275e167..7108617ff 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/components/processors/BUILD b/mediapipe/tasks/python/components/processors/BUILD index f87a579b0..695f6df91 100644 --- a/mediapipe/tasks/python/components/processors/BUILD +++ b/mediapipe/tasks/python/components/processors/BUILD @@ -16,7 +16,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/components/utils/BUILD b/mediapipe/tasks/python/components/utils/BUILD index 31114f326..1a18531c6 100644 --- a/mediapipe/tasks/python/components/utils/BUILD +++ b/mediapipe/tasks/python/components/utils/BUILD @@ -16,7 +16,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index fc0018ab1..447189d6f 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index e2a51cdbd..9d5d23261 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/text/core/BUILD b/mediapipe/tasks/python/text/core/BUILD index 072a0c7d8..e76bd4b6d 100644 --- a/mediapipe/tasks/python/text/core/BUILD +++ b/mediapipe/tasks/python/text/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 241ca4341..5f4aa38ff 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/python/vision/core/BUILD b/mediapipe/tasks/python/vision/core/BUILD index e2b2b3dec..18df690a0 100644 --- a/mediapipe/tasks/python/vision/core/BUILD +++ b/mediapipe/tasks/python/vision/core/BUILD @@ -14,7 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) From a0efcb47f23666f84448d82fcede6dab9fdfbf55 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 7 Dec 2022 16:37:08 -0800 Subject: [PATCH 36/64] internal change PiperOrigin-RevId: 493742399 --- .../tasks/cc/components/containers/BUILD | 13 + .../components/containers/detection_result.cc | 73 ++++++ .../components/containers/detection_result.h | 52 ++++ .../tasks/cc/components/containers/rect.cc | 34 +++ .../tasks/cc/components/containers/rect.h | 29 ++- .../cc/vision/core/base_vision_task_api.h | 4 +- .../cc/vision/core/image_processing_options.h | 3 +- ...hand_landmarks_deduplication_calculator.cc | 14 +- .../hand_landmarker/hand_landmarker_test.cc | 4 +- .../image_classifier/image_classifier_test.cc | 19 +- .../image_embedder/image_embedder_test.cc | 6 +- .../image_segmenter/image_segmenter_test.cc | 4 +- .../tasks/cc/vision/object_detector/BUILD | 1 + .../vision/object_detector/object_detector.cc | 15 +- .../vision/object_detector/object_detector.h | 12 +- .../object_detector/object_detector_test.cc | 227 ++++++++++-------- .../tasks/cc/vision/utils/landmarks_utils.cc | 8 +- .../tasks/cc/vision/utils/landmarks_utils.h | 10 +- 18 files changed, 377 insertions(+), 151 deletions(-) create mode 100644 mediapipe/tasks/cc/components/containers/detection_result.cc create mode 100644 mediapipe/tasks/cc/components/containers/detection_result.h create mode 100644 mediapipe/tasks/cc/components/containers/rect.cc diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index 35d3f4785..0750a1482 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -18,6 +18,7 @@ licenses(["notice"]) cc_library( name = "rect", + srcs = ["rect.cc"], hdrs = ["rect.h"], ) @@ -41,6 +42,18 @@ cc_library( ], ) +cc_library( + name = "detection_result", + srcs = ["detection_result.cc"], + hdrs = ["detection_result.h"], + deps = [ + ":category", + ":rect", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + ], +) + cc_library( name = "embedding_result", srcs = ["embedding_result.cc"], diff --git a/mediapipe/tasks/cc/components/containers/detection_result.cc b/mediapipe/tasks/cc/components/containers/detection_result.cc new file mode 100644 index 000000000..43c8ca0f5 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/detection_result.cc @@ -0,0 +1,73 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/detection_result.h" + +#include + +#include +#include +#include + +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +constexpr int kDefaultCategoryIndex = -1; + +Detection ConvertToDetectionResult( + const mediapipe::Detection& detection_proto) { + Detection detection; + for (int idx = 0; idx < detection_proto.score_size(); ++idx) { + detection.categories.push_back( + {/* index= */ detection_proto.label_id_size() > idx + ? detection_proto.label_id(idx) + : kDefaultCategoryIndex, + /* score= */ detection_proto.score(idx), + /* category_name */ detection_proto.label_size() > idx + ? detection_proto.label(idx) + : "", + /* display_name */ detection_proto.display_name_size() > idx + ? detection_proto.display_name(idx) + : ""}); + } + Rect bounding_box; + if (detection_proto.location_data().has_bounding_box()) { + mediapipe::LocationData::BoundingBox bounding_box_proto = + detection_proto.location_data().bounding_box(); + bounding_box.left = bounding_box_proto.xmin(); + bounding_box.top = bounding_box_proto.ymin(); + bounding_box.right = bounding_box_proto.xmin() + bounding_box_proto.width(); + bounding_box.bottom = + bounding_box_proto.ymin() + bounding_box_proto.height(); + } + detection.bounding_box = bounding_box; + return detection; +} + +DetectionResult ConvertToDetectionResult( + std::vector detections_proto) { + DetectionResult detection_result; + detection_result.detections.reserve(detections_proto.size()); + for (const auto& detection_proto : detections_proto) { + detection_result.detections.push_back( + ConvertToDetectionResult(detection_proto)); + } + return detection_result; +} +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/detection_result.h b/mediapipe/tasks/cc/components/containers/detection_result.h new file mode 100644 index 000000000..546f324d6 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/detection_result.h @@ -0,0 +1,52 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ + +#include +#include +#include + +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +// Detection for a single bounding box. +struct Detection { + // A vector of detected categories. + std::vector categories; + // The bounding box location. + Rect bounding_box; +}; + +// Detection results of a model. +struct DetectionResult { + // A vector of Detections. + std::vector detections; +}; + +// Utility function to convert from Detection proto to Detection struct. +Detection ConvertToDetection(const mediapipe::Detection& detection_proto); + +// Utility function to convert from list of Detection proto to DetectionResult +// struct. +DetectionResult ConvertToDetectionResult( + std::vector detections_proto); + +} // namespace mediapipe::tasks::components::containers +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_DETECTION_RESULT_H_ diff --git a/mediapipe/tasks/cc/components/containers/rect.cc b/mediapipe/tasks/cc/components/containers/rect.cc new file mode 100644 index 000000000..4a94832a6 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/rect.cc @@ -0,0 +1,34 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::components::containers { + +RectF ToRectF(const Rect& rect, int image_height, int image_width) { + return RectF{static_cast(rect.left) / image_width, + static_cast(rect.top) / image_height, + static_cast(rect.right) / image_width, + static_cast(rect.bottom) / image_height}; +} + +Rect ToRect(const RectF& rect, int image_height, int image_width) { + return Rect{static_cast(rect.left * image_width), + static_cast(rect.top * image_height), + static_cast(rect.right * image_width), + static_cast(rect.bottom * image_height)}; +} + +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/rect.h b/mediapipe/tasks/cc/components/containers/rect.h index 3f5432cf2..551d91588 100644 --- a/mediapipe/tasks/cc/components/containers/rect.h +++ b/mediapipe/tasks/cc/components/containers/rect.h @@ -16,20 +16,47 @@ limitations under the License. #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ +#include + namespace mediapipe::tasks::components::containers { +constexpr float kRectFTolerance = 1e-4; + // Defines a rectangle, used e.g. as part of detection results or as input // region-of-interest. // +struct Rect { + int left; + int top; + int right; + int bottom; +}; + +inline bool operator==(const Rect& lhs, const Rect& rhs) { + return lhs.left == rhs.left && lhs.top == rhs.top && lhs.right == rhs.right && + lhs.bottom == rhs.bottom; +} + // The coordinates are normalized wrt the image dimensions, i.e. generally in // [0,1] but they may exceed these bounds if describing a region overlapping the // image. The origin is on the top-left corner of the image. -struct Rect { +struct RectF { float left; float top; float right; float bottom; }; +inline bool operator==(const RectF& lhs, const RectF& rhs) { + return abs(lhs.left - rhs.left) < kRectFTolerance && + abs(lhs.top - rhs.top) < kRectFTolerance && + abs(lhs.right - rhs.right) < kRectFTolerance && + abs(lhs.bottom - rhs.bottom) < kRectFTolerance; +} + +RectF ToRectF(const Rect& rect, int image_height, int image_width); + +Rect ToRect(const RectF& rect, int image_height, int image_width); + } // namespace mediapipe::tasks::components::containers #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ diff --git a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h index c3c0a0261..a86b2cca8 100644 --- a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h +++ b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h @@ -129,13 +129,13 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi { if (roi.left >= roi.right || roi.top >= roi.bottom) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, - "Expected Rect with left < right and top < bottom.", + "Expected RectF with left < right and top < bottom.", MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); } if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, - "Expected Rect values to be in [0,1].", + "Expected RectF values to be in [0,1].", MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); } normalized_rect.set_x_center((roi.left + roi.right) / 2.0); diff --git a/mediapipe/tasks/cc/vision/core/image_processing_options.h b/mediapipe/tasks/cc/vision/core/image_processing_options.h index 7e764c1fe..1983272fc 100644 --- a/mediapipe/tasks/cc/vision/core/image_processing_options.h +++ b/mediapipe/tasks/cc/vision/core/image_processing_options.h @@ -35,7 +35,8 @@ struct ImageProcessingOptions { // the full image is used. // // Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom. - std::optional region_of_interest = std::nullopt; + std::optional region_of_interest = + std::nullopt; // The rotation to apply to the image (or cropped region-of-interest), in // degrees clockwise. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc index 564184c64..266ce223f 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc @@ -44,7 +44,7 @@ namespace { using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::utils::CalculateIOU; using ::mediapipe::tasks::vision::utils::DuplicatesFinder; @@ -126,7 +126,7 @@ absl::StatusOr HandBaselineDistance( return distance; } -Rect CalculateBound(const NormalizedLandmarkList& list) { +RectF CalculateBound(const NormalizedLandmarkList& list) { constexpr float kMinInitialValue = std::numeric_limits::max(); constexpr float kMaxInitialValue = std::numeric_limits::lowest(); @@ -144,10 +144,10 @@ Rect CalculateBound(const NormalizedLandmarkList& list) { } // Populate normalized non rotated face bounding box - return Rect{/*left=*/bounding_box_left, - /*top=*/bounding_box_top, - /*right=*/bounding_box_right, - /*bottom=*/bounding_box_bottom}; + return RectF{/*left=*/bounding_box_left, + /*top=*/bounding_box_top, + /*right=*/bounding_box_right, + /*bottom=*/bounding_box_bottom}; } // Uses IoU and distance of some corresponding hand landmarks to detect @@ -172,7 +172,7 @@ class HandDuplicatesFinder : public DuplicatesFinder { const int num = multi_landmarks.size(); std::vector baseline_distances; baseline_distances.reserve(num); - std::vector bounds; + std::vector bounds; bounds.reserve(num); for (const NormalizedLandmarkList& list : multi_landmarks) { ASSIGN_OR_RETURN(const float baseline_distance, diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc index fa49a4c1f..94d1b1c12 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc @@ -50,7 +50,7 @@ namespace { using ::file::Defaults; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::EqualsProto; @@ -188,7 +188,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { options->running_mode = core::RunningMode::IMAGE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr hand_landmarker, HandLandmarker::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = hand_landmarker->Detect(image, image_processing_options); diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index 1144e9032..7aa2a148c 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -52,7 +52,7 @@ namespace { using ::mediapipe::file::JoinPath; using ::mediapipe::tasks::components::containers::Category; using ::mediapipe::tasks::components::containers::Classifications; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -472,7 +472,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Region-of-interest around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( @@ -526,7 +526,8 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Region-of-interest around the chair, with 90° anti-clockwise rotation. - Rect roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, /*bottom=*/0.3049}; + RectF roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, + /*bottom=*/0.3049}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/-90}; @@ -554,13 +555,13 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { ImageClassifier::Create(std::move(options))); // Invalid: left > right. - Rect roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1}; + RectF roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect with left < right and top < bottom")); + HasSubstr("Expected RectF with left < right and top < bottom")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -573,7 +574,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect with left < right and top < bottom")); + HasSubstr("Expected RectF with left < right and top < bottom")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -586,7 +587,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { results = image_classifier->Classify(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("Expected Rect values to be in [0,1]")); + HasSubstr("Expected RectF values to be in [0,1]")); EXPECT_THAT( results.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( @@ -695,7 +696,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) { ImageClassifier::Create(std::move(options))); // Crop around the soccer ball. // Region-of-interest around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { @@ -837,7 +838,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Crop around the soccer ball. - Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + RectF roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc index 6098a9a70..dd602bef5 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc @@ -41,7 +41,7 @@ namespace image_embedder { namespace { using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -320,7 +320,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { Image crop, DecodeImageFromFile( JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); // Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg". - Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1}; + RectF roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; // Extract both embeddings. @@ -388,7 +388,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger_rotated.jpg"))); // Region-of-interest corresponding to burger_crop.jpg. - Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333}; + RectF roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/-90}; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index d5ea088a1..f9618c1b1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -47,7 +47,7 @@ namespace { using ::mediapipe::Image; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -299,7 +299,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = segmenter->Segment(image, image_processing_options); diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 77373303a..5269796ae 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -33,6 +33,7 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", + "//mediapipe/tasks/cc/components/containers:detection_result", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc index dd19237ff..e0222dd70 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -56,6 +57,7 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.ObjectDetectorGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::tasks::components::containers::ConvertToDetectionResult; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; @@ -129,7 +131,8 @@ absl::StatusOr> ObjectDetector::Create( Packet detections_packet = status_or_packets.value()[kDetectionsOutStreamName]; Packet image_packet = status_or_packets.value()[kImageOutStreamName]; - result_callback(detections_packet.Get>(), + result_callback(ConvertToDetectionResult( + detections_packet.Get>()), image_packet.Get(), detections_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); @@ -144,7 +147,7 @@ absl::StatusOr> ObjectDetector::Create( std::move(packets_callback)); } -absl::StatusOr> ObjectDetector::Detect( +absl::StatusOr ObjectDetector::Detect( mediapipe::Image image, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -161,10 +164,11 @@ absl::StatusOr> ObjectDetector::Detect( ProcessImageData( {{kImageInStreamName, MakePacket(std::move(image))}, {kNormRectName, MakePacket(std::move(norm_rect))}})); - return output_packets[kDetectionsOutStreamName].Get>(); + return ConvertToDetectionResult( + output_packets[kDetectionsOutStreamName].Get>()); } -absl::StatusOr> ObjectDetector::DetectForVideo( +absl::StatusOr ObjectDetector::DetectForVideo( mediapipe::Image image, int64 timestamp_ms, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -185,7 +189,8 @@ absl::StatusOr> ObjectDetector::DetectForVideo( {kNormRectName, MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); - return output_packets[kDetectionsOutStreamName].Get>(); + return ConvertToDetectionResult( + output_packets[kDetectionsOutStreamName].Get>()); } absl::Status ObjectDetector::DetectAsync( diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.h b/mediapipe/tasks/cc/vision/object_detector/object_detector.h index 44ce68ed9..249a2ebf5 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.h +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" @@ -36,6 +37,10 @@ namespace mediapipe { namespace tasks { namespace vision { +// Alias the shared DetectionResult struct as result typo. +using ObjectDetectorResult = + ::mediapipe::tasks::components::containers::DetectionResult; + // The options for configuring a mediapipe object detector task. struct ObjectDetectorOptions { // Base options for configuring MediaPipe Tasks, such as specifying the TfLite @@ -79,8 +84,7 @@ struct ObjectDetectorOptions { // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. - std::function>, - const Image&, int64)> + std::function, const Image&, int64)> result_callback = nullptr; }; @@ -165,7 +169,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // underlying image data. // TODO: Describes the output bounding boxes for gpu input // images after enabling the gpu support in MediaPipe Tasks. - absl::StatusOr> Detect( + absl::StatusOr Detect( mediapipe::Image image, std::optional image_processing_options = std::nullopt); @@ -188,7 +192,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // unrotated input frame of reference coordinates system, i.e. in `[0, // image_width) x [0, image_height)`, which are the dimensions of the // underlying image data. - absl::StatusOr> DetectForVideo( + absl::StatusOr DetectForVideo( mediapipe::Image image, int64 timestamp_ms, std::optional image_processing_options = std::nullopt); diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index 1747685dd..798e3f238 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" @@ -65,10 +66,14 @@ namespace vision { namespace { using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::ConvertToDetectionResult; +using ::mediapipe::tasks::components::containers::Detection; +using ::mediapipe::tasks::components::containers::DetectionResult; +using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; +using DetectionProto = mediapipe::Detection; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kMobileSsdWithMetadata[] = @@ -83,47 +88,45 @@ constexpr char kEfficientDetWithMetadata[] = // Checks that the two provided `Detection` proto vectors are equal, with a // tolerancy on floating-point scores to account for numerical instabilities. // If the proto definition changes, please also change this function. -void ExpectApproximatelyEqual(const std::vector& actual, - const std::vector& expected) { +void ExpectApproximatelyEqual(const ObjectDetectorResult& actual, + const ObjectDetectorResult& expected) { const float kPrecision = 1e-6; - EXPECT_EQ(actual.size(), expected.size()); - for (int i = 0; i < actual.size(); ++i) { - const Detection& a = actual[i]; - const Detection& b = expected[i]; - EXPECT_THAT(a.location_data().bounding_box(), - EqualsProto(b.location_data().bounding_box())); - EXPECT_EQ(a.label_size(), 1); - EXPECT_EQ(b.label_size(), 1); - EXPECT_EQ(a.label(0), b.label(0)); - EXPECT_EQ(a.score_size(), 1); - EXPECT_EQ(b.score_size(), 1); - EXPECT_NEAR(a.score(0), b.score(0), kPrecision); + EXPECT_EQ(actual.detections.size(), expected.detections.size()); + for (int i = 0; i < actual.detections.size(); ++i) { + const Detection& a = actual.detections[i]; + const Detection& b = expected.detections[i]; + EXPECT_EQ(a.bounding_box, b.bounding_box); + EXPECT_EQ(a.categories.size(), 1); + EXPECT_EQ(b.categories.size(), 1); + EXPECT_EQ(a.categories[0].category_name, b.categories[0].category_name); + EXPECT_NEAR(a.categories[0].score, b.categories[0].score, kPrecision); } } -std::vector GenerateMobileSsdNoImageResizingFullExpectedResults() { - return {ParseTextProtoOrDie(R"pb( +std::vector +GenerateMobileSsdNoImageResizingFullExpectedResults() { + return {ParseTextProtoOrDie(R"pb( label: "cat" score: 0.6328125 location_data { format: BOUNDING_BOX bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "cat" score: 0.59765625 location_data { format: BOUNDING_BOX bounding_box { xmin: 151 ymin: 78 width: 104 height: 223 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "cat" score: 0.5 location_data { format: BOUNDING_BOX bounding_box { xmin: 65 ymin: 199 width: 41 height: 101 } })pb"), - ParseTextProtoOrDie(R"pb( + ParseTextProtoOrDie(R"pb( label: "dog" score: 0.48828125 location_data { @@ -263,8 +266,8 @@ TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInImageOrVideoMode) { JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->running_mode = running_mode; options->result_callback = - [](absl::StatusOr> detections, - const Image& image, int64 timestamp_ms) {}; + [](absl::StatusOr detections, const Image& image, + int64 timestamp_ms) {}; absl::StatusOr> object_detector = ObjectDetector::Create(std::move(options)); EXPECT_EQ(object_detector.status().code(), @@ -340,34 +343,36 @@ TEST_F(ImageModeTest, Succeeds) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.69921875 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.64453125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.51171875 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.48828125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 } - })pb")}); + results, + ConvertToDetectionResult( + {ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.69921875 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.64453125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.51171875 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.48828125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 } + })pb")})); } TEST_F(ImageModeTest, SucceedsEfficientDetModel) { @@ -383,34 +388,36 @@ TEST_F(ImageModeTest, SucceedsEfficientDetModel) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.7578125 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.72265625 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.6289063 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 } - })pb"), - ParseTextProtoOrDie(R"pb( - label: "cat" - score: 0.5859375 - location_data { - format: BOUNDING_BOX - bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 } - })pb")}); + results, + ConvertToDetectionResult( + {ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.7578125 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 858 ymin: 408 width: 225 height: 187 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.72265625 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 67 ymin: 401 width: 399 height: 192 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.6289063 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 368 ymin: 210 width: 272 height: 385 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.5859375 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 601 ymin: 166 width: 298 height: 437 } + })pb")})); } TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { @@ -426,7 +433,8 @@ TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, GenerateMobileSsdNoImageResizingFullExpectedResults()); + results, ConvertToDetectionResult( + GenerateMobileSsdNoImageResizingFullExpectedResults())); } TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { @@ -442,13 +450,14 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( + results, + ConvertToDetectionResult({ParseTextProtoOrDie(R"pb( label: "cat" score: 0.6531269142 location_data { format: BOUNDING_BOX bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } - })pb")}); + })pb")})); } TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { @@ -463,11 +472,13 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, - {full_expected_results[0], full_expected_results[1], - full_expected_results[2]}); + + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[0], + full_expected_results[1], + full_expected_results[2]})); } TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { @@ -482,10 +493,11 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); ExpectApproximatelyEqual( - results, {full_expected_results[0], full_expected_results[1]}); + results, ConvertToDetectionResult( + {full_expected_results[0], full_expected_results[1]})); } TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { @@ -501,9 +513,10 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, {full_expected_results[3]}); + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[3]})); } TEST_F(ImageModeTest, SucceedsWithDenylistOption) { @@ -519,9 +532,10 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) { ObjectDetector::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); MP_ASSERT_OK(object_detector->Close()); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); - ExpectApproximatelyEqual(results, {full_expected_results[3]}); + ExpectApproximatelyEqual( + results, ConvertToDetectionResult({full_expected_results[3]})); } TEST_F(ImageModeTest, SucceedsWithRotation) { @@ -541,13 +555,14 @@ TEST_F(ImageModeTest, SucceedsWithRotation) { auto results, object_detector->Detect(image, image_processing_options)); MP_ASSERT_OK(object_detector->Close()); ExpectApproximatelyEqual( - results, {ParseTextProtoOrDie(R"pb( + results, + ConvertToDetectionResult({ParseTextProtoOrDie(R"pb( label: "cat" score: 0.7109375 location_data { format: BOUNDING_BOX bounding_box { xmin: 0 ymin: 622 width: 436 height: 276 } - })pb")}); + })pb")})); } TEST_F(ImageModeTest, FailsWithRegionOfInterest) { @@ -560,7 +575,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); - Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = object_detector->Detect(image, image_processing_options); @@ -619,10 +634,11 @@ TEST_F(VideoModeTest, Succeeds) { for (int i = 0; i < iterations; ++i) { MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->DetectForVideo(image, i)); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); ExpectApproximatelyEqual( - results, {full_expected_results[0], full_expected_results[1]}); + results, ConvertToDetectionResult( + {full_expected_results[0], full_expected_results[1]})); } MP_ASSERT_OK(object_detector->Close()); } @@ -637,9 +653,8 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->running_mode = core::RunningMode::LIVE_STREAM; - options->result_callback = - [](absl::StatusOr> detections, const Image& image, - int64 timestamp_ms) {}; + options->result_callback = [](absl::StatusOr detections, + const Image& image, int64 timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); @@ -669,9 +684,8 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { options->running_mode = core::RunningMode::LIVE_STREAM; options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); - options->result_callback = - [](absl::StatusOr> detections, const Image& image, - int64 timestamp_ms) {}; + options->result_callback = [](absl::StatusOr detections, + const Image& image, int64 timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); MP_ASSERT_OK(object_detector->DetectAsync(image, 1)); @@ -695,14 +709,14 @@ TEST_F(LiveStreamModeTest, Succeeds) { auto options = std::make_unique(); options->max_results = 2; options->running_mode = core::RunningMode::LIVE_STREAM; - std::vector> detection_results; + std::vector detection_results; std::vector> image_sizes; std::vector timestamps; options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->result_callback = [&detection_results, &image_sizes, ×tamps]( - absl::StatusOr> detections, const Image& image, + absl::StatusOr detections, const Image& image, int64 timestamp_ms) { MP_ASSERT_OK(detections.status()); detection_results.push_back(std::move(detections).value()); @@ -719,11 +733,12 @@ TEST_F(LiveStreamModeTest, Succeeds) { // number of iterations. ASSERT_LE(detection_results.size(), iterations); ASSERT_GT(detection_results.size(), 0); - std::vector full_expected_results = + std::vector full_expected_results = GenerateMobileSsdNoImageResizingFullExpectedResults(); for (const auto& detection_result : detection_results) { ExpectApproximatelyEqual( - detection_result, {full_expected_results[0], full_expected_results[1]}); + detection_result, ConvertToDetectionResult({full_expected_results[0], + full_expected_results[1]})); } for (const auto& image_size : image_sizes) { EXPECT_EQ(image_size.first, image.width()); diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc index 2ce9e2454..fe4e63824 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc @@ -22,13 +22,13 @@ limitations under the License. namespace mediapipe::tasks::vision::utils { -using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::RectF; -float CalculateArea(const Rect& rect) { +float CalculateArea(const RectF& rect) { return (rect.right - rect.left) * (rect.bottom - rect.top); } -float CalculateIntersectionArea(const Rect& a, const Rect& b) { +float CalculateIntersectionArea(const RectF& a, const RectF& b) { const float intersection_left = std::max(a.left, b.left); const float intersection_top = std::max(a.top, b.top); const float intersection_right = std::min(a.right, b.right); @@ -38,7 +38,7 @@ float CalculateIntersectionArea(const Rect& a, const Rect& b) { std::max(intersection_right - intersection_left, 0.0); } -float CalculateIOU(const Rect& a, const Rect& b) { +float CalculateIOU(const RectF& a, const RectF& b) { const float area_a = CalculateArea(a); const float area_b = CalculateArea(b); if (area_a <= 0 || area_b <= 0) return 0.0; diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h index 73114d2ef..4d1fac62f 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h @@ -27,15 +27,15 @@ limitations under the License. namespace mediapipe::tasks::vision::utils { // Calculates intersection over union for two bounds. -float CalculateIOU(const components::containers::Rect& a, - const components::containers::Rect& b); +float CalculateIOU(const components::containers::RectF& a, + const components::containers::RectF& b); // Calculates area for face bound -float CalculateArea(const components::containers::Rect& rect); +float CalculateArea(const components::containers::RectF& rect); // Calucates intersection area of two face bounds -float CalculateIntersectionArea(const components::containers::Rect& a, - const components::containers::Rect& b); +float CalculateIntersectionArea(const components::containers::RectF& a, + const components::containers::RectF& b); } // namespace mediapipe::tasks::vision::utils #endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_ From 700c7b4b2258d3a01bf8424146a4cf94e8ca7282 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 7 Dec 2022 18:54:34 -0800 Subject: [PATCH 37/64] Internal refactoring for TextEmbedder. PiperOrigin-RevId: 493766612 --- .../tasks/cc/components/processors/BUILD | 3 + .../cc/components/processors/proto/BUILD | 6 + .../processors/proto/text_model_type.proto | 31 +++++ .../text_preprocessing_graph_options.proto | 15 +-- .../processors/text_preprocessing_graph.cc | 126 +++++------------- mediapipe/tasks/cc/text/utils/BUILD | 40 ++++++ .../tasks/cc/text/utils/text_model_utils.cc | 119 +++++++++++++++++ .../tasks/cc/text/utils/text_model_utils.h | 33 +++++ .../cc/text/utils/text_model_utils_test.cc | 108 +++++++++++++++ 9 files changed, 375 insertions(+), 106 deletions(-) create mode 100644 mediapipe/tasks/cc/components/processors/proto/text_model_type.proto create mode 100644 mediapipe/tasks/cc/text/utils/text_model_utils.cc create mode 100644 mediapipe/tasks/cc/text/utils/text_model_utils.h create mode 100644 mediapipe/tasks/cc/text/utils/text_model_utils_test.cc diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 185bf231b..cec44a9e3 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -150,9 +150,12 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/text/utils:text_model_utils", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index f48c4bad8..816ba47e3 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -60,10 +60,16 @@ mediapipe_proto_library( ], ) +mediapipe_proto_library( + name = "text_model_type_proto", + srcs = ["text_model_type.proto"], +) + mediapipe_proto_library( name = "text_preprocessing_graph_options_proto", srcs = ["text_preprocessing_graph_options.proto"], deps = [ + ":text_model_type_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], diff --git a/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto b/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto new file mode 100644 index 000000000..7ffc0db07 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto @@ -0,0 +1,31 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.components.processors.proto; + +message TextModelType { + // TFLite text models supported by MediaPipe tasks. + enum ModelType { + UNSPECIFIED_MODEL = 0; + // A BERT-based model. + BERT_MODEL = 1; + // A model expecting input passed through a regex-based tokenizer. + REGEX_MODEL = 2; + // A model taking a string tensor input. + STRING_MODEL = 3; + } +} diff --git a/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto index a67cfd8a9..b610f7757 100644 --- a/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto @@ -18,25 +18,16 @@ syntax = "proto2"; package mediapipe.tasks.components.processors.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/processors/proto/text_model_type.proto"; message TextPreprocessingGraphOptions { extend mediapipe.CalculatorOptions { optional TextPreprocessingGraphOptions ext = 476978751; } - // The type of text preprocessor required for the TFLite model. - enum PreprocessorType { - UNSPECIFIED_PREPROCESSOR = 0; - // Used for the BertPreprocessorCalculator. - BERT_PREPROCESSOR = 1; - // Used for the RegexPreprocessorCalculator. - REGEX_PREPROCESSOR = 2; - // Used for the TextToTensorCalculator. - STRING_PREPROCESSOR = 3; - } - optional PreprocessorType preprocessor_type = 1; + optional TextModelType.ModelType model_type = 1; // The maximum input sequence length for the TFLite model. Used with - // BERT_PREPROCESSOR and REGEX_PREPROCESSOR. + // BERT_MODEL and REGEX_MODEL. optional int32 max_seq_len = 2; } diff --git a/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc index de16375bd..f6c15c441 100644 --- a/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc @@ -25,15 +25,14 @@ limitations under the License. #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/subgraph.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" -namespace mediapipe { -namespace tasks { -namespace components { -namespace processors { - +namespace mediapipe::tasks::components::processors { namespace { using ::mediapipe::api2::Input; @@ -42,91 +41,35 @@ using ::mediapipe::api2::SideInput; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::SideSource; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::processors::proto::TextModelType; using ::mediapipe::tasks::components::processors::proto:: TextPreprocessingGraphOptions; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; +using ::mediapipe::tasks::text::utils::GetModelType; constexpr char kTextTag[] = "TEXT"; constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; constexpr char kTensorsTag[] = "TENSORS"; -constexpr int kNumInputTensorsForBert = 3; -constexpr int kNumInputTensorsForRegex = 1; - -// Gets the name of the MediaPipe calculator associated with -// `preprocessor_type`. -absl::StatusOr GetCalculatorNameFromPreprocessorType( - TextPreprocessingGraphOptions::PreprocessorType preprocessor_type) { - switch (preprocessor_type) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: +// Gets the name of the MediaPipe preprocessor calculator associated with +// `model_type`. +absl::StatusOr GetCalculatorNameFromModelType( + TextModelType::ModelType model_type) { + switch (model_type) { + case TextModelType::UNSPECIFIED_MODEL: return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, "Unspecified preprocessor type", + absl::StatusCode::kInvalidArgument, "Unspecified model type", MediaPipeTasksStatus::kInvalidArgumentError); - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: + case TextModelType::BERT_MODEL: return "BertPreprocessorCalculator"; - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: + case TextModelType::REGEX_MODEL: return "RegexPreprocessorCalculator"; - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: + case TextModelType::STRING_MODEL: return "TextToTensorCalculator"; } } -// Determines the PreprocessorType for the model based on its metadata as well -// as its input tensors' type and count. Returns an error if there is no -// compatible preprocessor. -absl::StatusOr -GetPreprocessorType(const ModelResources& model_resources) { - const tflite::SubGraph& model_graph = - *(*model_resources.GetTfLiteModel()->subgraphs())[0]; - bool all_int32_tensors = - absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { - return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32; - }); - bool all_string_tensors = - absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { - return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING; - }); - if (!all_int32_tensors && !all_string_tensors) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "All input tensors should have type int32 or all should have type " - "string", - MediaPipeTasksStatus::kInvalidInputTensorTypeError); - } - if (all_string_tensors) { - return TextPreprocessingGraphOptions::STRING_PREPROCESSOR; - } - - // Otherwise, all tensors should have type int32 - const ModelMetadataExtractor* metadata_extractor = - model_resources.GetMetadataExtractor(); - if (metadata_extractor->GetModelMetadata() == nullptr || - metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "Text models with int32 input tensors require TFLite Model " - "Metadata but none was found", - MediaPipeTasksStatus::kMetadataNotFoundError); - } - - if (model_graph.inputs()->size() == kNumInputTensorsForBert) { - return TextPreprocessingGraphOptions::BERT_PREPROCESSOR; - } - - if (model_graph.inputs()->size() == kNumInputTensorsForRegex) { - return TextPreprocessingGraphOptions::REGEX_PREPROCESSOR; - } - - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::Substitute("Models with int32 input tensors should take exactly $0 " - "or $1 input tensors, but found $2", - kNumInputTensorsForBert, kNumInputTensorsForRegex, - model_graph.inputs()->size()), - MediaPipeTasksStatus::kInvalidNumInputTensorsError); -} - // Returns the maximum input sequence length accepted by the TFLite // model that owns `model graph` or returns an error if the model's input // tensors' shape is invalid for text preprocessing. This util assumes that the @@ -181,17 +124,16 @@ absl::Status ConfigureTextPreprocessingGraph( MediaPipeTasksStatus::kInvalidArgumentError); } - ASSIGN_OR_RETURN( - TextPreprocessingGraphOptions::PreprocessorType preprocessor_type, - GetPreprocessorType(model_resources)); - options.set_preprocessor_type(preprocessor_type); - switch (preprocessor_type) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: { + ASSIGN_OR_RETURN(TextModelType::ModelType model_type, + GetModelType(model_resources)); + options.set_model_type(model_type); + switch (model_type) { + case TextModelType::UNSPECIFIED_MODEL: + case TextModelType::STRING_MODEL: { break; } - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: { + case TextModelType::BERT_MODEL: + case TextModelType::REGEX_MODEL: { ASSIGN_OR_RETURN( int max_seq_len, GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0])); @@ -239,23 +181,22 @@ class TextPreprocessingGraph : public mediapipe::Subgraph { absl::StatusOr>> BuildTextPreprocessing( const TextPreprocessingGraphOptions& options, Source text_in, SideSource metadata_extractor_in, Graph& graph) { - ASSIGN_OR_RETURN( - std::string preprocessor_name, - GetCalculatorNameFromPreprocessorType(options.preprocessor_type())); + ASSIGN_OR_RETURN(std::string preprocessor_name, + GetCalculatorNameFromModelType(options.model_type())); auto& text_preprocessor = graph.AddNode(preprocessor_name); - switch (options.preprocessor_type()) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: { + switch (options.model_type()) { + case TextModelType::UNSPECIFIED_MODEL: + case TextModelType::STRING_MODEL: { break; } - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: { + case TextModelType::BERT_MODEL: { text_preprocessor.GetOptions() .set_bert_max_seq_len(options.max_seq_len()); metadata_extractor_in >> text_preprocessor.SideIn(kMetadataExtractorTag); break; } - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: { + case TextModelType::REGEX_MODEL: { text_preprocessor.GetOptions() .set_max_seq_len(options.max_seq_len()); metadata_extractor_in >> @@ -270,7 +211,4 @@ class TextPreprocessingGraph : public mediapipe::Subgraph { REGISTER_MEDIAPIPE_GRAPH( ::mediapipe::tasks::components::processors::TextPreprocessingGraph); -} // namespace processors -} // namespace components -} // namespace tasks -} // namespace mediapipe +} // namespace mediapipe::tasks::components::processors diff --git a/mediapipe/tasks/cc/text/utils/BUILD b/mediapipe/tasks/cc/text/utils/BUILD index 710e8a984..092a7d450 100644 --- a/mediapipe/tasks/cc/text/utils/BUILD +++ b/mediapipe/tasks/cc/text/utils/BUILD @@ -43,3 +43,43 @@ cc_test( "@com_google_absl//absl/container:node_hash_map", ], ) + +cc_library( + name = "text_model_utils", + srcs = ["text_model_utils.cc"], + hdrs = ["text_model_utils.h"], + deps = [ + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_test( + name = "text_model_utils_test", + srcs = ["text_model_utils_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:mobilebert_embedding_model", + "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + deps = [ + ":text_model_utils", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], +) diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils.cc b/mediapipe/tasks/cc/text/utils/text_model_utils.cc new file mode 100644 index 000000000..9d0005ec1 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils.cc @@ -0,0 +1,119 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/substitute.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace mediapipe::tasks::text::utils { +namespace { + +using ::mediapipe::tasks::components::processors::proto::TextModelType; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; + +constexpr int kNumInputTensorsForBert = 3; +constexpr int kNumInputTensorsForRegex = 1; +constexpr int kNumInputTensorsForStringPreprocessor = 1; + +// Determines the ModelType for a model with int32 input tensors based +// on the number of input tensors. Returns an error if there is missing metadata +// or an invalid number of input tensors. +absl::StatusOr GetIntTensorModelType( + const ModelResources& model_resources, int num_input_tensors) { + const ModelMetadataExtractor* metadata_extractor = + model_resources.GetMetadataExtractor(); + if (metadata_extractor->GetModelMetadata() == nullptr || + metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Text models with int32 input tensors require TFLite Model " + "Metadata but none was found", + MediaPipeTasksStatus::kMetadataNotFoundError); + } + + if (num_input_tensors == kNumInputTensorsForBert) { + return TextModelType::BERT_MODEL; + } + + if (num_input_tensors == kNumInputTensorsForRegex) { + return TextModelType::REGEX_MODEL; + } + + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute("Models with int32 input tensors should take exactly $0 " + "or $1 input tensors, but found $2", + kNumInputTensorsForBert, kNumInputTensorsForRegex, + num_input_tensors), + MediaPipeTasksStatus::kInvalidNumInputTensorsError); +} + +// Determines the ModelType for a model with string input tensors based +// on the number of input tensors. Returns an error if there is an invalid +// number of input tensors. +absl::StatusOr GetStringTensorModelType( + const ModelResources& model_resources, int num_input_tensors) { + if (num_input_tensors == kNumInputTensorsForStringPreprocessor) { + return TextModelType::STRING_MODEL; + } + + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute("Models with string input tensors should take exactly " + "$0 tensors, but found $1", + kNumInputTensorsForStringPreprocessor, + num_input_tensors), + MediaPipeTasksStatus::kInvalidNumInputTensorsError); +} +} // namespace + +absl::StatusOr GetModelType( + const ModelResources& model_resources) { + const tflite::SubGraph& model_graph = + *(*model_resources.GetTfLiteModel()->subgraphs())[0]; + bool all_int32_tensors = + absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { + return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32; + }); + bool all_string_tensors = + absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { + return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING; + }); + if (!all_int32_tensors && !all_string_tensors) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "All input tensors should have type int32 or all should have type " + "string", + MediaPipeTasksStatus::kInvalidInputTensorTypeError); + } + if (all_string_tensors) { + return GetStringTensorModelType(model_resources, + model_graph.inputs()->size()); + } + + // Otherwise, all tensors should have type int32 + return GetIntTensorModelType(model_resources, model_graph.inputs()->size()); +} + +} // namespace mediapipe::tasks::text::utils diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils.h b/mediapipe/tasks/cc/text/utils/text_model_utils.h new file mode 100644 index 000000000..da8783d33 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils.h @@ -0,0 +1,33 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ + +#include "absl/status/statusor.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" + +namespace mediapipe::tasks::text::utils { + +// Determines the ModelType for the model based on its metadata as well +// as its input tensors' type and count. Returns an error if there is no +// compatible model type. +absl::StatusOr +GetModelType(const core::ModelResources& model_resources); + +} // namespace mediapipe::tasks::text::utils + +#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc b/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc new file mode 100644 index 000000000..c02f8eca5 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe::tasks::text::utils { + +namespace { + +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::processors::proto::TextModelType; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::core::proto::ExternalFile; + +constexpr absl::string_view kTestModelResourcesTag = "test_model_resources"; + +constexpr absl::string_view kTestDataDirectory = + "/mediapipe/tasks/testdata/text/"; +// Classification model with BERT preprocessing. +constexpr absl::string_view kBertClassifierPath = "bert_text_classifier.tflite"; +// Embedding model with BERT preprocessing. +constexpr absl::string_view kMobileBert = + "mobilebert_embedding_with_metadata.tflite"; +// Classification model with regex preprocessing. +constexpr absl::string_view kRegexClassifierPath = + "test_model_text_classifier_with_regex_tokenizer.tflite"; +// Embedding model with regex preprocessing. +constexpr absl::string_view kRegexOneEmbeddingModel = + "regex_one_embedding_with_metadata.tflite"; +// Classification model that takes a string tensor and outputs a bool tensor. +constexpr absl::string_view kStringToBoolModelPath = + "test_model_text_classifier_bool_output.tflite"; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./", kTestDataDirectory, file_name); +} + +absl::StatusOr GetModelTypeFromFile( + absl::string_view file_name) { + auto model_file = std::make_unique(); + model_file->set_file_name(GetFullPath(file_name)); + ASSIGN_OR_RETURN(auto model_resources, + ModelResources::Create(std::string(kTestModelResourcesTag), + std::move(model_file))); + return GetModelType(*model_resources); +} + +} // namespace + +class TextModelUtilsTest : public tflite_shims::testing::Test {}; + +TEST_F(TextModelUtilsTest, BertClassifierModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kBertClassifierPath)); + ASSERT_EQ(model_type, TextModelType::BERT_MODEL); +} + +TEST_F(TextModelUtilsTest, BertEmbedderModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, GetModelTypeFromFile(kMobileBert)); + ASSERT_EQ(model_type, TextModelType::BERT_MODEL); +} + +TEST_F(TextModelUtilsTest, RegexClassifierModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kRegexClassifierPath)); + ASSERT_EQ(model_type, TextModelType::REGEX_MODEL); +} + +TEST_F(TextModelUtilsTest, RegexEmbedderModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kRegexOneEmbeddingModel)); + ASSERT_EQ(model_type, TextModelType::REGEX_MODEL); +} + +TEST_F(TextModelUtilsTest, StringInputModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kStringToBoolModelPath)); + ASSERT_EQ(model_type, TextModelType::STRING_MODEL); +} + +} // namespace mediapipe::tasks::text::utils From 24c8fa97e9aeb75ac6344957aff2a2d5b953061b Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Wed, 7 Dec 2022 19:04:31 -0800 Subject: [PATCH 38/64] Internal change PiperOrigin-RevId: 493768013 --- mediapipe/examples/ios/faceeffect/BUILD | 4 ++-- mediapipe/examples/ios/facemeshgpu/BUILD | 2 +- mediapipe/examples/ios/handtrackinggpu/BUILD | 2 +- mediapipe/examples/ios/iristrackinggpu/BUILD | 2 +- mediapipe/examples/ios/posetrackinggpu/BUILD | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mediapipe/examples/ios/faceeffect/BUILD b/mediapipe/examples/ios/faceeffect/BUILD index e0c3abb86..7d3a75cc6 100644 --- a/mediapipe/examples/ios/faceeffect/BUILD +++ b/mediapipe/examples/ios/faceeffect/BUILD @@ -74,10 +74,12 @@ objc_library( ], features = ["-layering_check"], deps = [ + "//mediapipe/framework/formats:matrix_data_cc_proto", "//third_party/apple_frameworks:AVFoundation", "//third_party/apple_frameworks:CoreGraphics", "//third_party/apple_frameworks:CoreMedia", "//third_party/apple_frameworks:UIKit", + "//mediapipe/modules/face_geometry/protos:face_geometry_cc_proto", "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_input_sources_ios", "//mediapipe/objc:mediapipe_layer_renderer", @@ -85,9 +87,7 @@ objc_library( "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/graphs/face_effect:face_effect_gpu_deps", - "//mediapipe/modules/face_geometry/protos:face_geometry_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/facemeshgpu/BUILD b/mediapipe/examples/ios/facemeshgpu/BUILD index 02103ce2f..6caf8c09c 100644 --- a/mediapipe/examples/ios/facemeshgpu/BUILD +++ b/mediapipe/examples/ios/facemeshgpu/BUILD @@ -67,12 +67,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/face_mesh:mobile_calculators", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/handtrackinggpu/BUILD b/mediapipe/examples/ios/handtrackinggpu/BUILD index 647b7670a..c5b8e7b58 100644 --- a/mediapipe/examples/ios/handtrackinggpu/BUILD +++ b/mediapipe/examples/ios/handtrackinggpu/BUILD @@ -68,12 +68,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/hand_tracking:mobile_calculators", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/iristrackinggpu/BUILD b/mediapipe/examples/ios/iristrackinggpu/BUILD index 056447d63..646d2e5a2 100644 --- a/mediapipe/examples/ios/iristrackinggpu/BUILD +++ b/mediapipe/examples/ios/iristrackinggpu/BUILD @@ -68,12 +68,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/iris_tracking:iris_tracking_gpu_deps", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) diff --git a/mediapipe/examples/ios/posetrackinggpu/BUILD b/mediapipe/examples/ios/posetrackinggpu/BUILD index 86b41ed36..4fbc2280c 100644 --- a/mediapipe/examples/ios/posetrackinggpu/BUILD +++ b/mediapipe/examples/ios/posetrackinggpu/BUILD @@ -67,12 +67,12 @@ objc_library( ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + "//mediapipe/framework/formats:landmark_cc_proto", ] + select({ "//mediapipe:ios_i386": [], "//mediapipe:ios_x86_64": [], "//conditions:default": [ "//mediapipe/graphs/pose_tracking:pose_tracking_gpu_deps", - "//mediapipe/framework/formats:landmark_cc_proto", ], }), ) From 9ae2e43b70188cd73fd478364b71d32410f9c21c Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 7 Dec 2022 19:17:14 -0800 Subject: [PATCH 39/64] Open Source the remaining MediaPipe Tasks tests for Web PiperOrigin-RevId: 493769657 --- .../audio_classifier_graph_options.proto | 1 + .../proto/audio_embedder_graph_options.proto | 1 + .../proto/text_classifier_graph_options.proto | 1 + .../proto/text_embedder_graph_options.proto | 1 + .../gesture_classifier_graph_options.proto | 1 + .../gesture_embedder_graph_options.proto | 1 + .../gesture_recognizer_graph_options.proto | 1 + ...and_gesture_recognizer_graph_options.proto | 1 + .../proto/hand_detector_graph_options.proto | 1 + .../proto/hand_landmarker_graph_options.proto | 1 + ...and_landmarks_detector_graph_options.proto | 1 + .../image_classifier_graph_options.proto | 1 + .../proto/image_embedder_graph_options.proto | 1 + .../proto/image_segmenter_graph_options.proto | 1 + .../proto/object_detector_options.proto | 1 + .../tasks/web/audio/audio_classifier/BUILD | 21 ++ .../audio_classifier/audio_classifier_test.ts | 208 ++++++++++++ .../tasks/web/audio/audio_embedder/BUILD | 21 ++ .../audio_embedder/audio_embedder_test.ts | 185 +++++++++++ .../tasks/web/text/text_classifier/BUILD | 22 ++ .../text_classifier/text_classifier_test.ts | 152 +++++++++ mediapipe/tasks/web/text/text_embedder/BUILD | 21 ++ .../text/text_embedder/text_embedder_test.ts | 165 ++++++++++ mediapipe/tasks/web/vision/core/BUILD | 18 + .../vision/core/vision_task_runner.test.ts | 99 ++++++ .../tasks/web/vision/gesture_recognizer/BUILD | 25 ++ .../gesture_recognizer_test.ts | 307 ++++++++++++++++++ .../tasks/web/vision/hand_landmarker/BUILD | 25 ++ .../hand_landmarker/hand_landmarker_test.ts | 251 ++++++++++++++ .../tasks/web/vision/image_classifier/BUILD | 24 ++ .../image_classifier/image_classifier_test.ts | 150 +++++++++ .../tasks/web/vision/image_embedder/BUILD | 21 ++ .../image_embedder/image_embedder_test.ts | 158 +++++++++ .../tasks/web/vision/object_detector/BUILD | 24 ++ .../object_detector/object_detector_test.ts | 229 +++++++++++++ 35 files changed, 2141 insertions(+) create mode 100644 mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts create mode 100644 mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts create mode 100644 mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts create mode 100644 mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts create mode 100644 mediapipe/tasks/web/vision/core/vision_task_runner.test.ts create mode 100644 mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts create mode 100644 mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts create mode 100644 mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts create mode 100644 mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts create mode 100644 mediapipe/tasks/web/vision/object_detector/object_detector_test.ts diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto index 5d4ba3296..cc26b3070 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.audio.audio_classifier.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto b/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto index 25c5d5474..367a1bf26 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.audio.audio_embedder.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto index 8f4d7eea6..41f87b519 100644 --- a/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.text.text_classifier.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto index e7e3a63c7..fc8e02858 100644 --- a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.text.text_embedder.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto index dcefa075f..edbabc018 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto index bff4e0a9c..df909a6db 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto index 57d8a3746..fef22c07c 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto"; import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto index 7df2fed37..ae85509da 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto index a009f2365..bede70da5 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_detector.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.handdetector.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto index 51e4e129a..d0edf99c0 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto"; import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto index 195f6e5cc..a2d520963 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.handlandmarker.proto"; diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto index 76315e230..24b126a35 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_classifier.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto index 72b3e7ee3..24ee866f2 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_embedder.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto index 4d8100842..5c7d2ec71 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_segmenter.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; diff --git a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto index cba58ace8..3f6932f8f 100644 --- a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto +++ b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto @@ -18,6 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.object_detector.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.objectdetector.proto"; diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index dc82a4a24..24ef31feb 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -2,6 +2,7 @@ # # This task takes audio data and outputs the classification result. +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -44,3 +45,23 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core:classifier_options", ], ) + +mediapipe_ts_library( + name = "audio_classifier_test_lib", + testonly = True, + srcs = [ + "audio_classifier_test.ts", + ], + deps = [ + ":audio_classifier", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "audio_classifier_test", + deps = [":audio_classifier_test_lib"], +) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts new file mode 100644 index 000000000..d5c0a9429 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts @@ -0,0 +1,208 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {AudioClassifier} from './audio_classifier'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class AudioClassifierFake extends AudioClassifier implements + MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = + 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + private protoVectorListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + private resultProtoVector: ClassificationResult[] = []; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('timestamped_classifications'); + this.protoVectorListener = listener; + }); + spyOn(this.graphRunner, 'addDoubleToStream') + .and.callFake((sampleRate, streamName, timestamp) => { + if (streamName === 'sample_rate') { + this.lastSampleRate = sampleRate; + } + }); + spyOn(this.graphRunner, 'addAudioToStreamWithShape') + .and.callFake( + (audioData, numChannels, numSamples, streamName, timestamp) => { + expect(numChannels).toBe(1); + }); + spyOn(this.graphRunner, 'finishProcessing').and.callFake(() => { + if (!this.protoVectorListener) return; + this.protoVectorListener(this.resultProtoVector.map( + classificationResult => classificationResult.serializeBinary())); + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + } + + /** Sets the Protobuf that will be send to the API. */ + setResults(results: ClassificationResult[]): void { + this.resultProtoVector = results; + } +} + +describe('AudioClassifier', () => { + let audioClassifier: AudioClassifierFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + audioClassifier = new AudioClassifierFake(); + await audioClassifier.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(audioClassifier); + verifyListenersRegistered(audioClassifier); + }); + + it('reloads graph when settings are changed', async () => { + await audioClassifier.setOptions({maxResults: 1}); + verifyGraph(audioClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyListenersRegistered(audioClassifier); + + await audioClassifier.setOptions({maxResults: 5}); + verifyGraph(audioClassifier, [['classifierOptions', 'maxResults'], 5]); + verifyListenersRegistered(audioClassifier); + }); + + it('merges options', async () => { + await audioClassifier.setOptions({maxResults: 1}); + await audioClassifier.setOptions({displayNamesLocale: 'en'}); + verifyGraph(audioClassifier, [ + 'classifierOptions', { + maxResults: 1, + displayNamesLocale: 'en', + scoreThreshold: undefined, + categoryAllowlistList: [], + categoryDenylistList: [] + } + ]); + }); + + it('uses a sample rate of 48000 by default', async () => { + audioClassifier.classify(new Float32Array([])); + expect(audioClassifier.lastSampleRate).toEqual(48000); + }); + + it('uses default sample rate if none provided', async () => { + audioClassifier.setDefaultSampleRate(16000); + audioClassifier.classify(new Float32Array([])); + expect(audioClassifier.lastSampleRate).toEqual(16000); + }); + + it('uses custom sample rate if provided', async () => { + audioClassifier.setDefaultSampleRate(16000); + audioClassifier.classify(new Float32Array([]), 44100); + expect(audioClassifier.lastSampleRate).toEqual(44100); + }); + + it('transforms results', async () => { + const resultProtoVector: ClassificationResult[] = []; + + let classificationResult = new ClassificationResult(); + classificationResult.setTimestampMs(0); + let classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + let classificationList = new ClassificationList(); + let clasification = new Classification(); + clasification.setIndex(1); + clasification.setScore(0.2); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + resultProtoVector.push(classificationResult); + + classificationResult = new ClassificationResult(); + classificationResult.setTimestampMs(1); + classifcations = new Classifications(); + classificationList = new ClassificationList(); + clasification = new Classification(); + clasification.setIndex(2); + clasification.setScore(0.3); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + resultProtoVector.push(classificationResult); + + // Invoke the audio classifier + audioClassifier.setResults(resultProtoVector); + const results = audioClassifier.classify(new Float32Array([])); + expect(results.length).toEqual(2); + expect(results[0]).toEqual({ + classifications: [{ + categories: [{ + index: 1, + score: 0.2, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }], + timestampMs: 0 + }); + expect(results[1]).toEqual({ + classifications: [{ + categories: [{index: 2, score: 0.3, displayName: '', categoryName: ''}], + headIndex: 0, + headName: '' + }], + timestampMs: 1 + }); + }); + + it('clears results between invocations', async () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + audioClassifier.setResults([classificationResult]); + + // Invoke the gesture recognizer twice + const classifications1 = audioClassifier.classify(new Float32Array([])); + const classifications2 = audioClassifier.classify(new Float32Array([])); + + // Verify that gestures2 is not a concatenation of all previously returned + // gestures. + expect(classifications1).toEqual(classifications2); + }); +}); diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD index dc84d0cd6..0817776c5 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -3,6 +3,7 @@ # This task takes audio input and performs embedding. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -43,3 +44,23 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core:embedder_options", ], ) + +mediapipe_ts_library( + name = "audio_embedder_test_lib", + testonly = True, + srcs = [ + "audio_embedder_test.ts", + ], + deps = [ + ":audio_embedder", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "audio_embedder_test", + deps = [":audio_embedder_test_lib"], +) diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts new file mode 100644 index 000000000..2f605ff98 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts @@ -0,0 +1,185 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Embedding, EmbeddingResult as EmbeddingResultProto, FloatEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {AudioEmbedder, AudioEmbedderResult} from './audio_embedder'; + + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class AudioEmbedderFake extends AudioEmbedder implements MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = 'mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph'; + graph: CalculatorGraphConfig|undefined; + attachListenerSpies: jasmine.Spy[] = []; + fakeWasmModule: SpyWasmModule; + + protoListener: ((binaryProto: Uint8Array) => void)|undefined; + protoVectorListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('embeddings_out'); + this.protoListener = listener; + }); + this.attachListenerSpies[1] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('timestamped_embeddings_out'); + this.protoVectorListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addDoubleToStream').and.callFake(sampleRate => { + this.lastSampleRate = sampleRate; + }); + spyOn(this.graphRunner, 'addAudioToStreamWithShape'); + } +} + +describe('AudioEmbedder', () => { + let audioEmbedder: AudioEmbedderFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + audioEmbedder = new AudioEmbedderFake(); + await audioEmbedder.setOptions({}); // Initialize graph + }); + + it('initializes graph', () => { + verifyGraph(audioEmbedder); + verifyListenersRegistered(audioEmbedder); + }); + + it('reloads graph when settings are changed', async () => { + await audioEmbedder.setOptions({quantize: true}); + verifyGraph(audioEmbedder, [['embedderOptions', 'quantize'], true]); + verifyListenersRegistered(audioEmbedder); + + await audioEmbedder.setOptions({quantize: undefined}); + verifyGraph(audioEmbedder, [['embedderOptions', 'quantize'], undefined]); + verifyListenersRegistered(audioEmbedder); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await audioEmbedder.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + audioEmbedder, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('combines options', async () => { + await audioEmbedder.setOptions({quantize: true}); + await audioEmbedder.setOptions({l2Normalize: true}); + verifyGraph( + audioEmbedder, + ['embedderOptions', {'quantize': true, 'l2Normalize': true}]); + }); + + it('uses a sample rate of 48000 by default', async () => { + audioEmbedder.embed(new Float32Array([])); + expect(audioEmbedder.lastSampleRate).toEqual(48000); + }); + + it('uses default sample rate if none provided', async () => { + audioEmbedder.setDefaultSampleRate(16000); + audioEmbedder.embed(new Float32Array([])); + expect(audioEmbedder.lastSampleRate).toEqual(16000); + }); + + it('uses custom sample rate if provided', async () => { + audioEmbedder.setDefaultSampleRate(16000); + audioEmbedder.embed(new Float32Array([]), 44100); + expect(audioEmbedder.lastSampleRate).toEqual(44100); + }); + + describe('transforms results', () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + embedding.setFloatEmbedding(floatEmbedding); + const resultProto = new EmbeddingResultProto(); + resultProto.addEmbeddings(embedding); + + function validateEmbeddingResult( + expectedEmbeddignResult: AudioEmbedderResult[]) { + expect(expectedEmbeddignResult.length).toEqual(1); + + const [embeddingResult] = expectedEmbeddignResult; + expect(embeddingResult.embeddings.length).toEqual(1); + expect(embeddingResult.embeddings[0]) + .toEqual( + {floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'}); + } + + it('from embeddings strem', async () => { + audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(audioEmbedder); + // Pass the test data to our listener + audioEmbedder.protoListener!(resultProto.serializeBinary()); + }); + + // Invoke the audio embedder + const embeddingResults = audioEmbedder.embed(new Float32Array([])); + validateEmbeddingResult(embeddingResults); + }); + + it('from timestamped embeddgins stream', async () => { + audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(audioEmbedder); + // Pass the test data to our listener + audioEmbedder.protoVectorListener!([resultProto.serializeBinary()]); + }); + + // Invoke the audio embedder + const embeddingResults = audioEmbedder.embed(new Float32Array([]), 42); + validateEmbeddingResult(embeddingResults); + }); + }); +}); diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 07f78ac20..fd97c3db4 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -4,6 +4,7 @@ # BERT-based text classification). load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -45,3 +46,24 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core:classifier_options", ], ) + +mediapipe_ts_library( + name = "text_classifier_test_lib", + testonly = True, + srcs = [ + "text_classifier_test.ts", + ], + deps = [ + ":text_classifier", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "text_classifier_test", + deps = [":text_classifier_test_lib"], +) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts new file mode 100644 index 000000000..841bf8c48 --- /dev/null +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts @@ -0,0 +1,152 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {TextClassifier} from './text_classifier'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class TextClassifierFake extends TextClassifier implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProto: Uint8Array) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('classifications_out'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + } +} + +describe('TextClassifier', () => { + let textClassifier: TextClassifierFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + textClassifier = new TextClassifierFake(); + await textClassifier.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(textClassifier); + verifyListenersRegistered(textClassifier); + }); + + it('reloads graph when settings are changed', async () => { + await textClassifier.setOptions({maxResults: 1}); + verifyGraph(textClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyListenersRegistered(textClassifier); + + await textClassifier.setOptions({maxResults: 5}); + verifyGraph(textClassifier, [['classifierOptions', 'maxResults'], 5]); + verifyListenersRegistered(textClassifier); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await textClassifier.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + textClassifier, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await textClassifier.setOptions({maxResults: 1}); + await textClassifier.setOptions({displayNamesLocale: 'en'}); + verifyGraph(textClassifier, [ + 'classifierOptions', { + maxResults: 1, + displayNamesLocale: 'en', + scoreThreshold: undefined, + categoryAllowlistList: [], + categoryDenylistList: [] + } + ]); + }); + + it('transforms results', async () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + clasification.setIndex(1); + clasification.setScore(0.2); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + // Pass the test data to our listener + textClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(textClassifier); + textClassifier.protoListener!(classificationResult.serializeBinary()); + }); + + // Invoke the text classifier + const result = textClassifier.classify('foo'); + + expect(textClassifier.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result).toEqual({ + classifications: [{ + categories: [{ + index: 1, + score: 0.2, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }] + }); + }); +}); diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index 7d796fb7e..1514944bf 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -4,6 +4,7 @@ # load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -44,3 +45,23 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core:embedder_options", ], ) + +mediapipe_ts_library( + name = "text_embedder_test_lib", + testonly = True, + srcs = [ + "text_embedder_test.ts", + ], + deps = [ + ":text_embedder", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "text_embedder_test", + deps = [":text_embedder_test_lib"], +) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts new file mode 100644 index 000000000..04a9b371a --- /dev/null +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts @@ -0,0 +1,165 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Embedding, EmbeddingResult, FloatEmbedding, QuantizedEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {TextEmbedder} from './text_embedder'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class TextEmbedderFake extends TextEmbedder implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.text.text_embedder.TextEmbedderGraph'; + graph: CalculatorGraphConfig|undefined; + attachListenerSpies: jasmine.Spy[] = []; + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProtos: Uint8Array) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('embeddings_out'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + } +} + +describe('TextEmbedder', () => { + let textEmbedder: TextEmbedderFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + textEmbedder = new TextEmbedderFake(); + await textEmbedder.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(textEmbedder); + verifyListenersRegistered(textEmbedder); + }); + + it('reloads graph when settings are changed', async () => { + await textEmbedder.setOptions({quantize: true}); + verifyGraph(textEmbedder, [['embedderOptions', 'quantize'], true]); + verifyListenersRegistered(textEmbedder); + + await textEmbedder.setOptions({quantize: undefined}); + verifyGraph(textEmbedder, [['embedderOptions', 'quantize'], undefined]); + verifyListenersRegistered(textEmbedder); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await textEmbedder.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + textEmbedder, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('combines options', async () => { + await textEmbedder.setOptions({quantize: true}); + await textEmbedder.setOptions({l2Normalize: true}); + verifyGraph( + textEmbedder, + ['embedderOptions', {'quantize': true, 'l2Normalize': true}]); + }); + + it('transforms results', async () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + embedding.setFloatEmbedding(floatEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + + // Pass the test data to our listener + textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(textEmbedder); + textEmbedder.protoListener!(resultProto.serializeBinary()); + }); + + // Invoke the text embedder + const embeddingResult = textEmbedder.embed('foo'); + + expect(textEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingResult.embeddings.length).toEqual(1); + expect(embeddingResult.embeddings[0]) + .toEqual( + {floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'}); + }); + + it('transforms custom quantized values', async () => { + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + + const quantizedEmbedding = new QuantizedEmbedding(); + const quantizedValues = new Uint8Array([1, 2, 3]); + quantizedEmbedding.setValues(quantizedValues); + + embedding.setQuantizedEmbedding(quantizedEmbedding); + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + + // Pass the test data to our listener + textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(textEmbedder); + textEmbedder.protoListener!(resultProto.serializeBinary()); + }); + + // Invoke the text embedder + const embeddingsResult = textEmbedder.embed('foo'); + + expect(textEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingsResult.embeddings.length).toEqual(1); + expect(embeddingsResult.embeddings[0]).toEqual({ + quantizedEmbedding: new Uint8Array([1, 2, 3]), + headIndex: 1, + headName: 'headName' + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index b389a9b01..e4ea3036f 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -1,5 +1,6 @@ # This package contains options shared by all MediaPipe Vision Tasks for Web. +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -22,3 +23,20 @@ mediapipe_ts_library( "//mediapipe/web/graph_runner:graph_runner_ts", ], ) + +mediapipe_ts_library( + name = "vision_task_runner_test_lib", + testonly = True, + srcs = ["vision_task_runner.test.ts"], + deps = [ + ":vision_task_runner", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +jasmine_node_test( + name = "vision_task_runner_test", + deps = [":vision_task_runner_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts new file mode 100644 index 000000000..6cc9ea328 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -0,0 +1,99 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; + +import {VisionTaskRunner} from './vision_task_runner'; + +class VisionTaskRunnerFake extends VisionTaskRunner { + baseOptions = new BaseOptionsProto(); + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + } + + protected override process(): void {} + + override processImageData(image: ImageSource): void { + super.processImageData(image); + } + + override processVideoData(imageFrame: ImageSource, timestamp: number): void { + super.processVideoData(imageFrame, timestamp); + } +} + +describe('VisionTaskRunner', () => { + const streamMode = { + modelAsset: undefined, + useStreamMode: true, + acceleration: undefined, + }; + + const imageMode = { + modelAsset: undefined, + useStreamMode: false, + acceleration: undefined, + }; + + let visionTaskRunner: VisionTaskRunnerFake; + + beforeEach(() => { + visionTaskRunner = new VisionTaskRunnerFake(); + }); + + it('can enable image mode', async () => { + await visionTaskRunner.setOptions({runningMode: 'image'}); + expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + }); + + it('can enable video mode', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode); + }); + + it('can clear running mode', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + + // Clear running mode + await visionTaskRunner.setOptions({runningMode: undefined}); + expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + }); + + it('cannot process images with video mode', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + expect(() => { + visionTaskRunner.processImageData({} as HTMLImageElement); + }).toThrowError(/Task is not initialized with image mode./); + }); + + it('cannot process video with image mode', async () => { + // Use default for `useStreamMode` + expect(() => { + visionTaskRunner.processVideoData({} as HTMLImageElement, 42); + }).toThrowError(/Task is not initialized with video mode./); + + // Explicitly set to image mode + await visionTaskRunner.setOptions({runningMode: 'image'}); + expect(() => { + visionTaskRunner.processVideoData({} as HTMLImageElement, 42); + }).toThrowError(/Task is not initialized with video mode./); + }); +}); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index 6e2e56196..aa2f9c366 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -4,6 +4,7 @@ # the detection results for one or more gesture categories, using Gesture Recognizer. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -52,3 +53,27 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "gesture_recognizer_test_lib", + testonly = True, + srcs = [ + "gesture_recognizer_test.ts", + ], + deps = [ + ":gesture_recognizer", + ":gesture_recognizer_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "gesture_recognizer_test", + tags = ["nomsan"], + deps = [":gesture_recognizer_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts new file mode 100644 index 000000000..c0f0d1554 --- /dev/null +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -0,0 +1,307 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; +import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {GestureRecognizer, GestureRecognizerOptions} from './gesture_recognizer'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +type ProtoListener = ((binaryProtos: Uint8Array[]) => void); + +function createHandednesses(): Uint8Array[] { + const handsProto = new ClassificationList(); + const classification = new Classification(); + classification.setScore(0.1); + classification.setIndex(1); + classification.setLabel('handedness_label'); + classification.setDisplayName('handedness_display_name'); + handsProto.addClassification(classification); + return [handsProto.serializeBinary()]; +} + +function createGestures(): Uint8Array[] { + const gesturesProto = new ClassificationList(); + const classification = new Classification(); + classification.setScore(0.2); + classification.setIndex(2); + classification.setLabel('gesture_label'); + classification.setDisplayName('gesture_display_name'); + gesturesProto.addClassification(classification); + return [gesturesProto.serializeBinary()]; +} + +function createLandmarks(): Uint8Array[] { + const handLandmarksProto = new NormalizedLandmarkList(); + const landmark = new NormalizedLandmark(); + landmark.setX(0.3); + landmark.setY(0.4); + landmark.setZ(0.5); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +function createWorldLandmarks(): Uint8Array[] { + const handLandmarksProto = new LandmarkList(); + const landmark = new Landmark(); + landmark.setX(21); + landmark.setY(22); + landmark.setZ(23); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +class GestureRecognizerFake extends GestureRecognizer implements + MediapipeTasksFake { + calculatorName = + 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + fakeWasmModule: SpyWasmModule; + listeners = new Map(); + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toMatch( + /(hand_landmarks|world_hand_landmarks|handedness|hand_gestures)/); + this.listeners.set(stream, listener); + }); + + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + spyOn(this.graphRunner, 'addProtoToStream'); + } + + getGraphRunner(): GraphRunnerImageLib { + return this.graphRunner; + } +} + +describe('GestureRecognizer', () => { + let gestureRecognizer: GestureRecognizerFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + gestureRecognizer = new GestureRecognizerFake(); + await gestureRecognizer.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(gestureRecognizer); + verifyListenersRegistered(gestureRecognizer); + }); + + it('reloads graph when settings are changed', async () => { + await gestureRecognizer.setOptions({numHands: 1}); + verifyGraph(gestureRecognizer, [ + ['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 1 + ]); + verifyListenersRegistered(gestureRecognizer); + + await gestureRecognizer.setOptions({numHands: 5}); + verifyGraph(gestureRecognizer, [ + ['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 5 + ]); + verifyListenersRegistered(gestureRecognizer); + }); + + it('merges options', async () => { + await gestureRecognizer.setOptions({numHands: 1}); + await gestureRecognizer.setOptions({minHandDetectionConfidence: 0.5}); + verifyGraph(gestureRecognizer, [ + ['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 1 + ]); + verifyGraph(gestureRecognizer, [ + [ + 'handLandmarkerGraphOptions', 'handDetectorGraphOptions', + 'minDetectionConfidence' + ], + 0.5 + ]); + }); + + describe('setOptions() ', () => { + interface TestCase { + optionPath: [keyof GestureRecognizerOptions, ...string[]]; + fieldPath: string[]; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionPath: ['numHands'], + fieldPath: [ + 'handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands' + ], + customValue: 5, + defaultValue: 1 + }, + { + optionPath: ['minHandDetectionConfidence'], + fieldPath: [ + 'handLandmarkerGraphOptions', 'handDetectorGraphOptions', + 'minDetectionConfidence' + ], + customValue: 0.1, + defaultValue: 0.5 + }, + { + optionPath: ['minHandPresenceConfidence'], + fieldPath: [ + 'handLandmarkerGraphOptions', 'handLandmarksDetectorGraphOptions', + 'minDetectionConfidence' + ], + customValue: 0.2, + defaultValue: 0.5 + }, + { + optionPath: ['minTrackingConfidence'], + fieldPath: ['handLandmarkerGraphOptions', 'minTrackingConfidence'], + customValue: 0.3, + defaultValue: 0.5 + }, + { + optionPath: ['cannedGesturesClassifierOptions', 'scoreThreshold'], + fieldPath: [ + 'handGestureRecognizerGraphOptions', + 'cannedGestureClassifierGraphOptions', 'classifierOptions', + 'scoreThreshold' + ], + customValue: 0.4, + defaultValue: undefined + }, + { + optionPath: ['customGesturesClassifierOptions', 'scoreThreshold'], + fieldPath: [ + 'handGestureRecognizerGraphOptions', + 'customGestureClassifierGraphOptions', 'classifierOptions', + 'scoreThreshold' + ], + customValue: 0.5, + defaultValue: undefined, + }, + ]; + + /** Creates an options object that can be passed to setOptions() */ + function createOptions( + path: string[], value: unknown): GestureRecognizerOptions { + const options: Record = {}; + let currentLevel = options; + for (const element of path.slice(0, -1)) { + currentLevel[element] = {}; + currentLevel = currentLevel[element] as Record; + } + currentLevel[path[path.length - 1]] = value; + return options; + } + + for (const testCase of testCases) { + it(`uses default value for ${testCase.optionPath[0]}`, async () => { + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.defaultValue]); + }); + + it(`can set ${testCase.optionPath[0]}`, async () => { + await gestureRecognizer.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.customValue]); + }); + + it(`can clear ${testCase.optionPath[0]}`, async () => { + await gestureRecognizer.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.customValue]); + + await gestureRecognizer.setOptions( + createOptions(testCase.optionPath, undefined)); + verifyGraph( + gestureRecognizer, [testCase.fieldPath, testCase.defaultValue]); + }); + } + }); + + it('transforms results', async () => { + // Pass the test data to our listener + gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(gestureRecognizer); + gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + gestureRecognizer.listeners.get('handedness')!(createHandednesses()); + gestureRecognizer.listeners.get('hand_gestures')!(createGestures()); + }); + + // Invoke the gesture recognizer + const gestures = gestureRecognizer.recognize({} as HTMLImageElement); + expect(gestureRecognizer.getGraphRunner().addProtoToStream) + .toHaveBeenCalledTimes(1); + expect(gestureRecognizer.getGraphRunner().addGpuBufferAsImageToStream) + .toHaveBeenCalledTimes(1); + expect(gestureRecognizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + + expect(gestures).toEqual({ + 'gestures': [[{ + 'score': 0.2, + 'index': 2, + 'categoryName': 'gesture_label', + 'displayName': 'gesture_display_name' + }]], + 'landmarks': [[{'x': 0.3, 'y': 0.4, 'z': 0.5}]], + 'worldLandmarks': [[{'x': 21, 'y': 22, 'z': 23}]], + 'handednesses': [[{ + 'score': 0.1, + 'index': 1, + 'categoryName': 'handedness_label', + 'displayName': 'handedness_display_name' + }]] + }); + }); + + it('clears results between invoations', async () => { + // Pass the test data to our listener + gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + gestureRecognizer.listeners.get('handedness')!(createHandednesses()); + gestureRecognizer.listeners.get('hand_gestures')!(createGestures()); + }); + + // Invoke the gesture recognizer twice + const gestures1 = gestureRecognizer.recognize({} as HTMLImageElement); + const gestures2 = gestureRecognizer.recognize({} as HTMLImageElement); + + // Verify that gestures2 is not a concatenation of all previously returned + // gestures. + expect(gestures2).toEqual(gestures1); + }); +}); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index 520898e34..d1f1e48f3 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -4,6 +4,7 @@ # the detection results for one or more hand categories, using Hand Landmarker. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -47,3 +48,27 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "hand_landmarker_test_lib", + testonly = True, + srcs = [ + "hand_landmarker_test.ts", + ], + deps = [ + ":hand_landmarker", + ":hand_landmarker_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "hand_landmarker_test", + tags = ["nomsan"], + deps = [":hand_landmarker_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts new file mode 100644 index 000000000..fc26680e0 --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -0,0 +1,251 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; +import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {HandLandmarker} from './hand_landmarker'; +import {HandLandmarkerOptions} from './hand_landmarker_options'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +type ProtoListener = ((binaryProtos: Uint8Array[]) => void); + +function createHandednesses(): Uint8Array[] { + const handsProto = new ClassificationList(); + const classification = new Classification(); + classification.setScore(0.1); + classification.setIndex(1); + classification.setLabel('handedness_label'); + classification.setDisplayName('handedness_display_name'); + handsProto.addClassification(classification); + return [handsProto.serializeBinary()]; +} + +function createLandmarks(): Uint8Array[] { + const handLandmarksProto = new NormalizedLandmarkList(); + const landmark = new NormalizedLandmark(); + landmark.setX(0.3); + landmark.setY(0.4); + landmark.setZ(0.5); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +function createWorldLandmarks(): Uint8Array[] { + const handLandmarksProto = new LandmarkList(); + const landmark = new Landmark(); + landmark.setX(21); + landmark.setY(22); + landmark.setZ(23); + handLandmarksProto.addLandmark(landmark); + return [handLandmarksProto.serializeBinary()]; +} + +class HandLandmarkerFake extends HandLandmarker implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + fakeWasmModule: SpyWasmModule; + listeners = new Map(); + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toMatch( + /(hand_landmarks|world_hand_landmarks|handedness|hand_hands)/); + this.listeners.set(stream, listener); + }); + + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + spyOn(this.graphRunner, 'addProtoToStream'); + } + + getGraphRunner(): GraphRunnerImageLib { + return this.graphRunner; + } +} + +describe('HandLandmarker', () => { + let handLandmarker: HandLandmarkerFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + handLandmarker = new HandLandmarkerFake(); + await handLandmarker.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(handLandmarker); + verifyListenersRegistered(handLandmarker); + }); + + it('reloads graph when settings are changed', async () => { + verifyListenersRegistered(handLandmarker); + + await handLandmarker.setOptions({numHands: 1}); + verifyGraph(handLandmarker, [['handDetectorGraphOptions', 'numHands'], 1]); + verifyListenersRegistered(handLandmarker); + + await handLandmarker.setOptions({numHands: 5}); + verifyGraph(handLandmarker, [['handDetectorGraphOptions', 'numHands'], 5]); + verifyListenersRegistered(handLandmarker); + }); + + it('merges options', async () => { + await handLandmarker.setOptions({numHands: 1}); + await handLandmarker.setOptions({minHandDetectionConfidence: 0.5}); + verifyGraph(handLandmarker, [ + 'handDetectorGraphOptions', + {numHands: 1, baseOptions: undefined, minDetectionConfidence: 0.5} + ]); + }); + + describe('setOptions() ', () => { + interface TestCase { + optionPath: [keyof HandLandmarkerOptions, ...string[]]; + fieldPath: string[]; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionPath: ['numHands'], + fieldPath: ['handDetectorGraphOptions', 'numHands'], + customValue: 5, + defaultValue: 1 + }, + { + optionPath: ['minHandDetectionConfidence'], + fieldPath: ['handDetectorGraphOptions', 'minDetectionConfidence'], + customValue: 0.1, + defaultValue: 0.5 + }, + { + optionPath: ['minHandPresenceConfidence'], + fieldPath: + ['handLandmarksDetectorGraphOptions', 'minDetectionConfidence'], + customValue: 0.2, + defaultValue: 0.5 + }, + { + optionPath: ['minTrackingConfidence'], + fieldPath: ['minTrackingConfidence'], + customValue: 0.3, + defaultValue: 0.5 + }, + ]; + + /** Creates an options object that can be passed to setOptions() */ + function createOptions( + path: string[], value: unknown): HandLandmarkerOptions { + const options: Record = {}; + let currentLevel = options; + for (const element of path.slice(0, -1)) { + currentLevel[element] = {}; + currentLevel = currentLevel[element] as Record; + } + currentLevel[path[path.length - 1]] = value; + return options; + } + + for (const testCase of testCases) { + it(`uses default value for ${testCase.optionPath[0]}`, async () => { + verifyGraph( + handLandmarker, [testCase.fieldPath, testCase.defaultValue]); + }); + + it(`can set ${testCase.optionPath[0]}`, async () => { + await handLandmarker.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph(handLandmarker, [testCase.fieldPath, testCase.customValue]); + }); + + it(`can clear ${testCase.optionPath[0]}`, async () => { + await handLandmarker.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph(handLandmarker, [testCase.fieldPath, testCase.customValue]); + + await handLandmarker.setOptions( + createOptions(testCase.optionPath, undefined)); + verifyGraph( + handLandmarker, [testCase.fieldPath, testCase.defaultValue]); + }); + } + }); + + it('transforms results', async () => { + // Pass the test data to our listener + handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(handLandmarker); + handLandmarker.listeners.get('hand_landmarks')!(createLandmarks()); + handLandmarker.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + handLandmarker.listeners.get('handedness')!(createHandednesses()); + }); + + // Invoke the hand landmarker + const landmarks = handLandmarker.detect({} as HTMLImageElement); + expect(handLandmarker.getGraphRunner().addProtoToStream) + .toHaveBeenCalledTimes(1); + expect(handLandmarker.getGraphRunner().addGpuBufferAsImageToStream) + .toHaveBeenCalledTimes(1); + expect(handLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + + expect(landmarks).toEqual({ + 'landmarks': [[{'x': 0.3, 'y': 0.4, 'z': 0.5}]], + 'worldLandmarks': [[{'x': 21, 'y': 22, 'z': 23}]], + 'handednesses': [[{ + 'score': 0.1, + 'index': 1, + 'categoryName': 'handedness_label', + 'displayName': 'handedness_display_name' + }]] + }); + }); + + it('clears results between invoations', async () => { + // Pass the test data to our listener + handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { + handLandmarker.listeners.get('hand_landmarks')!(createLandmarks()); + handLandmarker.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + handLandmarker.listeners.get('handedness')!(createHandednesses()); + }); + + // Invoke the hand landmarker twice + const landmarks1 = handLandmarker.detect({} as HTMLImageElement); + const landmarks2 = handLandmarker.detect({} as HTMLImageElement); + + // Verify that hands2 is not a concatenation of all previously returned + // hands. + expect(landmarks1).toEqual(landmarks2); + }); +}); diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index 848c162ae..310575964 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -3,6 +3,7 @@ # This task takes video or image frames and outputs the classification result. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -44,3 +45,26 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "image_classifier_test_lib", + testonly = True, + srcs = [ + "image_classifier_test.ts", + ], + deps = [ + ":image_classifier", + ":image_classifier_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "image_classifier_test", + tags = ["nomsan"], + deps = [":image_classifier_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts new file mode 100644 index 000000000..2041a0cef --- /dev/null +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts @@ -0,0 +1,150 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {ImageClassifier} from './image_classifier'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class ImageClassifierFake extends ImageClassifier implements + MediapipeTasksFake { + calculatorName = + 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProto: Uint8Array) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('classifications'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ImageClassifier', () => { + let imageClassifier: ImageClassifierFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + imageClassifier = new ImageClassifierFake(); + await imageClassifier.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(imageClassifier); + verifyListenersRegistered(imageClassifier); + }); + + it('reloads graph when settings are changed', async () => { + await imageClassifier.setOptions({maxResults: 1}); + verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyListenersRegistered(imageClassifier); + + await imageClassifier.setOptions({maxResults: 5}); + verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 5]); + verifyListenersRegistered(imageClassifier); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await imageClassifier.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + imageClassifier, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await imageClassifier.setOptions({maxResults: 1}); + await imageClassifier.setOptions({displayNamesLocale: 'en'}); + verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 1]); + verifyGraph( + imageClassifier, [['classifierOptions', 'displayNamesLocale'], 'en']); + }); + + it('transforms results', async () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + const classificationList = new ClassificationList(); + const clasification = new Classification(); + clasification.setIndex(1); + clasification.setScore(0.2); + clasification.setDisplayName('displayName'); + clasification.setLabel('categoryName'); + classificationList.addClassification(clasification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + // Pass the test data to our listener + imageClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(imageClassifier); + imageClassifier.protoListener!(classificationResult.serializeBinary()); + }); + + // Invoke the image classifier + const result = imageClassifier.classify({} as HTMLImageElement); + + expect(imageClassifier.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result).toEqual({ + classifications: [{ + categories: [{ + index: 1, + score: 0.2, + displayName: 'displayName', + categoryName: 'categoryName' + }], + headIndex: 1, + headName: 'headName' + }] + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index 6c9d80fb1..de4785e6c 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -3,6 +3,7 @@ # This task performs embedding extraction on images. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -45,3 +46,23 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "image_embedder_test_lib", + testonly = True, + srcs = [ + "image_embedder_test.ts", + ], + deps = [ + ":image_embedder", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "image_embedder_test", + deps = [":image_embedder_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts new file mode 100644 index 000000000..cafe0f3d8 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts @@ -0,0 +1,158 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Embedding, EmbeddingResult, FloatEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {ImageEmbedder} from './image_embedder'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class ImageEmbedderFake extends ImageEmbedder implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph'; + graph: CalculatorGraphConfig|undefined; + attachListenerSpies: jasmine.Spy[] = []; + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProtos: Uint8Array) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('embeddings_out'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ImageEmbedder', () => { + let imageEmbedder: ImageEmbedderFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + imageEmbedder = new ImageEmbedderFake(); + await imageEmbedder.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(imageEmbedder); + verifyListenersRegistered(imageEmbedder); + }); + + it('reloads graph when settings are changed', async () => { + verifyListenersRegistered(imageEmbedder); + + await imageEmbedder.setOptions({quantize: true}); + verifyGraph(imageEmbedder, [['embedderOptions', 'quantize'], true]); + verifyListenersRegistered(imageEmbedder); + + await imageEmbedder.setOptions({quantize: undefined}); + verifyGraph(imageEmbedder, [['embedderOptions', 'quantize'], undefined]); + verifyListenersRegistered(imageEmbedder); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await imageEmbedder.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + imageEmbedder, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */[ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('overrides options', async () => { + await imageEmbedder.setOptions({quantize: true}); + await imageEmbedder.setOptions({l2Normalize: true}); + verifyGraph( + imageEmbedder, + ['embedderOptions', {'quantize': true, 'l2Normalize': true}]); + }); + + describe('transforms result', () => { + beforeEach(() => { + const floatEmbedding = new FloatEmbedding(); + floatEmbedding.setValuesList([0.1, 0.9]); + + const embedding = new Embedding(); + embedding.setHeadIndex(1); + embedding.setHeadName('headName'); + embedding.setFloatEmbedding(floatEmbedding); + + const resultProto = new EmbeddingResult(); + resultProto.addEmbeddings(embedding); + resultProto.setTimestampMs(42); + + // Pass the test data to our listener + imageEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(imageEmbedder); + imageEmbedder.protoListener!(resultProto.serializeBinary()); + }); + }); + + it('for image mode', async () => { + // Invoke the image embedder + const embeddingResult = imageEmbedder.embed({} as HTMLImageElement); + + expect(imageEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingResult).toEqual({ + embeddings: + [{headIndex: 1, headName: 'headName', floatEmbedding: [0.1, 0.9]}], + timestampMs: 42 + }); + }); + + it('for video mode', async () => { + await imageEmbedder.setOptions({runningMode: 'video'}); + + // Invoke the video embedder + const embeddingResult = + imageEmbedder.embedForVideo({} as HTMLImageElement, 42); + + expect(imageEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(embeddingResult).toEqual({ + embeddings: + [{headIndex: 1, headName: 'headName', floatEmbedding: [0.1, 0.9]}], + timestampMs: 42 + }); + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index f73790895..fc206a2d7 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -4,6 +4,7 @@ # the detection results for one or more object categories, using Object Detector. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -41,3 +42,26 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) + +mediapipe_ts_library( + name = "object_detector_test_lib", + testonly = True, + srcs = [ + "object_detector_test.ts", + ], + deps = [ + ":object_detector", + ":object_detector_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/framework/formats:location_data_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "object_detector_test", + tags = ["nomsan"], + deps = [":object_detector_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts new file mode 100644 index 000000000..fff1a1c48 --- /dev/null +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -0,0 +1,229 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {LocationData} from '../../../../framework/formats/location_data_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {ObjectDetector} from './object_detector'; +import {ObjectDetectorOptions} from './object_detector_options'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class ObjectDetectorFake extends ObjectDetector implements MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = 'mediapipe.tasks.vision.ObjectDetectorGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + protoListener: ((binaryProtos: Uint8Array[]) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('detections'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ObjectDetector', () => { + let objectDetector: ObjectDetectorFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + objectDetector = new ObjectDetectorFake(); + await objectDetector.setOptions({}); // Initialize graph + }); + + it('initializes graph', async () => { + verifyGraph(objectDetector); + verifyListenersRegistered(objectDetector); + }); + + it('reloads graph when settings are changed', async () => { + await objectDetector.setOptions({maxResults: 1}); + verifyGraph(objectDetector, ['maxResults', 1]); + verifyListenersRegistered(objectDetector); + + await objectDetector.setOptions({maxResults: 5}); + verifyGraph(objectDetector, ['maxResults', 5]); + verifyListenersRegistered(objectDetector); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await objectDetector.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + objectDetector, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await objectDetector.setOptions({maxResults: 1}); + await objectDetector.setOptions({displayNamesLocale: 'en'}); + verifyGraph(objectDetector, ['maxResults', 1]); + verifyGraph(objectDetector, ['displayNamesLocale', 'en']); + }); + + describe('setOptions() ', () => { + interface TestCase { + optionName: keyof ObjectDetectorOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionName: 'maxResults', + protoName: 'maxResults', + customValue: 5, + defaultValue: -1 + }, + { + optionName: 'displayNamesLocale', + protoName: 'displayNamesLocale', + customValue: 'en', + defaultValue: 'en' + }, + { + optionName: 'scoreThreshold', + protoName: 'scoreThreshold', + customValue: 0.1, + defaultValue: undefined + }, + { + optionName: 'categoryAllowlist', + protoName: 'categoryAllowlistList', + customValue: ['foo'], + defaultValue: [] + }, + { + optionName: 'categoryDenylist', + protoName: 'categoryDenylistList', + customValue: ['bar'], + defaultValue: [] + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, async () => { + await objectDetector.setOptions( + {[testCase.optionName]: testCase.customValue}); + verifyGraph(objectDetector, [testCase.protoName, testCase.customValue]); + }); + + it(`can clear ${testCase.optionName}`, async () => { + await objectDetector.setOptions( + {[testCase.optionName]: testCase.customValue}); + verifyGraph(objectDetector, [testCase.protoName, testCase.customValue]); + await objectDetector.setOptions({[testCase.optionName]: undefined}); + verifyGraph( + objectDetector, [testCase.protoName, testCase.defaultValue]); + }); + } + }); + + it('transforms results', async () => { + const detectionProtos: Uint8Array[] = []; + + // Add a detection with all optional properties + let detection = new DetectionProto(); + detection.addScore(0.1); + detection.addLabelId(1); + detection.addLabel('foo'); + detection.addDisplayName('bar'); + let locationData = new LocationData(); + let boundingBox = new LocationData.BoundingBox(); + boundingBox.setXmin(1); + boundingBox.setYmin(2); + boundingBox.setWidth(3); + boundingBox.setHeight(4); + locationData.setBoundingBox(boundingBox); + detection.setLocationData(locationData); + detectionProtos.push(detection.serializeBinary()); + + // Add a detection without optional properties + detection = new DetectionProto(); + detection.addScore(0.2); + locationData = new LocationData(); + boundingBox = new LocationData.BoundingBox(); + locationData.setBoundingBox(boundingBox); + detection.setLocationData(locationData); + detectionProtos.push(detection.serializeBinary()); + + // Pass the test data to our listener + objectDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(objectDetector); + objectDetector.protoListener!(detectionProtos); + }); + + // Invoke the object detector + const detections = objectDetector.detect({} as HTMLImageElement); + + expect(objectDetector.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(detections.length).toEqual(2); + expect(detections[0]).toEqual({ + categories: [{ + score: 0.1, + index: 1, + categoryName: 'foo', + displayName: 'bar', + }], + boundingBox: {originX: 1, originY: 2, width: 3, height: 4} + }); + expect(detections[1]).toEqual({ + categories: [{ + score: 0.2, + index: -1, + categoryName: '', + displayName: '', + }], + boundingBox: {originX: 0, originY: 0, width: 0, height: 0} + }); + }); +}); From d1820320b15893a0f0b947ed208bdcfb630bb938 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 8 Dec 2022 10:23:53 +0530 Subject: [PATCH 40/64] Added base options --- mediapipe/tasks/ios/core/BUILD | 33 ++++++++++++ .../tasks/ios/core/sources/MPPBaseOptions.h | 51 +++++++++++++++++++ .../tasks/ios/core/sources/MPPBaseOptions.m | 36 +++++++++++++ .../tasks/ios/core/sources/MPPExternalFile.h | 28 ++++++++++ .../tasks/ios/core/sources/MPPExternalFile.m | 27 ++++++++++ 5 files changed, 175 insertions(+) create mode 100644 mediapipe/tasks/ios/core/BUILD create mode 100644 mediapipe/tasks/ios/core/sources/MPPBaseOptions.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPBaseOptions.m create mode 100644 mediapipe/tasks/ios/core/sources/MPPExternalFile.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPExternalFile.m diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD new file mode 100644 index 000000000..9b8ad7bec --- /dev/null +++ b/mediapipe/tasks/ios/core/BUILD @@ -0,0 +1,33 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPExternalFile", + srcs = ["sources/MPPExternalFile.m"], + hdrs = ["sources/MPPExternalFile.h"], +) + +objc_library( + name = "MPPBaseOptions", + srcs = ["sources/MPPBaseOptions.m"], + hdrs = ["sources/MPPBaseOptions.h"], + deps = [ + ":MPPExternalFile", + + ], +) diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h new file mode 100644 index 000000000..87b6826df --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h @@ -0,0 +1,51 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import +#import "mediapipe/tasks/ios/core/sources/MPPExternalFile.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * MediaPipe Tasks delegate. + */ +typedef NS_ENUM(NSUInteger, MPPDelegate) { + + /** CPU. */ + MPPDelegateCPU, + + /** GPU. */ + MPPDelegateGPU +} NS_SWIFT_NAME(Delegate); + +/** + * Holds the base options that is used for creation of any type of task. It has fields with + * important information acceleration configuration, TFLite model source etc. + */ +NS_SWIFT_NAME(BaseOptions) +@interface MPPBaseOptions : NSObject + +/** + * The external model file, as a single standalone TFLite file. It could be packed with TFLite Model + * Metadata[1] and associated files if exist. Fail to provide the necessary metadata and associated + * files might result in errors. + */ +@property(nonatomic, copy) MPPExternalFile *modelAssetFile; + +/** + * device delegate to run the MediaPipe pipeline. If the delegate is not set, the default + * delegate CPU is used. + */ +@property(nonatomic) MPPDelegate delegate; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m new file mode 100644 index 000000000..4c25b80e8 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m @@ -0,0 +1,36 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" + +@implementation MPPBaseOptions + +- (instancetype)init { + self = [super init]; + if (self) { + self.modelAssetFile = [[MPPExternalFile alloc] init]; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init]; + + baseOptions.modelAssetFile = self.modelAssetFile; + baseOptions.delegate = self.delegate; + + return baseOptions; +} + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPExternalFile.h b/mediapipe/tasks/ios/core/sources/MPPExternalFile.h new file mode 100644 index 000000000..a97802002 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPExternalFile.h @@ -0,0 +1,28 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * Holds information about an external file. + */ +NS_SWIFT_NAME(ExternalFile) +@interface MPPExternalFile : NSObject + +/** Path to the file in bundle. */ +@property(nonatomic, copy) NSString *filePath; +/// Add provision for other sources in future. + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPExternalFile.m b/mediapipe/tasks/ios/core/sources/MPPExternalFile.m new file mode 100644 index 000000000..70d85657c --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPExternalFile.m @@ -0,0 +1,27 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/core/sources/MPPExternalFile.h" + +@implementation MPPExternalFile + +- (id)copyWithZone:(NSZone *)zone { + MPPExternalFile *externalFile = [[MPPExternalFile alloc] init]; + + externalFile.filePath = self.filePath; + + return externalFile; +} + +@end From 66dbd9969a0aae7f71ce7096135a3b436ea76473 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 8 Dec 2022 10:25:01 +0530 Subject: [PATCH 41/64] Updated license text --- .../tasks/ios/core/sources/MPPBaseOptions.h | 25 +++++++++++-------- .../tasks/ios/core/sources/MPPExternalFile.h | 25 +++++++++++-------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h index 87b6826df..258b49b3b 100644 --- a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h @@ -1,14 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #import #import "mediapipe/tasks/ios/core/sources/MPPExternalFile.h" diff --git a/mediapipe/tasks/ios/core/sources/MPPExternalFile.h b/mediapipe/tasks/ios/core/sources/MPPExternalFile.h index a97802002..300fd4778 100644 --- a/mediapipe/tasks/ios/core/sources/MPPExternalFile.h +++ b/mediapipe/tasks/ios/core/sources/MPPExternalFile.h @@ -1,14 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #import NS_ASSUME_NONNULL_BEGIN From 13f8fa51393a2883ec825ad717bfffb693d59376 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 8 Dec 2022 07:59:46 -0800 Subject: [PATCH 42/64] Retire the visibility group "//mediapipe/framework:mediapipe_internal" in the "mediapipe/calculators/tensor" dir. PiperOrigin-RevId: 493895834 --- mediapipe/calculators/tensor/BUILD | 76 ++++-------------------------- 1 file changed, 8 insertions(+), 68 deletions(-) diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 577ac4111..dec68deac 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -24,7 +24,7 @@ load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) exports_files( glob(["testdata/image_to_tensor/*"]), @@ -44,9 +44,6 @@ selects.config_setting_group( mediapipe_proto_library( name = "audio_to_tensor_calculator_proto", srcs = ["audio_to_tensor_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -64,9 +61,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":audio_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -113,9 +107,6 @@ cc_test( mediapipe_proto_library( name = "tensors_to_audio_calculator_proto", srcs = ["tensors_to_audio_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -125,9 +116,6 @@ mediapipe_proto_library( cc_library( name = "tensors_to_audio_calculator", srcs = ["tensors_to_audio_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":tensors_to_audio_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -164,9 +152,6 @@ cc_test( mediapipe_proto_library( name = "feedback_tensors_calculator_proto", srcs = ["feedback_tensors_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -184,9 +169,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":feedback_tensors_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -216,9 +198,6 @@ cc_test( mediapipe_proto_library( name = "bert_preprocessor_calculator_proto", srcs = ["bert_preprocessor_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -228,9 +207,6 @@ mediapipe_proto_library( cc_library( name = "bert_preprocessor_calculator", srcs = ["bert_preprocessor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":bert_preprocessor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -274,9 +250,6 @@ cc_test( mediapipe_proto_library( name = "regex_preprocessor_calculator_proto", srcs = ["regex_preprocessor_calculator.proto"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -286,9 +259,6 @@ mediapipe_proto_library( cc_library( name = "regex_preprocessor_calculator", srcs = ["regex_preprocessor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ ":regex_preprocessor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -330,9 +300,6 @@ cc_test( cc_library( name = "text_to_tensor_calculator", srcs = ["text_to_tensor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", @@ -366,9 +333,6 @@ cc_test( cc_library( name = "universal_sentence_encoder_preprocessor_calculator", srcs = ["universal_sentence_encoder_preprocessor_calculator.cc"], - visibility = [ - "//mediapipe/framework:mediapipe_internal", - ], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", @@ -408,7 +372,6 @@ cc_test( mediapipe_proto_library( name = "inference_calculator_proto", srcs = ["inference_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -435,7 +398,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":inference_calculator_cc_proto", ":inference_calculator_options_lib", @@ -460,7 +422,6 @@ cc_library( name = "inference_calculator_gl", srcs = ["inference_calculator_gl.cc"], tags = ["nomac"], # config problem with cpuinfo via TF - visibility = ["//visibility:public"], deps = [ ":inference_calculator_cc_proto", ":inference_calculator_interface", @@ -478,7 +439,6 @@ cc_library( name = "inference_calculator_gl_advanced", srcs = ["inference_calculator_gl_advanced.cc"], tags = ["nomac"], - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", "@com_google_absl//absl/memory", @@ -509,7 +469,6 @@ cc_library( "-framework MetalKit", ], tags = ["ios"], - visibility = ["//visibility:public"], deps = [ "inference_calculator_interface", "//mediapipe/gpu:MPPMetalHelper", @@ -538,7 +497,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework/formats:tensor", @@ -558,7 +516,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":inference_runner", "//mediapipe/framework:mediapipe_profiling", @@ -588,7 +545,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", ":inference_calculator_utils", @@ -635,7 +591,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", ":inference_calculator_utils", @@ -651,7 +606,6 @@ cc_library( cc_library( name = "inference_calculator_gl_if_compute_shader_available", - visibility = ["//visibility:public"], deps = selects.with_or({ ":compute_shader_unavailable": [], "//conditions:default": [ @@ -667,7 +621,6 @@ cc_library( # inference_calculator_interface. cc_library( name = "inference_calculator", - visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", ":inference_calculator_cpu", @@ -681,7 +634,6 @@ cc_library( mediapipe_proto_library( name = "tensor_converter_calculator_proto", srcs = ["tensor_converter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -706,7 +658,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensor_converter_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -725,6 +676,7 @@ cc_library( cc_library( name = "tensor_converter_calculator_gpu_deps", + visibility = ["//visibility:private"], deps = select({ "//mediapipe:android": [ "//mediapipe/gpu:gl_calculator_helper", @@ -769,7 +721,6 @@ cc_test( mediapipe_proto_library( name = "tensors_to_detections_calculator_proto", srcs = ["tensors_to_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -794,7 +745,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_detections_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -817,6 +767,7 @@ cc_library( cc_library( name = "tensors_to_detections_calculator_gpu_deps", + visibility = ["//visibility:private"], deps = select({ "//mediapipe:ios": [ "//mediapipe/gpu:MPPMetalUtil", @@ -832,7 +783,6 @@ cc_library( mediapipe_proto_library( name = "tensors_to_landmarks_calculator_proto", srcs = ["tensors_to_landmarks_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -849,7 +799,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_landmarks_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -864,7 +813,6 @@ cc_library( mediapipe_proto_library( name = "landmarks_to_tensor_calculator_proto", srcs = ["landmarks_to_tensor_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -882,7 +830,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":landmarks_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -915,7 +862,6 @@ cc_test( mediapipe_proto_library( name = "tensors_to_floats_calculator_proto", srcs = ["tensors_to_floats_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -932,7 +878,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_floats_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -970,7 +915,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_classification_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -1001,7 +945,6 @@ cc_library( mediapipe_proto_library( name = "tensors_to_classification_calculator_proto", srcs = ["tensors_to_classification_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1039,7 +982,6 @@ cc_library( "//conditions:default": [], }), features = ["-layering_check"], # allow depending on image_to_tensor_calculator_gpu_deps - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_calculator_cc_proto", ":image_to_tensor_converter", @@ -1068,6 +1010,7 @@ cc_library( cc_library( name = "image_to_tensor_calculator_gpu_deps", + visibility = ["//visibility:private"], deps = selects.with_or({ "//mediapipe:android": [ ":image_to_tensor_converter_gl_buffer", @@ -1091,7 +1034,6 @@ cc_library( mediapipe_proto_library( name = "image_to_tensor_calculator_proto", srcs = ["image_to_tensor_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1154,7 +1096,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_utils", "//mediapipe/framework/formats:image", @@ -1174,7 +1115,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_converter", ":image_to_tensor_utils", @@ -1194,6 +1134,7 @@ cc_library( name = "image_to_tensor_converter_gl_buffer", srcs = ["image_to_tensor_converter_gl_buffer.cc"], hdrs = ["image_to_tensor_converter_gl_buffer.h"], + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + selects.with_or({ "//mediapipe:apple": [], "//conditions:default": [ @@ -1227,6 +1168,7 @@ cc_library( name = "image_to_tensor_converter_gl_texture", srcs = ["image_to_tensor_converter_gl_texture.cc"], hdrs = ["image_to_tensor_converter_gl_texture.h"], + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [ @@ -1251,6 +1193,7 @@ cc_library( name = "image_to_tensor_converter_gl_utils", srcs = ["image_to_tensor_converter_gl_utils.cc"], hdrs = ["image_to_tensor_converter_gl_utils.h"], + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [ @@ -1280,6 +1223,7 @@ cc_library( ], "//conditions:default": [], }), + visibility = ["//visibility:private"], deps = ["//mediapipe/framework:port"] + select({ "//mediapipe:apple": [ ":image_to_tensor_converter", @@ -1311,7 +1255,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_to_tensor_calculator_cc_proto", "@com_google_absl//absl/status", @@ -1354,7 +1297,6 @@ selects.config_setting_group( mediapipe_proto_library( name = "tensors_to_segmentation_calculator_proto", srcs = ["tensors_to_segmentation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1372,7 +1314,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensors_to_segmentation_calculator_cc_proto", "@com_google_absl//absl/strings:str_format", @@ -1430,7 +1371,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", From a641ea12e15ec8e3ff552647ca569dc1ee9f59bc Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 8 Dec 2022 11:30:39 -0800 Subject: [PATCH 43/64] Update gesture recognizer to new mediapipe tasks pipeline PiperOrigin-RevId: 493950564 --- .../python/vision/gesture_recognizer/BUILD | 14 ++-- .../vision/gesture_recognizer/dataset.py | 62 ++++++++++------- .../vision/gesture_recognizer/dataset_test.py | 67 ++++++++----------- .../gesture_recognizer/metadata_writer.py | 62 +++++++++++++---- .../metadata_writer_test.py | 17 +++++ mediapipe/tasks/python/core/BUILD | 8 +-- mediapipe/tasks/python/vision/BUILD | 4 ++ 7 files changed, 147 insertions(+), 87 deletions(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index 256447a8d..9123e36b0 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -35,20 +35,21 @@ py_library( srcs = ["constants.py"], ) -# TODO: Change to py_library after migrating the MediaPipe hand solution -# library to MediaPipe hand task library. py_library( name = "dataset", srcs = ["dataset.py"], deps = [ ":constants", + ":metadata_writer", "//mediapipe/model_maker/python/core/data:classification_dataset", - "//mediapipe/model_maker/python/core/data:data_util", "//mediapipe/model_maker/python/core/utils:model_util", - "//mediapipe/python/solutions:hands", + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/vision:hand_landmarker", ], ) +# TODO: Remove notsan tag once tasks no longer has race condition issue py_test( name = "dataset_test", srcs = ["dataset_test.py"], @@ -56,10 +57,11 @@ py_test( ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], + tags = ["notsan"], deps = [ ":dataset", - "//mediapipe/python/solutions:hands", "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/vision:hand_landmarker", ], ) @@ -131,6 +133,7 @@ py_library( ], ) +# TODO: Remove notsan tag once tasks no longer has race condition issue py_test( name = "gesture_recognizer_test", size = "large", @@ -140,6 +143,7 @@ py_test( "//mediapipe/model_maker/models/gesture_recognizer:models", ], shard_count = 2, + tags = ["notsan"], deps = [ ":gesture_recognizer_import", "//mediapipe/model_maker/python/core/utils:test_util", diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py index 256f26fd6..6a2c878c0 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py @@ -16,16 +16,22 @@ import dataclasses import os import random -from typing import List, NamedTuple, Optional +from typing import List, Optional -import cv2 import tensorflow as tf from mediapipe.model_maker.python.core.data import classification_dataset -from mediapipe.model_maker.python.core.data import data_util from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.vision.gesture_recognizer import constants -from mediapipe.python.solutions import hands as mp_hands +from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.vision import hand_landmarker as hand_landmarker_module + +_Image = image_module.Image +_HandLandmarker = hand_landmarker_module.HandLandmarker +_HandLandmarkerOptions = hand_landmarker_module.HandLandmarkerOptions +_HandLandmarkerResult = hand_landmarker_module.HandLandmarkerResult @dataclasses.dataclass @@ -59,7 +65,7 @@ class HandData: handedness: List[float] -def _validate_data_sample(data: NamedTuple) -> bool: +def _validate_data_sample(data: _HandLandmarkerResult) -> bool: """Validates the input hand data sample. Args: @@ -70,19 +76,17 @@ def _validate_data_sample(data: NamedTuple) -> bool: 'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness' or any of these attributes' values are none. Otherwise, True. """ - if (not hasattr(data, 'multi_hand_landmarks') or - data.multi_hand_landmarks is None): + if data.hand_landmarks is None or not data.hand_landmarks: return False - if (not hasattr(data, 'multi_hand_world_landmarks') or - data.multi_hand_world_landmarks is None): + if data.hand_world_landmarks is None or not data.hand_world_landmarks: return False - if not hasattr(data, 'multi_handedness') or data.multi_handedness is None: + if data.handedness is None or not data.handedness: return False return True def _get_hand_data(all_image_paths: List[str], - min_detection_confidence: float) -> Optional[HandData]: + min_detection_confidence: float) -> List[Optional[HandData]]: """Computes hand data (landmarks and handedness) in the input image. Args: @@ -93,28 +97,36 @@ def _get_hand_data(all_image_paths: List[str], A HandData object. Returns None if no hand is detected. """ hand_data_result = [] - with mp_hands.Hands( - static_image_mode=True, - max_num_hands=1, - min_detection_confidence=min_detection_confidence) as hands: + hand_detector_model_buffer = model_util.load_tflite_model_buffer( + constants.HAND_DETECTOR_TFLITE_MODEL_FILE) + hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer( + constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE) + hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter( + hand_detector_model_buffer, hand_landmarks_detector_model_buffer) + hand_landmarker_options = _HandLandmarkerOptions( + base_options=base_options_module.BaseOptions( + model_asset_buffer=hand_landmarker_writer.populate()), + num_hands=1, + min_hand_detection_confidence=min_detection_confidence, + min_hand_presence_confidence=0.5, + min_tracking_confidence=1, + ) + with _HandLandmarker.create_from_options( + hand_landmarker_options) as hand_landmarker: for path in all_image_paths: tf.compat.v1.logging.info('Loading image %s', path) - image = data_util.load_image(path) - # Flip image around y-axis for correct handedness output - image = cv2.flip(image, 1) - data = hands.process(image) + image = _Image.create_from_file(path) + data = hand_landmarker.detect(image) if not _validate_data_sample(data): hand_data_result.append(None) continue - hand_landmarks = [[ - hand_landmark.x, hand_landmark.y, hand_landmark.z - ] for hand_landmark in data.multi_hand_landmarks[0].landmark] + hand_landmarks = [[hand_landmark.x, hand_landmark.y, hand_landmark.z] + for hand_landmark in data.hand_landmarks[0]] hand_world_landmarks = [[ hand_landmark.x, hand_landmark.y, hand_landmark.z - ] for hand_landmark in data.multi_hand_world_landmarks[0].landmark] + ] for hand_landmark in data.hand_world_landmarks[0]] handedness_scores = [ - handedness.score - for handedness in data.multi_handedness[0].classification + handedness.score for handedness in data.handedness[0] ] hand_data_result.append( HandData( diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py index 76e70a58d..528d02edd 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py @@ -12,21 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import os import shutil from typing import NamedTuple import unittest -from absl import flags from absl.testing import parameterized import tensorflow as tf from mediapipe.model_maker.python.vision.gesture_recognizer import dataset -from mediapipe.python.solutions import hands as mp_hands from mediapipe.tasks.python.test import test_utils - -FLAGS = flags.FLAGS +from mediapipe.tasks.python.vision import hand_landmarker _TEST_DATA_DIRNAME = 'raw_data' @@ -39,14 +35,14 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase): dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams()) train_data, test_data = data.split(0.5) - self.assertLen(train_data, 17) + self.assertLen(train_data, 16) for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) self.assertEqual(train_data.num_classes, 4) self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock']) - self.assertLen(test_data, 18) + self.assertLen(test_data, 16) for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) @@ -60,7 +56,7 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase): for _, elem in enumerate(data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) - self.assertLen(data, 35) + self.assertLen(data, 32) self.assertEqual(data.num_classes, 4) self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock']) @@ -105,51 +101,42 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase): for _, elem in enumerate(data.gen_tf_dataset(is_training=True)): self.assertEqual(elem[0].shape, (1, 128)) self.assertEqual(elem[1].shape, ([1, 4])) - self.assertLen(data, 35) + self.assertLen(data, 32) self.assertEqual(data.num_classes, 4) self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK']) @parameterized.named_parameters( dict( - testcase_name='invalid_field_name_multi_hand_landmark', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmark', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(1, 2, 3)), + testcase_name='none_handedness', + hand=hand_landmarker.HandLandmarkerResult( + handedness=None, hand_landmarks=[[2]], + hand_world_landmarks=[[3]])), dict( - testcase_name='invalid_field_name_multi_hand_world_landmarks', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmark', - 'multi_handedness' - ])(1, 2, 3)), + testcase_name='none_hand_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=None, + hand_world_landmarks=[[3]])), dict( - testcase_name='invalid_field_name_multi_handed', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handed' - ])(1, 2, 3)), + testcase_name='none_hand_world_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=[[2]], + hand_world_landmarks=None)), dict( - testcase_name='multi_hand_landmarks_is_none', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(None, 2, 3)), + testcase_name='empty_handedness', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[], hand_landmarks=[[2]], hand_world_landmarks=[[3]])), dict( - testcase_name='multi_hand_world_landmarks_is_none', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(1, None, 3)), + testcase_name='empty_hand_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=[], hand_world_landmarks=[[3]])), dict( - testcase_name='multi_handedness_is_none', - hand=collections.namedtuple('Hand', [ - 'multi_hand_landmarks', 'multi_hand_world_landmarks', - 'multi_handedness' - ])(1, 2, None)), + testcase_name='empty_hand_world_landmarks', + hand=hand_landmarker.HandLandmarkerResult( + handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])), ) def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple): with unittest.mock.patch.object( - mp_hands.Hands, 'process', return_value=hand): + hand_landmarker.HandLandmarker, 'detect', return_value=hand): input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME) with self.assertRaisesRegex(ValueError, 'No valid hand is detected'): dataset.Dataset.from_folder( diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py index 58b67e072..b2e851afe 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py @@ -62,6 +62,50 @@ def read_file(file_path: str, mode: str = "rb") -> Union[str, bytes]: return f.read() +class HandLandmarkerMetadataWriter: + """MetadataWriter to write the model asset bundle for HandLandmarker.""" + + def __init__( + self, + hand_detector_model_buffer: bytearray, + hand_landmarks_detector_model_buffer: bytearray, + ) -> None: + """Initializes HandLandmarkerMetadataWriter to write model asset bundle. + + Args: + hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from + the TFLite hand detector model file. + hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata + loaded from the TFLite hand landmarks detector model file. + """ + self._hand_detector_model_buffer = hand_detector_model_buffer + self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer + self._temp_folder = tempfile.TemporaryDirectory() + + def __del__(self): + if os.path.exists(self._temp_folder.name): + self._temp_folder.cleanup() + + def populate(self): + """Creates the model asset bundle for hand landmarker task. + + Returns: + Model asset bundle in bytes + """ + landmark_models = { + _HAND_DETECTOR_TFLITE_NAME: + self._hand_detector_model_buffer, + _HAND_LANDMARKS_DETECTOR_TFLITE_NAME: + self._hand_landmarks_detector_model_buffer + } + output_hand_landmarker_path = os.path.join(self._temp_folder.name, + _HAND_LANDMARKER_BUNDLE_NAME) + writer_utils.create_model_asset_bundle(landmark_models, + output_hand_landmarker_path) + hand_landmarker_model_buffer = read_file(output_hand_landmarker_path) + return hand_landmarker_model_buffer + + class MetadataWriter: """MetadataWriter to write the metadata and the model asset bundle.""" @@ -86,8 +130,8 @@ class MetadataWriter: custom_gesture_classifier_metadata_writer: Metadata writer to write custom gesture classifier metadata into the TFLite file. """ - self._hand_detector_model_buffer = hand_detector_model_buffer - self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer + self._hand_landmarker_metadata_writer = HandLandmarkerMetadataWriter( + hand_detector_model_buffer, hand_landmarks_detector_model_buffer) self._gesture_embedder_model_buffer = gesture_embedder_model_buffer self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer @@ -147,16 +191,8 @@ class MetadataWriter: A tuple of (model_asset_bundle_in_bytes, metadata_json_content) """ # Creates the model asset bundle for hand landmarker task. - landmark_models = { - _HAND_DETECTOR_TFLITE_NAME: - self._hand_detector_model_buffer, - _HAND_LANDMARKS_DETECTOR_TFLITE_NAME: - self._hand_landmarks_detector_model_buffer - } - output_hand_landmarker_path = os.path.join(self._temp_folder.name, - _HAND_LANDMARKER_BUNDLE_NAME) - writer_utils.create_model_asset_bundle(landmark_models, - output_hand_landmarker_path) + hand_landmarker_model_buffer = self._hand_landmarker_metadata_writer.populate( + ) # Write metadata into custom gesture classifier model. self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate( @@ -179,7 +215,7 @@ class MetadataWriter: # graph. gesture_recognizer_models = { _HAND_LANDMARKER_BUNDLE_NAME: - read_file(output_hand_landmarker_path), + hand_landmarker_model_buffer, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME: read_file(output_hand_gesture_recognizer_path), } diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py index 83998141d..fd26b274d 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py @@ -33,6 +33,23 @@ _CUSTOM_GESTURE_CLASSIFIER_PATH = test_utils.get_test_data_path( class MetadataWriterTest(tf.test.TestCase): + def test_hand_landmarker_metadata_writer(self): + # Use dummy model buffer for unit test only. + hand_detector_model_buffer = b"\x11\x12" + hand_landmarks_detector_model_buffer = b"\x22" + writer = metadata_writer.HandLandmarkerMetadataWriter( + hand_detector_model_buffer, hand_landmarks_detector_model_buffer) + model_bundle_content = writer.populate() + model_bundle_filepath = os.path.join(self.get_temp_dir(), + "hand_landmarker.task") + with open(model_bundle_filepath, "wb") as f: + f.write(model_bundle_content) + + with zipfile.ZipFile(model_bundle_filepath) as zf: + self.assertEqual( + set(zf.namelist()), + set(["hand_landmarks_detector.tflite", "hand_detector.tflite"])) + def test_write_metadata_and_create_model_asset_bundle_successful(self): # Use dummy model buffer for unit test only. hand_detector_model_buffer = b"\x11\x12" diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index 447189d6f..f14d59b99 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -23,15 +23,15 @@ py_library( srcs = [ "optional_dependencies.py", ], - deps = [ - "@org_tensorflow//tensorflow/tools/docs:doc_controls", - ], ) py_library( name = "base_options", srcs = ["base_options.py"], - visibility = ["//mediapipe/tasks:users"], + visibility = [ + "//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__", + "//mediapipe/tasks:users", + ], deps = [ ":optional_dependencies", "//mediapipe/tasks/cc/core/proto:base_options_py_pb2", diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 5f4aa38ff..eda8e290d 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -131,6 +131,10 @@ py_library( srcs = [ "hand_landmarker.py", ], + visibility = [ + "//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__", + "//mediapipe/tasks:internal", + ], deps = [ "//mediapipe/framework/formats:classification_py_pb2", "//mediapipe/framework/formats:landmark_py_pb2", From 0fbaa8dc8a0220d682081d70fd01bef71709a316 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Thu, 8 Dec 2022 12:59:46 -0800 Subject: [PATCH 44/64] Internal change. PiperOrigin-RevId: 493973435 --- .../framework/formats/tensor_ahwb_gpu_test.cc | 143 ++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 mediapipe/framework/formats/tensor_ahwb_gpu_test.cc diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc new file mode 100644 index 000000000..dd865a367 --- /dev/null +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -0,0 +1,143 @@ + +#if !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) +#include + +#include + +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/formats/tensor_data_types.h" +#include "mediapipe/gpu/gpu_test_base.h" +#include "mediapipe/gpu/shader_util.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "testing/base/public/gunit.h" + +// The test creates OpenGL ES buffer, fills the buffer with incrementing values +// 0.0, 0.1, 0.2 etc. with the compute shader on GPU. +// Then the test requests the CPU view and compares the values. +// Float32 and Float16 tests are there. + +namespace { + +using mediapipe::Float16; +using mediapipe::Tensor; + +MATCHER_P(NearWithPrecision, precision, "") { + return std::abs(std::get<0>(arg) - std::get<1>(arg)) < precision; +} + +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + +// Utility function to fill the GPU buffer. +void FillGpuBuffer(GLuint name, std::size_t size, + const Tensor::ElementType fmt) { + std::string shader_source; + if (fmt == Tensor::ElementType::kFloat32) { + shader_source = R"( #version 310 es + precision highp float; + layout(local_size_x = 1, local_size_y = 1) in; + layout(std430, binding = 0) buffer Output {float elements[];} output_data; + void main() { + uint v = gl_GlobalInvocationID.x * 2u; + output_data.elements[v] = float(v) / 10.0; + output_data.elements[v + 1u] = float(v + 1u) / 10.0; + })"; + } else { + shader_source = R"( #version 310 es + precision highp float; + layout(local_size_x = 1, local_size_y = 1) in; + layout(std430, binding = 0) buffer Output {float elements[];} output_data; + void main() { + uint v = gl_GlobalInvocationID.x; + uint tmp = packHalf2x16(vec2((float(v)* 2.0 + 0.0) / 10.0, + (float(v) * 2.0 + 1.0) / 10.0)); + output_data.elements[v] = uintBitsToFloat(tmp); + })"; + } + GLuint shader; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCreateShader, &shader, GL_COMPUTE_SHADER)); + const GLchar* sources[] = {shader_source.c_str()}; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glShaderSource, shader, 1, sources, nullptr)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCompileShader, shader)); + GLint is_compiled = 0; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderiv, shader, GL_COMPILE_STATUS, + &is_compiled)); + if (is_compiled == GL_FALSE) { + GLint max_length = 0; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderiv, shader, GL_INFO_LOG_LENGTH, + &max_length)); + std::vector error_log(max_length); + glGetShaderInfoLog(shader, max_length, &max_length, error_log.data()); + glDeleteShader(shader); + FAIL() << error_log.data(); + return; + } + GLuint to_buffer_program; + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glCreateProgram, &to_buffer_program)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glAttachShader, to_buffer_program, shader)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDeleteShader, shader)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glLinkProgram, to_buffer_program)); + + MP_ASSERT_OK( + TFLITE_GPU_CALL_GL(glBindBufferBase, GL_SHADER_STORAGE_BUFFER, 0, name)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glUseProgram, to_buffer_program)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDispatchCompute, size / 2, 1, 1)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glBindBuffer, GL_SHADER_STORAGE_BUFFER, 0)); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDeleteProgram, to_buffer_program)); +} + +class TensorAhwbGpuTest : public mediapipe::GpuTestBase { + public: +}; + +TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { + Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + RunInGlContext([&tensor] { + auto ssbo_view = tensor.GetOpenGlBufferWriteView(); + auto ssbo_name = ssbo_view.name(); + EXPECT_GT(ssbo_name, 0); + FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), + tensor.element_type()); + }); + auto ptr = tensor.GetCpuReadView().buffer(); + EXPECT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(testing::FloatEq(), reference)); +} + +TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { + Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})}; + RunInGlContext([&tensor] { + auto ssbo_view = tensor.GetOpenGlBufferWriteView(); + auto ssbo_name = ssbo_view.name(); + EXPECT_GT(ssbo_name, 0); + FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), + tensor.element_type()); + }); + auto ptr = tensor.GetCpuReadView().buffer(); + EXPECT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + // Precision is set to a reasonable value for Float16. + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(NearWithPrecision(0.001), reference)); +} + +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +} // namespace + +#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || + // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) From b4e1969e4381053038322275bfb8f15d855da9f2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 8 Dec 2022 14:01:17 -0800 Subject: [PATCH 45/64] Add pip package builder for model_maker PiperOrigin-RevId: 493989013 --- mediapipe/model_maker/MANIFEST.in | 1 + mediapipe/model_maker/__init__.py | 6 + .../python/core/utils/file_util.py | 19 ++- mediapipe/model_maker/python/text/__init__.py | 13 ++ mediapipe/model_maker/requirements.txt | 6 +- mediapipe/model_maker/setup.py | 147 ++++++++++++++++++ 6 files changed, 184 insertions(+), 8 deletions(-) create mode 100644 mediapipe/model_maker/MANIFEST.in create mode 100644 mediapipe/model_maker/python/text/__init__.py create mode 100644 mediapipe/model_maker/setup.py diff --git a/mediapipe/model_maker/MANIFEST.in b/mediapipe/model_maker/MANIFEST.in new file mode 100644 index 000000000..54ce01aff --- /dev/null +++ b/mediapipe/model_maker/MANIFEST.in @@ -0,0 +1 @@ +recursive-include pip_src/mediapipe_model_maker/models * diff --git a/mediapipe/model_maker/__init__.py b/mediapipe/model_maker/__init__.py index 7ca2f9216..9899a145b 100644 --- a/mediapipe/model_maker/__init__.py +++ b/mediapipe/model_maker/__init__.py @@ -11,3 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + +from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.vision import image_classifier +from mediapipe.model_maker.python.vision import gesture_recognizer +from mediapipe.model_maker.python.text import text_classifier diff --git a/mediapipe/model_maker/python/core/utils/file_util.py b/mediapipe/model_maker/python/core/utils/file_util.py index bccf928e2..66addad54 100644 --- a/mediapipe/model_maker/python/core/utils/file_util.py +++ b/mediapipe/model_maker/python/core/utils/file_util.py @@ -19,7 +19,7 @@ import os def get_absolute_path(file_path: str) -> str: - """Gets the absolute path of a file. + """Gets the absolute path of a file in the model_maker directory. Args: file_path: The path to a file relative to the `mediapipe` dir @@ -27,10 +27,17 @@ def get_absolute_path(file_path: str) -> str: Returns: The full path of the file """ - # Extract the file path before mediapipe/ as the `base_dir`. By joining it - # with the `path` which defines the relative path under mediapipe/, it - # yields to the absolute path of the model files directory. + # Extract the file path before and including 'model_maker' as the + # `mm_base_dir`. By joining it with the `path` after 'model_maker/', it + # yields to the absolute path of the model files directory. We must join + # on 'model_maker' because in the pypi package, the 'model_maker' directory + # is renamed to 'mediapipe_model_maker'. So we have to join on model_maker + # to ensure that the `mm_base_dir` path includes the renamed + # 'mediapipe_model_maker' directory. cwd = os.path.dirname(__file__) - base_dir = cwd[:cwd.rfind('mediapipe')] - absolute_path = os.path.join(base_dir, file_path) + cwd_stop_idx = cwd.rfind('model_maker') + len('model_maker') + mm_base_dir = cwd[:cwd_stop_idx] + file_path_start_idx = file_path.find('model_maker') + len('model_maker') + 1 + mm_relative_path = file_path[file_path_start_idx:] + absolute_path = os.path.join(mm_base_dir, mm_relative_path) return absolute_path diff --git a/mediapipe/model_maker/python/text/__init__.py b/mediapipe/model_maker/python/text/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/text/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/requirements.txt b/mediapipe/model_maker/requirements.txt index 389ee484a..9b3c9f906 100644 --- a/mediapipe/model_maker/requirements.txt +++ b/mediapipe/model_maker/requirements.txt @@ -1,6 +1,8 @@ absl-py +mediapipe==0.9.1 numpy -opencv-contrib-python -tensorflow +opencv-python +tensorflow>=2.10 tensorflow-datasets tensorflow-hub +tf-models-official>=2.10.1 diff --git a/mediapipe/model_maker/setup.py b/mediapipe/model_maker/setup.py new file mode 100644 index 000000000..ea193db94 --- /dev/null +++ b/mediapipe/model_maker/setup.py @@ -0,0 +1,147 @@ +"""Copyright 2020-2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Setup for Mediapipe-Model-Maker package with setuptools. +""" + +import glob +import os +import shutil +import subprocess +import sys +import setuptools + + +__version__ = 'dev' +MM_ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) +# Build dir to copy all necessary files and build package +SRC_NAME = 'pip_src' +BUILD_DIR = os.path.join(MM_ROOT_PATH, SRC_NAME) +BUILD_MM_DIR = os.path.join(BUILD_DIR, 'mediapipe_model_maker') + + +def _parse_requirements(path): + with open(os.path.join(MM_ROOT_PATH, path)) as f: + return [ + line.rstrip() + for line in f + if not (line.isspace() or line.startswith('#')) + ] + + +def _copy_to_pip_src_dir(file): + """Copy a file from bazel-bin to the pip_src dir.""" + dst = file + dst_dir = os.path.dirname(dst) + if not os.path.exists(dst_dir): + os.makedirs(dst_dir) + src_file = os.path.join('../../bazel-bin/mediapipe/model_maker', file) + shutil.copyfile(src_file, file) + + +def _setup_build_dir(): + """Setup the BUILD_DIR directory to build the mediapipe_model_maker package. + + We need to create a new BUILD_DIR directory because any references to the path + `mediapipe/model_maker` needs to be renamed to `mediapipe_model_maker` to + avoid conflicting with the mediapipe package name. + This setup function performs the following actions: + 1. Copy python source code into BUILD_DIR and rename imports to + mediapipe_model_maker + 2. Download models from GCS into BUILD_DIR + """ + # Copy python source code into BUILD_DIR + if os.path.exists(BUILD_DIR): + shutil.rmtree(BUILD_DIR) + python_files = glob.glob('python/**/*.py', recursive=True) + python_files.append('__init__.py') + for python_file in python_files: + # Exclude test files from pip package + if '_test.py' in python_file: + continue + build_target_file = os.path.join(BUILD_MM_DIR, python_file) + with open(python_file, 'r') as file: + filedata = file.read() + # Rename all mediapipe.model_maker imports to mediapipe_model_maker + filedata = filedata.replace('from mediapipe.model_maker', + 'from mediapipe_model_maker') + os.makedirs(os.path.dirname(build_target_file), exist_ok=True) + with open(build_target_file, 'w') as file: + file.write(filedata) + + # Use bazel to download GCS model files + model_build_files = ['models/gesture_recognizer/BUILD'] + for model_build_file in model_build_files: + build_target_file = os.path.join(BUILD_MM_DIR, model_build_file) + os.makedirs(os.path.dirname(build_target_file), exist_ok=True) + shutil.copy(model_build_file, build_target_file) + external_files = [ + 'models/gesture_recognizer/canned_gesture_classifier.tflite', + 'models/gesture_recognizer/gesture_embedder.tflite', + 'models/gesture_recognizer/hand_landmark_full.tflite', + 'models/gesture_recognizer/palm_detection_full.tflite', + 'models/gesture_recognizer/gesture_embedder/keras_metadata.pb', + 'models/gesture_recognizer/gesture_embedder/saved_model.pb', + 'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001', + 'models/gesture_recognizer/gesture_embedder/variables/variables.index', + ] + for elem in external_files: + external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem) + sys.stderr.write('downloading file: %s\n' % external_file) + fetch_model_command = [ + 'bazel', + 'build', + external_file, + ] + if subprocess.call(fetch_model_command) != 0: + sys.exit(-1) + _copy_to_pip_src_dir(external_file) + +_setup_build_dir() + +setuptools.setup( + name='mediapipe-model-maker', + version=__version__, + url='https://github.com/google/mediapipe/tree/master/mediapipe/model_maker', + description='MediaPipe Model Maker is a simple, low-code solution for customizing on-device ML models', + author='The MediaPipe Authors', + author_email='mediapipe@google.com', + long_description='', + long_description_content_type='text/markdown', + packages=setuptools.find_packages(where=SRC_NAME), + package_dir={'': SRC_NAME}, + install_requires=_parse_requirements('requirements.txt'), + include_package_data=True, + classifiers=[ + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Developers', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: MacOS :: MacOS X', + 'Operating System :: Microsoft :: Windows', + 'Operating System :: POSIX :: Linux', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3 :: Only', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + ], + license='Apache 2.0', + keywords=['mediapipe', 'model', 'maker'], +) From 05535db5f77bdc9c46df36855fe4064ded89d7cb Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 8 Dec 2022 15:01:34 -0800 Subject: [PATCH 46/64] Fix assertion failure in Hair Segmentation demo PiperOrigin-RevId: 494004801 --- mediapipe/graphs/hair_segmentation/BUILD | 1 + .../hair_segmentation_desktop_live.pbtxt | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mediapipe/graphs/hair_segmentation/BUILD b/mediapipe/graphs/hair_segmentation/BUILD index b177726bf..945f02c62 100644 --- a/mediapipe/graphs/hair_segmentation/BUILD +++ b/mediapipe/graphs/hair_segmentation/BUILD @@ -43,6 +43,7 @@ cc_library( deps = [ "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/core:previous_loopback_calculator", + "//mediapipe/calculators/image:color_convert_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/image:recolor_calculator", "//mediapipe/calculators/image:set_alpha_calculator", diff --git a/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt b/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt index 36c6970e1..f48b26be0 100644 --- a/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt +++ b/mediapipe/graphs/hair_segmentation/hair_segmentation_desktop_live.pbtxt @@ -60,7 +60,14 @@ node { tag_index: "LOOP" back_edge: true } - output_stream: "PREV_LOOP:previous_hair_mask" + output_stream: "PREV_LOOP:previous_hair_mask_rgb" +} + +# Converts the 4 channel hair mask to a single channel mask +node { + calculator: "ColorConvertCalculator" + input_stream: "RGB_IN:previous_hair_mask_rgb" + output_stream: "GRAY_OUT:previous_hair_mask" } # Embeds the hair mask generated from the previous round of hair segmentation From bea0caae6586343eb91e986b84aa00ef75fa67b1 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Thu, 8 Dec 2022 17:05:06 -0800 Subject: [PATCH 47/64] Tensor: Cpu -> Ahwb storage transfer PiperOrigin-RevId: 494033280 --- mediapipe/framework/formats/tensor.cc | 4 +-- mediapipe/framework/formats/tensor.h | 2 +- mediapipe/framework/formats/tensor_ahwb.cc | 3 +- .../framework/formats/tensor_ahwb_gpu_test.cc | 28 +++++++++++++++++++ 4 files changed, 32 insertions(+), 5 deletions(-) diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index 9e1406dbb..fdafbff5c 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -361,7 +361,7 @@ void Tensor::AllocateOpenGlBuffer() const { LOG_IF(FATAL, !gl_context_) << "GlContext is not bound to the thread."; glGenBuffers(1, &opengl_buffer_); glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); - if (!AllocateAhwbMapToSsbo()) { + if (!use_ahwb_ || !AllocateAhwbMapToSsbo()) { glBufferData(GL_SHADER_STORAGE_BUFFER, bytes(), NULL, GL_STREAM_COPY); } glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); @@ -610,7 +610,7 @@ Tensor::CpuWriteView Tensor::GetCpuWriteView() const { void Tensor::AllocateCpuBuffer() const { if (!cpu_buffer_) { #ifdef MEDIAPIPE_TENSOR_USE_AHWB - if (AllocateAHardwareBuffer()) return; + if (use_ahwb_ && AllocateAHardwareBuffer()) return; #endif // MEDIAPIPE_TENSOR_USE_AHWB #if MEDIAPIPE_METAL_ENABLED cpu_buffer_ = AllocateVirtualMemory(bytes()); diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 151aa299d..9d3e90b6a 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -409,8 +409,8 @@ class Tensor { bool AllocateAHardwareBuffer(int size_alignment = 0) const; void CreateEglSyncAndFd() const; // Use Ahwb for other views: OpenGL / CPU buffer. - static inline bool use_ahwb_ = false; #endif // MEDIAPIPE_TENSOR_USE_AHWB + static inline bool use_ahwb_ = false; // Expects the target SSBO to be already bound. bool AllocateAhwbMapToSsbo() const; bool InsertAhwbToSsboFence() const; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 21bae9593..3c3ec8b17 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -214,7 +214,7 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const { "supported."; CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer)) << "Interoperability bettween OpenGL buffer and AHardwareBuffer is not " - "supported on targe system."; + "supported on target system."; bool transfer = !ahwb_; CHECK(AllocateAHardwareBuffer()) << "AHardwareBuffer is not supported on the target system."; @@ -268,7 +268,6 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView( } bool Tensor::AllocateAHardwareBuffer(int size_alignment) const { - if (!use_ahwb_) return false; if (__builtin_available(android 26, *)) { if (ahwb_ == nullptr) { AHardwareBuffer_Desc desc = {}; diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index dd865a367..7ccd9c7f5 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -136,6 +136,34 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { testing::Pointwise(NearWithPrecision(0.001), reference)); } +TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { + // Request the CPU view to get the memory to be allocated. + // Request Ahwb view then to transform the storage into Ahwb. + Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault); + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + { + auto ptr = tensor.GetCpuWriteView().buffer(); + EXPECT_NE(ptr, nullptr); + for (int i = 0; i < num_elements; i++) { + ptr[i] = static_cast(i) / 10.0f; + } + } + { + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + } + auto ptr = tensor.GetCpuReadView().buffer(); + EXPECT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(testing::FloatEq(), reference)); +} + #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 } // namespace From 3aeec84ac016d9899a7829ad5651753942dcf275 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 9 Dec 2022 03:19:45 -0800 Subject: [PATCH 48/64] Internal change for profiling PiperOrigin-RevId: 494126771 --- mediapipe/framework/validated_graph_config.cc | 10 ++++++++++ mediapipe/framework/validated_graph_config.h | 12 ++++++++++++ 2 files changed, 22 insertions(+) diff --git a/mediapipe/framework/validated_graph_config.cc b/mediapipe/framework/validated_graph_config.cc index 01e3da83e..15eac3209 100644 --- a/mediapipe/framework/validated_graph_config.cc +++ b/mediapipe/framework/validated_graph_config.cc @@ -369,6 +369,7 @@ absl::Status ValidatedGraphConfig::Initialize( input_side_packets_.clear(); output_side_packets_.clear(); stream_to_producer_.clear(); + output_streams_to_consumer_nodes_.clear(); input_streams_.clear(); output_streams_.clear(); owned_packet_types_.clear(); @@ -719,6 +720,15 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode( << " does not have a corresponding output stream."; } } + // Add this node as a consumer of this edge's output stream. + if (edge_info.upstream > -1) { + auto parent_node = output_streams_[edge_info.upstream].parent_node; + if (parent_node.type == NodeTypeInfo::NodeType::CALCULATOR) { + int this_idx = node_type_info->Node().index; + output_streams_to_consumer_nodes_[edge_info.upstream].push_back( + this_idx); + } + } edge_info.parent_node = node_type_info->Node(); edge_info.name = name; diff --git a/mediapipe/framework/validated_graph_config.h b/mediapipe/framework/validated_graph_config.h index 11f9553cd..95ecccbb4 100644 --- a/mediapipe/framework/validated_graph_config.h +++ b/mediapipe/framework/validated_graph_config.h @@ -282,6 +282,14 @@ class ValidatedGraphConfig { return output_streams_[iter->second].parent_node.index; } + std::vector OutputStreamToConsumers(int idx) const { + auto iter = output_streams_to_consumer_nodes_.find(idx); + if (iter == output_streams_to_consumer_nodes_.end()) { + return {}; + } + return iter->second; + } + // Returns the registered type name of the specified side packet if // it can be determined, otherwise an appropriate error is returned. absl::StatusOr RegisteredSidePacketTypeName( @@ -418,6 +426,10 @@ class ValidatedGraphConfig { // Mapping from stream name to the output_streams_ index which produces it. std::map stream_to_producer_; + + // Mapping from output streams to consumer node ids. Used for profiling. + std::map> output_streams_to_consumer_nodes_; + // Mapping from side packet name to the output_side_packets_ index // which produces it. std::map side_packet_to_producer_; From 4c4df2cf18955b2cc76f432b5f1321083d5bfb11 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 9 Dec 2022 04:11:05 -0800 Subject: [PATCH 49/64] Internal change for profiling PiperOrigin-RevId: 494135244 --- mediapipe/framework/profiler/graph_profiler.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mediapipe/framework/profiler/graph_profiler.h b/mediapipe/framework/profiler/graph_profiler.h index 29969af2e..23caed4ec 100644 --- a/mediapipe/framework/profiler/graph_profiler.h +++ b/mediapipe/framework/profiler/graph_profiler.h @@ -232,6 +232,11 @@ class GraphProfiler : public std::enable_shared_from_this { const ProfilerConfig& profiler_config() { return profiler_config_; } + // Helper method to expose the config to other profilers. + const ValidatedGraphConfig* GetValidatedGraphConfig() { + return validated_graph_; + } + private: // This can be used to add packet info for the input streams to the graph. // It treats the stream defined by |stream_name| as a stream produced by a From 5bc1baf96acab858942d151d46b988ebe0577c00 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 9 Dec 2022 05:55:20 -0800 Subject: [PATCH 50/64] Internal change PiperOrigin-RevId: 494150467 --- mediapipe/framework/output_stream_shard.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mediapipe/framework/output_stream_shard.h b/mediapipe/framework/output_stream_shard.h index fdc5fe077..718174c45 100644 --- a/mediapipe/framework/output_stream_shard.h +++ b/mediapipe/framework/output_stream_shard.h @@ -127,6 +127,8 @@ class OutputStreamShard : public OutputStream { friend class GraphProfiler; // Accesses OutputStreamShard for profiling. friend class GraphTracer; + // Accesses OutputStreamShard for profiling. + friend class PerfettoTraceScope; // Accesses OutputStreamShard for post processing. friend class OutputStreamManager; }; From db3cb68d919693adb729437d1223d29c30736f27 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Fri, 9 Dec 2022 07:27:11 -0800 Subject: [PATCH 51/64] Internal change. PiperOrigin-RevId: 494166776 --- .../formats/tensor_hardware_buffer.h | 71 ++++++ .../tensor_hardware_buffer_cpu_storage.cc | 216 ++++++++++++++++++ ...tensor_hardware_buffer_cpu_storage_test.cc | 76 ++++++ 3 files changed, 363 insertions(+) create mode 100644 mediapipe/framework/formats/tensor_hardware_buffer.h create mode 100644 mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc create mode 100644 mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc diff --git a/mediapipe/framework/formats/tensor_hardware_buffer.h b/mediapipe/framework/formats/tensor_hardware_buffer.h new file mode 100644 index 000000000..fa0241bde --- /dev/null +++ b/mediapipe/framework/formats/tensor_hardware_buffer.h @@ -0,0 +1,71 @@ +#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ +#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ + +#if !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) + +#include + +#include + +#include "mediapipe/framework/formats/tensor_buffer.h" +#include "mediapipe/framework/formats/tensor_internal.h" +#include "mediapipe/framework/formats/tensor_v2.h" + +namespace mediapipe { + +// Supports: +// - float 16 and 32 bits +// - signed / unsigned integers 8,16,32 bits +class TensorHardwareBufferView; +struct TensorHardwareBufferViewDescriptor : public Tensor::ViewDescriptor { + using ViewT = TensorHardwareBufferView; + TensorBufferDescriptor buffer; +}; + +class TensorHardwareBufferView : public Tensor::View { + public: + TENSOR_UNIQUE_VIEW_TYPE_ID(); + ~TensorHardwareBufferView() = default; + + const TensorHardwareBufferViewDescriptor& descriptor() const override { + return descriptor_; + } + AHardwareBuffer* handle() const { return ahwb_handle_; } + + protected: + TensorHardwareBufferView(int access_capability, Tensor::View::Access access, + Tensor::View::State state, + const TensorHardwareBufferViewDescriptor& desc, + AHardwareBuffer* ahwb_handle) + : Tensor::View(kId, access_capability, access, state), + descriptor_(desc), + ahwb_handle_(ahwb_handle) {} + + private: + bool MatchDescriptor( + uint64_t view_type_id, + const Tensor::ViewDescriptor& base_descriptor) const override { + if (!Tensor::View::MatchDescriptor(view_type_id, base_descriptor)) + return false; + auto descriptor = + static_cast(base_descriptor); + return descriptor.buffer.format == descriptor_.buffer.format && + descriptor.buffer.size_alignment <= + descriptor_.buffer.size_alignment && + descriptor_.buffer.size_alignment % + descriptor.buffer.size_alignment == + 0; + } + const TensorHardwareBufferViewDescriptor& descriptor_; + AHardwareBuffer* ahwb_handle_ = nullptr; +}; + +} // namespace mediapipe + +#endif // !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) + +#endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ diff --git a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc new file mode 100644 index 000000000..9c223ce2c --- /dev/null +++ b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc @@ -0,0 +1,216 @@ +#if !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) + +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "mediapipe/framework/formats/tensor_backend.h" +#include "mediapipe/framework/formats/tensor_cpu_buffer.h" +#include "mediapipe/framework/formats/tensor_hardware_buffer.h" +#include "mediapipe/framework/formats/tensor_v2.h" +#include "util/task/status_macros.h" + +namespace mediapipe { +namespace { + +class TensorCpuViewImpl : public TensorCpuView { + public: + TensorCpuViewImpl(int access_capabilities, Tensor::View::Access access, + Tensor::View::State state, + const TensorCpuViewDescriptor& descriptor, void* pointer, + AHardwareBuffer* ahwb_handle) + : TensorCpuView(access_capabilities, access, state, descriptor, pointer), + ahwb_handle_(ahwb_handle) {} + ~TensorCpuViewImpl() { + // If handle_ is null then this view is constructed in GetViews with no + // access. + if (ahwb_handle_) { + if (__builtin_available(android 26, *)) { + AHardwareBuffer_unlock(ahwb_handle_, nullptr); + } + } + } + + private: + AHardwareBuffer* ahwb_handle_; +}; + +class TensorHardwareBufferViewImpl : public TensorHardwareBufferView { + public: + TensorHardwareBufferViewImpl( + int access_capability, Tensor::View::Access access, + Tensor::View::State state, + const TensorHardwareBufferViewDescriptor& descriptor, + AHardwareBuffer* handle) + : TensorHardwareBufferView(access_capability, access, state, descriptor, + handle) {} + ~TensorHardwareBufferViewImpl() = default; +}; + +class HardwareBufferCpuStorage : public TensorStorage { + public: + ~HardwareBufferCpuStorage() { + if (!ahwb_handle_) return; + if (__builtin_available(android 26, *)) { + AHardwareBuffer_release(ahwb_handle_); + } + } + + static absl::Status CanProvide( + int access_capability, const Tensor::Shape& shape, uint64_t view_type_id, + const Tensor::ViewDescriptor& base_descriptor) { + // TODO: use AHardwareBuffer_isSupported for API >= 29. + static const bool is_ahwb_supported = [] { + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc = {}; + // Aligned to the largest possible virtual memory page size. + constexpr uint32_t kPageSize = 16384; + desc.width = kPageSize; + desc.height = 1; + desc.layers = 1; + desc.format = AHARDWAREBUFFER_FORMAT_BLOB; + desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | + AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN; + AHardwareBuffer* handle; + if (AHardwareBuffer_allocate(&desc, &handle) != 0) return false; + AHardwareBuffer_release(handle); + return true; + } + return false; + }(); + if (!is_ahwb_supported) { + return absl::UnavailableError( + "AHardwareBuffer is not supported on the platform."); + } + + if (view_type_id != TensorCpuView::kId && + view_type_id != TensorHardwareBufferView::kId) { + return absl::InvalidArgumentError( + "A view type is not supported by this storage."); + } + return absl::OkStatus(); + } + + std::vector> GetViews(uint64_t latest_version) { + std::vector> result; + auto update_state = latest_version == version_ + ? Tensor::View::State::kUpToDate + : Tensor::View::State::kOutdated; + if (ahwb_handle_) { + result.push_back( + std::unique_ptr(new TensorHardwareBufferViewImpl( + kAccessCapability, Tensor::View::Access::kNoAccess, update_state, + hw_descriptor_, ahwb_handle_))); + + result.push_back(std::unique_ptr(new TensorCpuViewImpl( + kAccessCapability, Tensor::View::Access::kNoAccess, update_state, + cpu_descriptor_, nullptr, nullptr))); + } + return result; + } + + absl::StatusOr> GetView( + Tensor::View::Access access, const Tensor::Shape& shape, + uint64_t latest_version, uint64_t view_type_id, + const Tensor::ViewDescriptor& base_descriptor, int access_capability) { + MP_RETURN_IF_ERROR( + CanProvide(access_capability, shape, view_type_id, base_descriptor)); + const auto& buffer_descriptor = + view_type_id == TensorHardwareBufferView::kId + ? static_cast( + base_descriptor) + .buffer + : static_cast(base_descriptor) + .buffer; + if (!ahwb_handle_) { + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc = {}; + desc.width = TensorBufferSize(buffer_descriptor, shape); + desc.height = 1; + desc.layers = 1; + desc.format = AHARDWAREBUFFER_FORMAT_BLOB; + // TODO: Use access capabilities to set hints. + desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | + AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN; + auto error = AHardwareBuffer_allocate(&desc, &ahwb_handle_); + if (error != 0) { + return absl::UnknownError( + absl::StrCat("Error allocating hardware buffer: ", error)); + } + // Fill all possible views to provide it as proto views. + hw_descriptor_.buffer = buffer_descriptor; + cpu_descriptor_.buffer = buffer_descriptor; + } + } + if (buffer_descriptor.format != hw_descriptor_.buffer.format || + buffer_descriptor.size_alignment > + hw_descriptor_.buffer.size_alignment || + hw_descriptor_.buffer.size_alignment % + buffer_descriptor.size_alignment > + 0) { + return absl::AlreadyExistsError( + "A view with different params is already allocated with this " + "storage"); + } + + absl::StatusOr> result; + if (view_type_id == TensorHardwareBufferView::kId) { + result = GetAhwbView(access, shape, base_descriptor); + } else { + result = GetCpuView(access, shape, base_descriptor); + } + if (result.ok()) version_ = latest_version; + return result; + } + + private: + absl::StatusOr> GetAhwbView( + Tensor::View::Access access, const Tensor::Shape& shape, + const Tensor::ViewDescriptor& base_descriptor) { + return std::unique_ptr(new TensorHardwareBufferViewImpl( + kAccessCapability, access, Tensor::View::State::kUpToDate, + hw_descriptor_, ahwb_handle_)); + } + + absl::StatusOr> GetCpuView( + Tensor::View::Access access, const Tensor::Shape& shape, + const Tensor::ViewDescriptor& base_descriptor) { + void* pointer = nullptr; + if (__builtin_available(android 26, *)) { + int error = + AHardwareBuffer_lock(ahwb_handle_, + access == Tensor::View::Access::kWriteOnly + ? AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN + : AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN, + -1, nullptr, &pointer); + if (error != 0) { + return absl::UnknownError( + absl::StrCat("Error locking hardware buffer: ", error)); + } + } + return std::unique_ptr( + new TensorCpuViewImpl(access == Tensor::View::Access::kWriteOnly + ? Tensor::View::AccessCapability::kWrite + : Tensor::View::AccessCapability::kRead, + access, Tensor::View::State::kUpToDate, + cpu_descriptor_, pointer, ahwb_handle_)); + } + + static constexpr int kAccessCapability = + Tensor::View::AccessCapability::kRead | + Tensor::View::AccessCapability::kWrite; + TensorHardwareBufferViewDescriptor hw_descriptor_; + AHardwareBuffer* ahwb_handle_ = nullptr; + + TensorCpuViewDescriptor cpu_descriptor_; + uint64_t version_ = 0; +}; +TENSOR_REGISTER_STORAGE(HardwareBufferCpuStorage); + +} // namespace +} // namespace mediapipe + +#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || + // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) diff --git a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc new file mode 100644 index 000000000..0afa9899f --- /dev/null +++ b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc @@ -0,0 +1,76 @@ + +#if !defined(MEDIAPIPE_NO_JNI) && \ + (__ANDROID_API__ >= 26 || \ + defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) +#include + +#include + +#include "mediapipe/framework/formats/tensor_cpu_buffer.h" +#include "mediapipe/framework/formats/tensor_hardware_buffer.h" +#include "mediapipe/framework/formats/tensor_v2.h" +#include "testing/base/public/gmock.h" +#include "testing/base/public/gunit.h" + +namespace mediapipe { + +namespace { + +class TensorHardwareBufferTest : public ::testing::Test { + public: + TensorHardwareBufferTest() {} + ~TensorHardwareBufferTest() override {} +}; + +TEST_F(TensorHardwareBufferTest, TestFloat32) { + Tensor tensor{Tensor::Shape({1})}; + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorHardwareBufferViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + EXPECT_NE(view->handle(), nullptr); + } + { + const auto& const_tensor = tensor; + MP_ASSERT_OK_AND_ASSIGN( + auto view, + const_tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = + TensorBufferDescriptor::Format::kFloat32}})); + EXPECT_NE(view->data(), nullptr); + } +} + +TEST_F(TensorHardwareBufferTest, TestInt8Padding) { + Tensor tensor{Tensor::Shape({1})}; + + { + MP_ASSERT_OK_AND_ASSIGN( + auto view, + tensor.GetView( + TensorHardwareBufferViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kInt8, + .size_alignment = 4}})); + EXPECT_NE(view->handle(), nullptr); + } + { + const auto& const_tensor = tensor; + MP_ASSERT_OK_AND_ASSIGN( + auto view, + const_tensor.GetView( + TensorCpuViewDescriptor{ + .buffer = {.format = TensorBufferDescriptor::Format::kInt8}})); + EXPECT_NE(view->data(), nullptr); + } +} + +} // namespace + +} // namespace mediapipe + +#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || + // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) From 453d67de92d19abf2488a4400532708d734b20bb Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 9 Dec 2022 13:10:25 -0800 Subject: [PATCH 52/64] Add MergeDetectionsToVectorCalculator. PiperOrigin-RevId: 494246359 --- mediapipe/calculators/core/BUILD | 1 + mediapipe/calculators/core/merge_to_vector_calculator.cc | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 29bca5fa6..2c143a609 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -1323,6 +1323,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image", "@com_google_absl//absl/status", ], diff --git a/mediapipe/calculators/core/merge_to_vector_calculator.cc b/mediapipe/calculators/core/merge_to_vector_calculator.cc index 5f05ad725..fd053ed2b 100644 --- a/mediapipe/calculators/core/merge_to_vector_calculator.cc +++ b/mediapipe/calculators/core/merge_to_vector_calculator.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mediapipe/calculators/core/merge_to_vector_calculator.h" +#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" namespace mediapipe { @@ -27,5 +28,9 @@ typedef MergeToVectorCalculator MergeGpuBuffersToVectorCalculator; MEDIAPIPE_REGISTER_NODE(MergeGpuBuffersToVectorCalculator); +typedef MergeToVectorCalculator + MergeDetectionsToVectorCalculator; +MEDIAPIPE_REGISTER_NODE(MergeDetectionsToVectorCalculator); + } // namespace api2 } // namespace mediapipe From 69c3c4c181766e4e94bca9f1db6ce49315d8ac45 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 9 Dec 2022 18:08:26 -0800 Subject: [PATCH 53/64] Internal change PiperOrigin-RevId: 494305195 --- mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts | 1 + mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts | 1 + mediapipe/tasks/web/core/task_runner.ts | 1 + mediapipe/tasks/web/text/text_classifier/text_classifier.ts | 1 + mediapipe/tasks/web/text/text_embedder/text_embedder.ts | 1 + .../tasks/web/vision/gesture_recognizer/gesture_recognizer.ts | 1 + mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts | 1 + mediapipe/tasks/web/vision/image_classifier/image_classifier.ts | 1 + mediapipe/tasks/web/vision/image_embedder/image_embedder.ts | 1 + mediapipe/tasks/web/vision/object_detector/object_detector.ts | 1 + 10 files changed, 10 insertions(+) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 265ba2b33..7bfca680a 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -94,6 +94,7 @@ export class AudioClassifier extends AudioTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 445dd5172..246cba883 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -96,6 +96,7 @@ export class AudioEmbedder extends AudioTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 6712c4d89..2011fadef 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -76,6 +76,7 @@ export abstract class TaskRunner { return createTaskRunner(type, initializeCanvas, fileset, options); } + /** @hideconstructor protected */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, graphRunner?: GraphRunnerImageLib) { diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 4a8588836..62708700a 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -92,6 +92,7 @@ export class TextClassifier extends TaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index cd5bc644e..611233e02 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -96,6 +96,7 @@ export class TextEmbedder extends TaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 69a8118a6..b6b795076 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -127,6 +127,7 @@ export class GestureRecognizer extends {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 9a0823f23..2a0e8286c 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -115,6 +115,7 @@ export class HandLandmarker extends VisionTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 40e8b5099..36e7311fb 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -93,6 +93,7 @@ export class ImageClassifier extends VisionTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index f8b0204ee..0c45ba5e7 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -95,6 +95,7 @@ export class ImageEmbedder extends VisionTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index e2cfe0575..fbfaced12 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -92,6 +92,7 @@ export class ObjectDetector extends VisionTaskRunner { {baseOptions: {modelAssetPath}}); } + /** @hideconstructor */ constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { From edafef9fd8bfb34e91d8578f5ad68919b8cff702 Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Fri, 9 Dec 2022 18:08:41 -0800 Subject: [PATCH 54/64] Updated issue templates. PiperOrigin-RevId: 494305235 --- .github/ISSUE_TEMPLATE/11-tasks-issue.md | 2 +- .github/ISSUE_TEMPLATE/12-model-maker-issue.md | 2 +- .../{10-solution-issue.md => 13-solution-issue.md} | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) rename .github/ISSUE_TEMPLATE/{10-solution-issue.md => 13-solution-issue.md} (81%) diff --git a/.github/ISSUE_TEMPLATE/11-tasks-issue.md b/.github/ISSUE_TEMPLATE/11-tasks-issue.md index 264371120..4e9ae721d 100644 --- a/.github/ISSUE_TEMPLATE/11-tasks-issue.md +++ b/.github/ISSUE_TEMPLATE/11-tasks-issue.md @@ -1,6 +1,6 @@ --- name: "Tasks Issue" -about: Use this template for assistance with using MediaPipe Tasks to deploy on-device ML solutions (e.g. gesture recognition etc.) on supported platforms. +about: Use this template for assistance with using MediaPipe Tasks (developers.google.com/mediapipe/solutions) to deploy on-device ML solutions (e.g. gesture recognition etc.) on supported platforms. labels: type:support --- diff --git a/.github/ISSUE_TEMPLATE/12-model-maker-issue.md b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md index 258390d5e..31e8d7f1b 100644 --- a/.github/ISSUE_TEMPLATE/12-model-maker-issue.md +++ b/.github/ISSUE_TEMPLATE/12-model-maker-issue.md @@ -1,6 +1,6 @@ --- name: "Model Maker Issue" -about: Use this template for assistance with using MediaPipe Model Maker to create custom on-device ML solutions. +about: Use this template for assistance with using MediaPipe Model Maker (developers.google.com/mediapipe/solutions) to create custom on-device ML solutions. labels: type:support --- diff --git a/.github/ISSUE_TEMPLATE/10-solution-issue.md b/.github/ISSUE_TEMPLATE/13-solution-issue.md similarity index 81% rename from .github/ISSUE_TEMPLATE/10-solution-issue.md rename to .github/ISSUE_TEMPLATE/13-solution-issue.md index a5332cb36..9297edf6b 100644 --- a/.github/ISSUE_TEMPLATE/10-solution-issue.md +++ b/.github/ISSUE_TEMPLATE/13-solution-issue.md @@ -1,6 +1,6 @@ --- -name: "Solution Issue" -about: Use this template for assistance with a specific mediapipe solution, such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc. +name: "Solution (legacy) Issue" +about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions), such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc. labels: type:support --- From e9bb51a524bc3c9e38aa7e689020172bea678069 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 9 Dec 2022 19:19:49 -0800 Subject: [PATCH 55/64] Internal change PiperOrigin-RevId: 494314595 --- .../mediapipe/apps/instantmotiontracking/GIFEditText.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java index 10e6422ba..1b733ed82 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java @@ -18,7 +18,7 @@ import android.content.ClipDescription; import android.content.Context; import android.net.Uri; import android.os.Bundle; -import android.support.v7.widget.AppCompatEditText; +import androidx.appcompat.widget.AppCompatEditText; import android.util.AttributeSet; import android.util.Log; import android.view.inputmethod.EditorInfo; From 421f789edea501d5fbfd7078d2d9534a628dd886 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Sat, 10 Dec 2022 12:32:04 -0800 Subject: [PATCH 56/64] Internal change PiperOrigin-RevId: 494420725 --- mediapipe/framework/tool/BUILD | 2 + .../tool/calculator_graph_template.proto | 3 + mediapipe/framework/tool/proto_util_lite.cc | 103 +++++++++---- mediapipe/framework/tool/proto_util_lite.h | 28 +++- mediapipe/framework/tool/template_expander.cc | 136 ++++++++++++------ mediapipe/framework/tool/template_parser.cc | 128 ++++++++++++++++- mediapipe/framework/tool/testdata/BUILD | 10 ++ .../tool/testdata/frozen_generator.proto | 20 +++ 8 files changed, 348 insertions(+), 82 deletions(-) create mode 100644 mediapipe/framework/tool/testdata/frozen_generator.proto diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 453b5a0e8..89cb802da 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -346,6 +346,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", "@com_google_absl//absl/strings", ], ) @@ -506,6 +507,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":proto_util_lite", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:proto_descriptor_cc_proto", "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:integral_types", diff --git a/mediapipe/framework/tool/calculator_graph_template.proto b/mediapipe/framework/tool/calculator_graph_template.proto index 27153f3f7..31c233812 100644 --- a/mediapipe/framework/tool/calculator_graph_template.proto +++ b/mediapipe/framework/tool/calculator_graph_template.proto @@ -27,6 +27,9 @@ message TemplateExpression { // The FieldDescriptor::Type of the modified field. optional mediapipe.FieldDescriptorProto.Type field_type = 5; + // The FieldDescriptor::Type of each map key in the path. + repeated mediapipe.FieldDescriptorProto.Type key_type = 6; + // Alternative value for the modified field, in protobuf binary format. optional string field_value = 7; } diff --git a/mediapipe/framework/tool/proto_util_lite.cc b/mediapipe/framework/tool/proto_util_lite.cc index 4628815ea..a810ce129 100644 --- a/mediapipe/framework/tool/proto_util_lite.cc +++ b/mediapipe/framework/tool/proto_util_lite.cc @@ -22,6 +22,7 @@ #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/tool/field_data.pb.h" #include "mediapipe/framework/type_map.h" @@ -87,12 +88,13 @@ absl::Status ReadPackedValues(WireFormatLite::WireType wire_type, // Extracts the data value(s) for one field from a serialized message. // The message with these field values removed is written to |out|. -absl::Status GetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type, - CodedInputStream* in, CodedOutputStream* out, +absl::Status GetFieldValues(uint32 field_id, CodedInputStream* in, + CodedOutputStream* out, std::vector* field_values) { uint32 tag; while ((tag = in->ReadTag()) != 0) { int field_number = WireFormatLite::GetTagFieldNumber(tag); + WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag); if (field_number == field_id) { if (!IsLengthDelimited(wire_type) && IsLengthDelimited(WireFormatLite::GetTagWireType(tag))) { @@ -131,9 +133,7 @@ absl::Status FieldAccess::SetMessage(const std::string& message) { CodedInputStream in(&ais); StringOutputStream sos(&message_); CodedOutputStream out(&sos); - WireFormatLite::WireType wire_type = - WireFormatLite::WireTypeForFieldType(field_type_); - return GetFieldValues(field_id_, wire_type, &in, &out, &field_values_); + return GetFieldValues(field_id_, &in, &out, &field_values_); } void FieldAccess::GetMessage(std::string* result) { @@ -149,18 +149,56 @@ std::vector* FieldAccess::mutable_field_values() { return &field_values_; } +namespace { +using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry; + +// Returns the FieldAccess and index for a field-id or a map-id. +// Returns access to the field-id if the field index is found, +// to the map-id if the map entry is found, and to the field-id otherwise. +absl::StatusOr> AccessField( + const ProtoPathEntry& entry, FieldType field_type, + const FieldValue& message) { + FieldAccess result(entry.field_id, field_type); + if (entry.field_id >= 0) { + MP_RETURN_IF_ERROR(result.SetMessage(message)); + if (entry.index < result.mutable_field_values()->size()) { + return std::pair(result, entry.index); + } + } + if (entry.map_id >= 0) { + FieldAccess access(entry.map_id, field_type); + MP_RETURN_IF_ERROR(access.SetMessage(message)); + auto& field_values = *access.mutable_field_values(); + for (int index = 0; index < field_values.size(); ++index) { + FieldAccess key(entry.key_id, entry.key_type); + MP_RETURN_IF_ERROR(key.SetMessage(field_values[index])); + if (key.mutable_field_values()->at(0) == entry.key_value) { + return std::pair(std::move(access), index); + } + } + } + if (entry.field_id >= 0) { + return std::pair(result, entry.index); + } + return absl::InvalidArgumentError(absl::StrCat( + "ProtoPath field missing, field-id: ", entry.field_id, ", map-id: ", + entry.map_id, ", key: ", entry.key_value, " key_type: ", entry.key_type)); +} + +} // namespace + // Replaces a range of field values for one field nested within a protobuf. absl::Status ProtoUtilLite::ReplaceFieldRange( FieldValue* message, ProtoPath proto_path, int length, FieldType field_type, const std::vector& field_values) { - int field_id, index; - std::tie(field_id, index) = proto_path.front(); + ProtoPathEntry entry = proto_path.front(); proto_path.erase(proto_path.begin()); - FieldAccess access(field_id, !proto_path.empty() - ? WireFormatLite::TYPE_MESSAGE - : field_type); - MP_RETURN_IF_ERROR(access.SetMessage(*message)); - std::vector& v = *access.mutable_field_values(); + FieldType type = + !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type; + ASSIGN_OR_RETURN(auto r, AccessField(entry, type, *message)); + FieldAccess& access = r.first; + int index = r.second; + std::vector& v = *access.mutable_field_values(); if (!proto_path.empty()) { RET_CHECK_NO_LOG(index >= 0 && index < v.size()); MP_RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length, @@ -180,19 +218,22 @@ absl::Status ProtoUtilLite::ReplaceFieldRange( absl::Status ProtoUtilLite::GetFieldRange( const FieldValue& message, ProtoPath proto_path, int length, FieldType field_type, std::vector* field_values) { - int field_id, index; - std::tie(field_id, index) = proto_path.front(); + ProtoPathEntry entry = proto_path.front(); proto_path.erase(proto_path.begin()); - FieldAccess access(field_id, !proto_path.empty() - ? WireFormatLite::TYPE_MESSAGE - : field_type); - MP_RETURN_IF_ERROR(access.SetMessage(message)); - std::vector& v = *access.mutable_field_values(); + FieldType type = + !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type; + ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message)); + FieldAccess& access = r.first; + int index = r.second; + std::vector& v = *access.mutable_field_values(); if (!proto_path.empty()) { RET_CHECK_NO_LOG(index >= 0 && index < v.size()); MP_RETURN_IF_ERROR( GetFieldRange(v[index], proto_path, length, field_type, field_values)); } else { + if (length == -1) { + length = v.size() - index; + } RET_CHECK_NO_LOG(index >= 0 && index <= v.size()); RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size()); field_values->insert(field_values->begin(), v.begin() + index, @@ -206,19 +247,21 @@ absl::Status ProtoUtilLite::GetFieldCount(const FieldValue& message, ProtoPath proto_path, FieldType field_type, int* field_count) { - int field_id, index; - std::tie(field_id, index) = proto_path.back(); - proto_path.pop_back(); - std::vector parent; - if (proto_path.empty()) { - parent.push_back(std::string(message)); + ProtoPathEntry entry = proto_path.front(); + proto_path.erase(proto_path.begin()); + FieldType type = + !proto_path.empty() ? WireFormatLite::TYPE_MESSAGE : field_type; + ASSIGN_OR_RETURN(auto r, AccessField(entry, type, message)); + FieldAccess& access = r.first; + int index = r.second; + std::vector& v = *access.mutable_field_values(); + if (!proto_path.empty()) { + RET_CHECK_NO_LOG(index >= 0 && index < v.size()); + MP_RETURN_IF_ERROR( + GetFieldCount(v[index], proto_path, field_type, field_count)); } else { - MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange( - message, proto_path, 1, WireFormatLite::TYPE_MESSAGE, &parent)); + *field_count = v.size(); } - FieldAccess access(field_id, field_type); - MP_RETURN_IF_ERROR(access.SetMessage(parent[0])); - *field_count = access.mutable_field_values()->size(); return absl::OkStatus(); } diff --git a/mediapipe/framework/tool/proto_util_lite.h b/mediapipe/framework/tool/proto_util_lite.h index 7d3a263f3..d71ceac83 100644 --- a/mediapipe/framework/tool/proto_util_lite.h +++ b/mediapipe/framework/tool/proto_util_lite.h @@ -34,15 +34,31 @@ class ProtoUtilLite { // Defines field types and tag formats. using WireFormatLite = proto_ns::internal::WireFormatLite; - // Defines a sequence of nested field-number field-index pairs. - using ProtoPath = std::vector>; - // The serialized value for a protobuf field. using FieldValue = std::string; // The serialized data type for a protobuf field. using FieldType = WireFormatLite::FieldType; + // A field-id and index, or a map-id and key, or both. + struct ProtoPathEntry { + ProtoPathEntry(int id, int index) : field_id(id), index(index) {} + ProtoPathEntry(int id, int key_id, FieldType key_type, FieldValue key_value) + : map_id(id), + key_id(key_id), + key_type(key_type), + key_value(std::move(key_value)) {} + int field_id = -1; + int index = -1; + int map_id = -1; + int key_id = -1; + FieldType key_type; + FieldValue key_value; + }; + + // Defines a sequence of nested field-number field-index pairs. + using ProtoPath = std::vector; + class FieldAccess { public: // Provides access to a certain protobuf field. @@ -57,9 +73,11 @@ class ProtoUtilLite { // Returns the serialized values of the protobuf field. std::vector* mutable_field_values(); + uint32 field_id() const { return field_id_; } + private: - const uint32 field_id_; - const FieldType field_type_; + uint32 field_id_; + FieldType field_type_; std::string message_; std::vector field_values_; }; diff --git a/mediapipe/framework/tool/template_expander.cc b/mediapipe/framework/tool/template_expander.cc index 034e1a026..a91ea5adc 100644 --- a/mediapipe/framework/tool/template_expander.cc +++ b/mediapipe/framework/tool/template_expander.cc @@ -22,6 +22,7 @@ #include #include "absl/strings/ascii.h" +#include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" @@ -44,6 +45,7 @@ using WireFormatLite = ProtoUtilLite::WireFormatLite; using FieldValue = ProtoUtilLite::FieldValue; using FieldType = ProtoUtilLite::FieldType; using ProtoPath = ProtoUtilLite::ProtoPath; +using ProtoPathEntry = ProtoUtilLite::ProtoPathEntry; namespace { @@ -84,26 +86,87 @@ std::unique_ptr CloneMessage(const MessageLite& message) { return result; } -// Returns the (tag, index) pairs in a field path. -// For example, returns {{1, 1}, {2, 1}, {3, 1}} for path "/1[1]/2[1]/3[1]". -absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) { - absl::Status status; - std::vector ids = absl::StrSplit(path, '/'); - for (const std::string& id : ids) { - if (id.length() > 0) { - std::pair id_pair = - absl::StrSplit(id, absl::ByAnyChar("[]")); - int tag = 0; - int index = 0; - bool ok = absl::SimpleAtoi(id_pair.first, &tag) && - absl::SimpleAtoi(id_pair.second, &index); - if (!ok) { - status.Update(absl::InvalidArgumentError(path)); - } - result->push_back(std::make_pair(tag, index)); +// Parses one ProtoPathEntry. +// The parsed entry is appended to `result` and removed from `path`. +// ProtoPathEntry::key_value stores map key text. Use SetMapKeyTypes +// to serialize the key text to protobuf wire format. +absl::Status ParseEntry(absl::string_view& path, ProtoPath* result) { + bool ok = true; + int sb = path.find('['); + int eb = path.find(']'); + int field_id = -1; + ok &= absl::SimpleAtoi(path.substr(0, sb), &field_id); + auto selector = path.substr(sb + 1, eb - 1 - sb); + if (absl::StartsWith(selector, "@")) { + int eq = selector.find('='); + int key_id = -1; + ok &= absl::SimpleAtoi(selector.substr(1, eq - 1), &key_id); + auto key_text = selector.substr(eq + 1); + FieldType key_type = FieldType::TYPE_STRING; + result->push_back({field_id, key_id, key_type, std::string(key_text)}); + } else { + int index = 0; + ok &= absl::SimpleAtoi(selector, &index); + result->push_back({field_id, index}); + } + int end = path.find('/', eb); + if (end == std::string::npos) { + path = ""; + } else { + path = path.substr(end + 1); + } + return ok ? absl::OkStatus() + : absl::InvalidArgumentError( + absl::StrCat("Failed to parse ProtoPath entry: ", path)); +} + +// Specifies the FieldTypes for protobuf map keys in a ProtoPath. +// Each ProtoPathEntry::key_value is converted from text to the protobuf +// wire format for its key type. +absl::Status SetMapKeyTypes(const std::vector& key_types, + ProtoPath* result) { + int i = 0; + for (ProtoPathEntry& entry : *result) { + if (entry.map_id >= 0) { + FieldType key_type = key_types[i++]; + std::vector key_value; + MP_RETURN_IF_ERROR( + ProtoUtilLite::Serialize({entry.key_value}, key_type, &key_value)); + entry.key_type = key_type; + entry.key_value = key_value.front(); } } - return status; + return absl::OkStatus(); +} + +// Returns the (tag, index) pairs in a field path. +// For example, returns {{1, 1}, {2, 1}, {3, 1}} for "/1[1]/2[1]/3[1]", +// returns {{1, 1}, {2, 1, "INPUT_FRAMES"}} for "/1[1]/2[@1=INPUT_FRAMES]". +absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) { + result->clear(); + absl::string_view rest = path; + if (absl::StartsWith(rest, "/")) { + rest = rest.substr(1); + } + while (!rest.empty()) { + MP_RETURN_IF_ERROR(ParseEntry(rest, result)); + } + return absl::OkStatus(); +} + +// Parse the TemplateExpression.path field into a ProtoPath struct. +absl::Status ParseProtoPath(const TemplateExpression& rule, + std::string base_path, ProtoPath* result) { + ProtoPath base_entries; + MP_RETURN_IF_ERROR(ProtoPathSplit(base_path, &base_entries)); + MP_RETURN_IF_ERROR(ProtoPathSplit(rule.path(), result)); + std::vector key_types; + for (int type : rule.key_type()) { + key_types.push_back(static_cast(type)); + } + MP_RETURN_IF_ERROR(SetMapKeyTypes(key_types, result)); + result->erase(result->begin(), result->begin() + base_entries.size()); + return absl::OkStatus(); } // Returns true if one proto path is prefix by another. @@ -111,13 +174,6 @@ bool ProtoPathStartsWith(const std::string& path, const std::string& prefix) { return absl::StartsWith(path, prefix); } -// Returns the part of one proto path after a prefix proto path. -std::string ProtoPathRelative(const std::string& field_path, - const std::string& base_path) { - CHECK(ProtoPathStartsWith(field_path, base_path)); - return field_path.substr(base_path.length()); -} - // Returns the target ProtoUtilLite::FieldType of a rule. FieldType GetFieldType(const TemplateExpression& rule) { return static_cast(rule.field_type()); @@ -126,19 +182,10 @@ FieldType GetFieldType(const TemplateExpression& rule) { // Returns the count of field values at a ProtoPath. int FieldCount(const FieldValue& base, ProtoPath field_path, FieldType field_type) { - int field_id, index; - std::tie(field_id, index) = field_path.back(); - field_path.pop_back(); - std::vector parent; - if (field_path.empty()) { - parent.push_back(base); - } else { - MEDIAPIPE_CHECK_OK(ProtoUtilLite::GetFieldRange( - base, field_path, 1, WireFormatLite::TYPE_MESSAGE, &parent)); - } - ProtoUtilLite::FieldAccess access(field_id, field_type); - MEDIAPIPE_CHECK_OK(access.SetMessage(parent[0])); - return access.mutable_field_values()->size(); + int result = 0; + CHECK( + ProtoUtilLite::GetFieldCount(base, field_path, field_type, &result).ok()); + return result; } } // namespace @@ -229,9 +276,7 @@ class TemplateExpanderImpl { return absl::OkStatus(); } ProtoPath field_path; - absl::Status status = - ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path); - if (!status.ok()) return status; + MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path)); return ProtoUtilLite::GetFieldRange(output, field_path, 1, GetFieldType(rule), base); } @@ -242,12 +287,13 @@ class TemplateExpanderImpl { const std::vector& field_values, FieldValue* output) { if (!rule.has_path()) { - *output = field_values[0]; + if (!field_values.empty()) { + *output = field_values[0]; + } return absl::OkStatus(); } ProtoPath field_path; - RET_CHECK_OK( - ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path)); + MP_RETURN_IF_ERROR(ParseProtoPath(rule, base_path, &field_path)); int field_count = 1; if (rule.has_field_value()) { // For a non-repeated field, only one value can be specified. @@ -257,7 +303,7 @@ class TemplateExpanderImpl { "Multiple values specified for non-repeated field: ", rule.path())); } // For a non-repeated field, the field value is stored only in the rule. - field_path[field_path.size() - 1].second = 0; + field_path[field_path.size() - 1].index = 0; field_count = 0; } return ProtoUtilLite::ReplaceFieldRange(output, field_path, field_count, diff --git a/mediapipe/framework/tool/template_parser.cc b/mediapipe/framework/tool/template_parser.cc index 1d81e7a78..5a0ceccd3 100644 --- a/mediapipe/framework/tool/template_parser.cc +++ b/mediapipe/framework/tool/template_parser.cc @@ -26,6 +26,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/deps/proto_descriptor.pb.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/integral_types.h" @@ -45,6 +46,9 @@ using mediapipe::proto_ns::Message; using mediapipe::proto_ns::OneofDescriptor; using mediapipe::proto_ns::Reflection; using mediapipe::proto_ns::TextFormat; +using ProtoPath = mediapipe::tool::ProtoUtilLite::ProtoPath; +using FieldType = mediapipe::tool::ProtoUtilLite::FieldType; +using FieldValue = mediapipe::tool::ProtoUtilLite::FieldValue; namespace mediapipe { @@ -1357,7 +1361,7 @@ absl::Status ProtoPathSplit(const std::string& path, if (!ok) { status.Update(absl::InvalidArgumentError(path)); } - result->push_back(std::make_pair(tag, index)); + result->push_back({tag, index}); } } return status; @@ -1381,7 +1385,7 @@ void StowFieldValue(Message* message, TemplateExpression* expression) { const Descriptor* descriptor = message->GetDescriptor(); ProtoUtilLite::ProtoPath path; MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path)); - int field_number = path[path.size() - 1].first; + int field_number = path[path.size() - 1].field_id; const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number); if (!field->is_repeated()) { std::vector field_values; @@ -1402,6 +1406,124 @@ static void StripQuotes(std::string* str) { } } +// Returns the field or extension for field number. +const FieldDescriptor* FindFieldByNumber(const Message* message, + int field_num) { + const FieldDescriptor* result = + message->GetDescriptor()->FindFieldByNumber(field_num); + if (result == nullptr) { + result = message->GetReflection()->FindKnownExtensionByNumber(field_num); + } + return result; +} + +// Returns the message value from a field at an index. +const Message* GetFieldMessage(const Message& message, + const FieldDescriptor* field, int index) { + if (field->type() != FieldDescriptor::TYPE_MESSAGE) { + return nullptr; + } + if (!field->is_repeated()) { + return &message.GetReflection()->GetMessage(message, field); + } + if (index < message.GetReflection()->FieldSize(message, field)) { + return &message.GetReflection()->GetRepeatedMessage(message, field, index); + } + return nullptr; +} + +// Serialize a ProtoPath as a readable string. +// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]", +// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]". +std::string ProtoPathJoin(ProtoPath path) { + std::string result; + for (ProtoUtilLite::ProtoPathEntry& e : path) { + if (e.field_id >= 0) { + absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]"); + } else if (e.map_id >= 0) { + absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value, + "]"); + } + } + return result; +} + +// Returns the protobuf map key types from a ProtoPath. +std::vector ProtoPathKeyTypes(ProtoPath path) { + std::vector result; + for (auto& entry : path) { + if (entry.map_id >= 0) { + result.push_back(entry.key_type); + } + } + return result; +} + +// Returns the text value for a string or numeric protobuf map key. +std::string GetMapKey(const Message& map_entry) { + auto key_field = map_entry.GetDescriptor()->FindFieldByName("key"); + auto reflection = map_entry.GetReflection(); + if (key_field->type() == FieldDescriptor::TYPE_STRING) { + return reflection->GetString(map_entry, key_field); + } else if (key_field->type() == FieldDescriptor::TYPE_INT32) { + return absl::StrCat(reflection->GetInt32(map_entry, key_field)); + } else if (key_field->type() == FieldDescriptor::TYPE_INT64) { + return absl::StrCat(reflection->GetInt64(map_entry, key_field)); + } + return ""; +} + +// Adjusts map-entries from indexes to keys. +// Protobuf map-entry order is intentionally not preserved. +mediapipe::Status KeyProtoMapEntries(Message* source) { + // Copy the rules from the source CalculatorGraphTemplate. + mediapipe::CalculatorGraphTemplate rules; + rules.ParsePartialFromString(source->SerializePartialAsString()); + // Only the "source" Message knows all extension types. + Message* config_0 = source->GetReflection()->MutableMessage( + source, source->GetDescriptor()->FindFieldByName("config"), nullptr); + for (int i = 0; i < rules.rule().size(); ++i) { + TemplateExpression* rule = rules.mutable_rule()->Mutable(i); + const Message* message = config_0; + ProtoPath path; + MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path)); + for (int j = 0; j < path.size(); ++j) { + int field_id = path[j].field_id; + int field_index = path[j].index; + const FieldDescriptor* field = FindFieldByNumber(message, field_id); + if (field->is_map()) { + const Message* map_entry = + GetFieldMessage(*message, field, path[j].index); + int key_id = + map_entry->GetDescriptor()->FindFieldByName("key")->number(); + FieldType key_type = static_cast( + map_entry->GetDescriptor()->FindFieldByName("key")->type()); + std::string key_value = GetMapKey(*map_entry); + path[j] = {field_id, key_id, key_type, key_value}; + } + message = GetFieldMessage(*message, field, field_index); + if (!message) { + break; + } + } + if (!rule->path().empty()) { + *rule->mutable_path() = ProtoPathJoin(path); + for (FieldType key_type : ProtoPathKeyTypes(path)) { + *rule->mutable_key_type()->Add() = key_type; + } + } + } + // Copy the rules back into the source CalculatorGraphTemplate. + auto source_rules = + source->GetReflection()->GetMutableRepeatedFieldRef( + source, source->GetDescriptor()->FindFieldByName("rule")); + source_rules.Clear(); + for (auto& rule : rules.rule()) { + source_rules.Add(rule); + } + return absl::OkStatus(); +} + } // namespace class TemplateParser::Parser::MediaPipeParserImpl @@ -1416,6 +1538,8 @@ class TemplateParser::Parser::MediaPipeParserImpl // Copy the template rules into the output template "rule" field. success &= MergeFields(template_rules_, output).ok(); + // Replace map-entry indexes with map keys. + success &= KeyProtoMapEntries(output).ok(); return success; } diff --git a/mediapipe/framework/tool/testdata/BUILD b/mediapipe/framework/tool/testdata/BUILD index f9aab7b72..8300181b5 100644 --- a/mediapipe/framework/tool/testdata/BUILD +++ b/mediapipe/framework/tool/testdata/BUILD @@ -17,6 +17,7 @@ load( "//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_simple_subgraph", ) +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) @@ -58,3 +59,12 @@ mediapipe_simple_subgraph( "//mediapipe/framework:test_calculators", ], ) + +mediapipe_proto_library( + name = "frozen_generator_proto", + srcs = ["frozen_generator.proto"], + visibility = ["//mediapipe/framework:__subpackages__"], + deps = [ + "//mediapipe/framework:packet_generator_proto", + ], +) diff --git a/mediapipe/framework/tool/testdata/frozen_generator.proto b/mediapipe/framework/tool/testdata/frozen_generator.proto new file mode 100644 index 000000000..5f133f461 --- /dev/null +++ b/mediapipe/framework/tool/testdata/frozen_generator.proto @@ -0,0 +1,20 @@ +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/packet_generator.proto"; + +message FrozenGeneratorOptions { + extend mediapipe.PacketGeneratorOptions { + optional FrozenGeneratorOptions ext = 225748738; + } + + // Path to file containing serialized proto of type tensorflow::GraphDef. + optional string graph_proto_path = 1; + + // This map defines the which streams are fed to which tensors in the model. + map tag_to_tensor_names = 2; + + // Graph nodes to run to initialize the model. + repeated string initialization_op_names = 4; +} From 37d2e369605e87cf741220db1d3c1b4afb403def Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 12 Dec 2022 12:08:45 -0800 Subject: [PATCH 57/64] Internal change PiperOrigin-RevId: 494791433 --- .github/ISSUE_TEMPLATE/14-studio-issue.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/14-studio-issue.md diff --git a/.github/ISSUE_TEMPLATE/14-studio-issue.md b/.github/ISSUE_TEMPLATE/14-studio-issue.md new file mode 100644 index 000000000..5942b1eb1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/14-studio-issue.md @@ -0,0 +1,19 @@ +--- +name: "Studio Issue" +about: Use this template for assistance with the MediaPipe Studio application. +labels: type:support + +--- +Please make sure that this is a MediaPipe Studio issue. + +**System information** (Please provide as much relevant information as possible) +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, Android 11, iOS 14.4): +- Browser and Version +- Any microphone or camera hardware +- URL that shows the problem + +**Describe the expected behavior:** + +**Other info / Complete Logs :** +Include any js console logs that would be helpful to diagnose the problem. +Large logs and files should be attached: From 3f66dde8fdb459be8552b837e83fb2a79c44566c Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Mon, 12 Dec 2022 17:33:08 -0800 Subject: [PATCH 58/64] Change `--site_path` default value to match the actual path. This did not match the URL we ended up using for MediaPipe, so needs to be set correctly in order to generate docs that match the real site. This change sets the default to be correct. PiperOrigin-RevId: 494874789 --- docs/build_py_api_docs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/build_py_api_docs.py b/docs/build_py_api_docs.py index 46546012d..02eb04074 100644 --- a/docs/build_py_api_docs.py +++ b/docs/build_py_api_docs.py @@ -44,14 +44,14 @@ _OUTPUT_DIR = flags.DEFINE_string( _URL_PREFIX = flags.DEFINE_string( 'code_url_prefix', - 'https://github.com/google/mediapipe/tree/master/mediapipe', + 'https://github.com/google/mediapipe/blob/master/mediapipe', 'The url prefix for links to code.') _SEARCH_HINTS = flags.DEFINE_bool( 'search_hints', True, 'Include metadata search hints in the generated files') -_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python', +_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api/solutions/python', 'Path prefix in the _toc.yaml') From fb2179761187f5a0c73c973d94690685170d9a21 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 12 Dec 2022 21:28:35 -0800 Subject: [PATCH 59/64] Internal change PiperOrigin-RevId: 494914168 --- mediapipe/calculators/image/image_cropping_calculator.cc | 3 ++- mediapipe/calculators/image/image_cropping_calculator_test.cc | 4 ++-- mediapipe/calculators/util/detections_to_rects_calculator.cc | 3 +++ .../calculators/util/detections_to_rects_calculator_test.cc | 3 +++ mediapipe/calculators/util/landmark_projection_calculator.cc | 2 ++ mediapipe/calculators/util/landmarks_smoothing_calculator.cc | 2 ++ mediapipe/calculators/util/rect_projection_calculator.cc | 2 ++ mediapipe/calculators/util/rect_to_render_data_calculator.cc | 3 +++ mediapipe/calculators/util/rect_to_render_scale_calculator.cc | 2 ++ mediapipe/calculators/util/rect_transformation_calculator.cc | 3 +++ .../calculators/util/world_landmark_projection_calculator.cc | 2 ++ .../calculators/video/tracked_detection_manager_calculator.cc | 2 ++ .../calculators/hand_landmarks_to_rect_calculator.cc | 2 ++ .../holistic_landmark/calculators/roi_tracking_calculator.cc | 2 ++ .../calculators/frame_annotation_to_rect_calculator.cc | 2 ++ .../cc/components/processors/image_preprocessing_graph.cc | 1 + .../calculators/landmarks_to_matrix_calculator.cc | 2 ++ .../calculators/landmarks_to_matrix_calculator_test.cc | 2 ++ .../tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc | 2 ++ .../cc/vision/gesture_recognizer/gesture_recognizer_graph.cc | 1 + .../gesture_recognizer/hand_gesture_recognizer_graph.cc | 1 + .../tasks/cc/vision/hand_detector/hand_detector_graph.cc | 1 + .../tasks/cc/vision/hand_detector/hand_detector_graph_test.cc | 1 + .../calculators/hand_association_calculator.cc | 2 ++ .../calculators/hand_association_calculator_test.cc | 2 ++ .../calculators/hand_landmarks_deduplication_calculator.cc | 1 + mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc | 2 ++ .../tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc | 1 + .../cc/vision/hand_landmarker/hand_landmarker_graph_test.cc | 1 + .../vision/hand_landmarker/hand_landmarks_detector_graph.cc | 1 + .../hand_landmarker/hand_landmarks_detector_graph_test.cc | 1 + .../tasks/cc/vision/image_classifier/image_classifier.cc | 1 + .../cc/vision/image_classifier/image_classifier_graph.cc | 1 + mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc | 1 + .../tasks/cc/vision/image_embedder/image_embedder_graph.cc | 1 + mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc | 1 + .../tasks/cc/vision/image_segmenter/image_segmenter_graph.cc | 1 + mediapipe/tasks/cc/vision/object_detector/object_detector.cc | 1 + .../tasks/cc/vision/object_detector/object_detector_graph.cc | 1 + mediapipe/util/rectangle_util_test.cc | 1 + mediapipe/util/tracking/tracked_detection.cc | 2 ++ mediapipe/util/tracking/tracked_detection_manager.cc | 1 + mediapipe/util/tracking/tracked_detection_test.cc | 2 ++ 43 files changed, 70 insertions(+), 3 deletions(-) diff --git a/mediapipe/calculators/image/image_cropping_calculator.cc b/mediapipe/calculators/image/image_cropping_calculator.cc index 8c9305ffb..1a2b2e5b0 100644 --- a/mediapipe/calculators/image/image_cropping_calculator.cc +++ b/mediapipe/calculators/image/image_cropping_calculator.cc @@ -37,7 +37,8 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; namespace mediapipe { namespace { - +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; #if !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU diff --git a/mediapipe/calculators/image/image_cropping_calculator_test.cc b/mediapipe/calculators/image/image_cropping_calculator_test.cc index b3f692889..3c565282b 100644 --- a/mediapipe/calculators/image/image_cropping_calculator_test.cc +++ b/mediapipe/calculators/image/image_cropping_calculator_test.cc @@ -195,11 +195,11 @@ TEST(ImageCroppingCalculatorTest, RedundantSpecWithInputStream) { auto cc = absl::make_unique( calculator_state.get(), inputTags, tool::CreateTagMap({}).value()); auto& inputs = cc->Inputs(); - mediapipe::Rect rect = ParseTextProtoOrDie( + Rect rect = ParseTextProtoOrDie( R"pb( width: 1 height: 1 x_center: 40 y_center: 40 rotation: 0.5 )pb"); - inputs.Tag(kRectTag).Value() = MakePacket(rect); + inputs.Tag(kRectTag).Value() = MakePacket(rect); RectSpec expectRect = { .width = 1, .height = 1, diff --git a/mediapipe/calculators/util/detections_to_rects_calculator.cc b/mediapipe/calculators/util/detections_to_rects_calculator.cc index 73a67d322..3e566836c 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator.cc @@ -37,6 +37,9 @@ constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kRectsTag[] = "RECTS"; constexpr char kNormRectsTag[] = "NORM_RECTS"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + constexpr float kMinFloat = std::numeric_limits::lowest(); constexpr float kMaxFloat = std::numeric_limits::max(); diff --git a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc index 6caf792a7..63de60a60 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc @@ -39,6 +39,9 @@ constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kRectTag[] = "RECT"; constexpr char kDetectionTag[] = "DETECTION"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + MATCHER_P4(RectEq, x_center, y_center, width, height, "") { return testing::Value(arg.x_center(), testing::Eq(x_center)) && testing::Value(arg.y_center(), testing::Eq(y_center)) && diff --git a/mediapipe/calculators/util/landmark_projection_calculator.cc b/mediapipe/calculators/util/landmark_projection_calculator.cc index e27edea66..9f276da56 100644 --- a/mediapipe/calculators/util/landmark_projection_calculator.cc +++ b/mediapipe/calculators/util/landmark_projection_calculator.cc @@ -24,6 +24,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + namespace { constexpr char kLandmarksTag[] = "NORM_LANDMARKS"; diff --git a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc index 6673816e7..7a92cfb7e 100644 --- a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc +++ b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc @@ -35,7 +35,9 @@ constexpr char kObjectScaleRoiTag[] = "OBJECT_SCALE_ROI"; constexpr char kNormalizedFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS"; constexpr char kFilteredLandmarksTag[] = "FILTERED_LANDMARKS"; +using ::mediapipe::NormalizedRect; using mediapipe::OneEuroFilter; +using ::mediapipe::Rect; using mediapipe::RelativeVelocityFilter; void NormalizedLandmarksToLandmarks( diff --git a/mediapipe/calculators/util/rect_projection_calculator.cc b/mediapipe/calculators/util/rect_projection_calculator.cc index dcc6e7391..69b28af87 100644 --- a/mediapipe/calculators/util/rect_projection_calculator.cc +++ b/mediapipe/calculators/util/rect_projection_calculator.cc @@ -23,6 +23,8 @@ namespace { constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormReferenceRectTag[] = "NORM_REFERENCE_RECT"; +using ::mediapipe::NormalizedRect; + } // namespace // Projects rectangle from reference coordinate system (defined by reference diff --git a/mediapipe/calculators/util/rect_to_render_data_calculator.cc b/mediapipe/calculators/util/rect_to_render_data_calculator.cc index 400be277d..bbc08255e 100644 --- a/mediapipe/calculators/util/rect_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_data_calculator.cc @@ -29,6 +29,9 @@ constexpr char kNormRectsTag[] = "NORM_RECTS"; constexpr char kRectsTag[] = "RECTS"; constexpr char kRenderDataTag[] = "RENDER_DATA"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + RenderAnnotation::Rectangle* NewRect( const RectToRenderDataCalculatorOptions& options, RenderData* render_data) { auto* annotation = render_data->add_render_annotations(); diff --git a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc index d94615228..6ff6b3d51 100644 --- a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc @@ -24,6 +24,8 @@ constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kRenderScaleTag[] = "RENDER_SCALE"; +using ::mediapipe::NormalizedRect; + } // namespace // A calculator to get scale for RenderData primitives. diff --git a/mediapipe/calculators/util/rect_transformation_calculator.cc b/mediapipe/calculators/util/rect_transformation_calculator.cc index 15bb26826..4783cb919 100644 --- a/mediapipe/calculators/util/rect_transformation_calculator.cc +++ b/mediapipe/calculators/util/rect_transformation_calculator.cc @@ -28,6 +28,9 @@ constexpr char kRectTag[] = "RECT"; constexpr char kRectsTag[] = "RECTS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +using ::mediapipe::NormalizedRect; +using ::mediapipe::Rect; + // Wraps around an angle in radians to within -M_PI and M_PI. inline float NormalizeRadians(float angle) { return angle - 2 * M_PI * std::floor((angle - (-M_PI)) / (2 * M_PI)); diff --git a/mediapipe/calculators/util/world_landmark_projection_calculator.cc b/mediapipe/calculators/util/world_landmark_projection_calculator.cc index bcd7352a2..e843d63bf 100644 --- a/mediapipe/calculators/util/world_landmark_projection_calculator.cc +++ b/mediapipe/calculators/util/world_landmark_projection_calculator.cc @@ -22,6 +22,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + namespace { constexpr char kLandmarksTag[] = "LANDMARKS"; diff --git a/mediapipe/calculators/video/tracked_detection_manager_calculator.cc b/mediapipe/calculators/video/tracked_detection_manager_calculator.cc index c416fa9b0..48664fead 100644 --- a/mediapipe/calculators/video/tracked_detection_manager_calculator.cc +++ b/mediapipe/calculators/video/tracked_detection_manager_calculator.cc @@ -32,6 +32,8 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + constexpr int kDetectionUpdateTimeOutMS = 5000; constexpr char kDetectionsTag[] = "DETECTIONS"; constexpr char kDetectionBoxesTag[] = "DETECTION_BOXES"; diff --git a/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc b/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc index 6f2c49d64..638678ff5 100644 --- a/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc +++ b/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc @@ -22,6 +22,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + namespace { // NORM_LANDMARKS is either the full set of landmarks for the hand, or diff --git a/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc b/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc index 0da6cd7f7..49c7b93fb 100644 --- a/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc +++ b/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc @@ -34,6 +34,8 @@ constexpr char kRecropRectTag[] = "RECROP_RECT"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kTrackingRectTag[] = "TRACKING_RECT"; +using ::mediapipe::NormalizedRect; + // TODO: Use rect rotation. // Verifies that Intersection over Union of previous frame rect and current // frame re-crop rect is less than threshold. diff --git a/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc b/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc index 476f8cb54..1fe919c54 100644 --- a/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc +++ b/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc @@ -34,6 +34,8 @@ namespace { constexpr char kInputFrameAnnotationTag[] = "FRAME_ANNOTATION"; constexpr char kOutputNormRectsTag[] = "NORM_RECTS"; +using ::mediapipe::NormalizedRect; + } // namespace // A calculator that converts FrameAnnotation proto to NormalizedRect. diff --git a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc index b24b7f0cb..fefc1ec52 100644 --- a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc @@ -45,6 +45,7 @@ namespace components { namespace processors { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::Tensor; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc index 277bb170a..088f97c29 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc @@ -35,6 +35,8 @@ limitations under the License. namespace mediapipe { namespace api2 { +using ::mediapipe::NormalizedRect; + namespace { constexpr char kLandmarksTag[] = "LANDMARKS"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc index fe6f1162b..a1a44c8d1 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc @@ -33,6 +33,8 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index e7fcf6fd9..01f444742 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -57,6 +57,8 @@ namespace { using GestureRecognizerGraphOptionsProto = ::mediapipe::tasks::vision:: gesture_recognizer::proto::GestureRecognizerGraphOptions; +using ::mediapipe::NormalizedRect; + constexpr char kHandGestureSubgraphTypeName[] = "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph"; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc index 47d95100b..2d949c410 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc @@ -46,6 +46,7 @@ namespace gesture_recognizer { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index d7e983d81..4db57e85b 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -52,6 +52,7 @@ namespace gesture_recognizer { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index c24548c9b..49958e36b 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -50,6 +50,7 @@ namespace hand_detector { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc index cbbc0e193..f4e5f8c7d 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc @@ -53,6 +53,7 @@ namespace { using ::file::Defaults; using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc index b6df80588..dffdbdd38 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc @@ -27,6 +27,8 @@ limitations under the License. namespace mediapipe::api2 { +using ::mediapipe::NormalizedRect; + // HandAssociationCalculator accepts multiple inputs of vectors of // NormalizedRect. The output is a vector of NormalizedRect that contains // rects from the input vectors that don't overlap with each other. When two diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc index cb3130854..138164209 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc @@ -26,6 +26,8 @@ limitations under the License. namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + class HandAssociationCalculatorTest : public testing::Test { protected: HandAssociationCalculatorTest() { diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc index 266ce223f..d875de98f 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc @@ -41,6 +41,7 @@ limitations under the License. namespace mediapipe::api2 { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Source; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc index 2b818b2e5..3bb1ee8d8 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc @@ -46,6 +46,8 @@ namespace { using HandLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision:: hand_landmarker::proto::HandLandmarkerGraphOptions; +using ::mediapipe::NormalizedRect; + constexpr char kHandLandmarkerGraphTypeName[] = "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index 2c4133eb1..05ad97efe 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -49,6 +49,7 @@ namespace hand_landmarker { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc index f275486f5..c28df2c05 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc @@ -54,6 +54,7 @@ namespace { using ::file::Defaults; using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index 014830ba2..4ea066aab 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -53,6 +53,7 @@ namespace hand_landmarker { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc index d1e928ce7..f28907d2f 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc @@ -50,6 +50,7 @@ namespace { using ::file::Defaults; using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc index 60f8f7ed4..763e0a320 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc @@ -58,6 +58,7 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::components::containers::ConvertToClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::PacketMap; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 2d0379c66..0adcf842d 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -38,6 +38,7 @@ namespace image_classifier { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc index e3198090f..494b075a7 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc @@ -54,6 +54,7 @@ constexpr char kGraphTypeName[] = "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::core::PacketMap; diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc index 81ccb5361..95c4ff379 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -34,6 +34,7 @@ namespace image_embedder { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index bbee714c6..7130c72e2 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -44,6 +44,7 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; +using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: image_segmenter::proto::ImageSegmenterGraphOptions; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 5531968c1..923cf2937 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -49,6 +49,7 @@ namespace image_segmenter { namespace { using ::mediapipe::Image; +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc index e0222dd70..2477f8a44 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc @@ -57,6 +57,7 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.ObjectDetectorGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::components::containers::ConvertToDetectionResult; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index fd95bb1ac..e5af7544d 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -52,6 +52,7 @@ namespace vision { namespace { +using ::mediapipe::NormalizedRect; using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; diff --git a/mediapipe/util/rectangle_util_test.cc b/mediapipe/util/rectangle_util_test.cc index cd1946d45..3bc323f9f 100644 --- a/mediapipe/util/rectangle_util_test.cc +++ b/mediapipe/util/rectangle_util_test.cc @@ -20,6 +20,7 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; using ::testing::FloatNear; class RectangleUtilTest : public testing::Test { diff --git a/mediapipe/util/tracking/tracked_detection.cc b/mediapipe/util/tracking/tracked_detection.cc index 130a87640..80a6981a8 100644 --- a/mediapipe/util/tracking/tracked_detection.cc +++ b/mediapipe/util/tracking/tracked_detection.cc @@ -20,6 +20,8 @@ namespace mediapipe { namespace { +using ::mediapipe::NormalizedRect; + // Struct for carrying boundary information. struct NormalizedRectBounds { float left, right, top, bottom; diff --git a/mediapipe/util/tracking/tracked_detection_manager.cc b/mediapipe/util/tracking/tracked_detection_manager.cc index 597827f3c..a9e348ceb 100644 --- a/mediapipe/util/tracking/tracked_detection_manager.cc +++ b/mediapipe/util/tracking/tracked_detection_manager.cc @@ -21,6 +21,7 @@ namespace { +using ::mediapipe::NormalizedRect; using mediapipe::TrackedDetection; // Checks if a point is out of view. diff --git a/mediapipe/util/tracking/tracked_detection_test.cc b/mediapipe/util/tracking/tracked_detection_test.cc index 60b9df1b1..13efaab92 100644 --- a/mediapipe/util/tracking/tracked_detection_test.cc +++ b/mediapipe/util/tracking/tracked_detection_test.cc @@ -18,6 +18,8 @@ namespace mediapipe { +using ::mediapipe::NormalizedRect; + const float kErrorMargin = 1e-4f; TEST(TrackedDetectionTest, ConstructorWithoutBox) { From 78597c5b37a2ef8f3f005ed55f0a01676a08fb0b Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 13 Dec 2022 09:05:19 -0800 Subject: [PATCH 60/64] Internal changes. PiperOrigin-RevId: 495038477 --- mediapipe/framework/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 3cc72b4f1..265ae9c6f 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -21,6 +21,7 @@ licenses(["notice"]) package(default_visibility = ["//visibility:private"]) +# The MediaPipe internal package group. No mediapipe users should be added to this group. package_group( name = "mediapipe_internal", packages = [ From db404b1a8593a8b316cc4930dc1bcc845fc3df62 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Tue, 13 Dec 2022 10:21:07 -0800 Subject: [PATCH 61/64] Internal change PiperOrigin-RevId: 495058817 --- mediapipe/framework/tool/proto_util_lite.h | 7 +- mediapipe/framework/tool/template_parser.cc | 181 +++++++++++++++----- 2 files changed, 144 insertions(+), 44 deletions(-) diff --git a/mediapipe/framework/tool/proto_util_lite.h b/mediapipe/framework/tool/proto_util_lite.h index d71ceac83..15e321eeb 100644 --- a/mediapipe/framework/tool/proto_util_lite.h +++ b/mediapipe/framework/tool/proto_util_lite.h @@ -48,11 +48,16 @@ class ProtoUtilLite { key_id(key_id), key_type(key_type), key_value(std::move(key_value)) {} + bool operator==(const ProtoPathEntry& o) const { + return field_id == o.field_id && index == o.index && map_id == o.map_id && + key_id == o.key_id && key_type == o.key_type && + key_value == o.key_value; + } int field_id = -1; int index = -1; int map_id = -1; int key_id = -1; - FieldType key_type; + FieldType key_type = FieldType::MAX_FIELD_TYPE; FieldValue key_value; }; diff --git a/mediapipe/framework/tool/template_parser.cc b/mediapipe/framework/tool/template_parser.cc index 5a0ceccd3..cf23f3443 100644 --- a/mediapipe/framework/tool/template_parser.cc +++ b/mediapipe/framework/tool/template_parser.cc @@ -1367,26 +1367,132 @@ absl::Status ProtoPathSplit(const std::string& path, return status; } +// Returns a message serialized deterministically. +bool DeterministicallySerialize(const Message& proto, std::string* result) { + proto_ns::io::StringOutputStream stream(result); + proto_ns::io::CodedOutputStream output(&stream); + output.SetSerializationDeterministic(true); + return proto.SerializeToCodedStream(&output); +} + // Serialize one field of a message. void SerializeField(const Message* message, const FieldDescriptor* field, std::vector* result) { ProtoUtilLite::FieldValue message_bytes; - CHECK(message->SerializePartialToString(&message_bytes)); + CHECK(DeterministicallySerialize(*message, &message_bytes)); ProtoUtilLite::FieldAccess access( field->number(), static_cast(field->type())); MEDIAPIPE_CHECK_OK(access.SetMessage(message_bytes)); *result = *access.mutable_field_values(); } +// Serialize a ProtoPath as a readable string. +// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]", +// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]". +std::string ProtoPathJoin(ProtoPath path) { + std::string result; + for (ProtoUtilLite::ProtoPathEntry& e : path) { + if (e.field_id >= 0) { + absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]"); + } else if (e.map_id >= 0) { + absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value, + "]"); + } + } + return result; +} + +// Returns the message value from a field at an index. +const Message* GetFieldMessage(const Message& message, + const FieldDescriptor* field, int index) { + if (field->type() != FieldDescriptor::TYPE_MESSAGE) { + return nullptr; + } + if (!field->is_repeated()) { + return &message.GetReflection()->GetMessage(message, field); + } + if (index < message.GetReflection()->FieldSize(message, field)) { + return &message.GetReflection()->GetRepeatedMessage(message, field, index); + } + return nullptr; +} + +// Returns all FieldDescriptors including extensions. +std::vector GetFields(const Message* src) { + std::vector result; + src->GetDescriptor()->file()->pool()->FindAllExtensions(src->GetDescriptor(), + &result); + for (int i = 0; i < src->GetDescriptor()->field_count(); ++i) { + result.push_back(src->GetDescriptor()->field(i)); + } + return result; +} + +// Orders map entries in dst to match src. +void OrderMapEntries(const Message* src, Message* dst, + std::set* seen = nullptr) { + std::unique_ptr> seen_owner; + if (!seen) { + seen_owner = std::make_unique>(); + seen = seen_owner.get(); + } + if (seen->count(src) > 0) { + return; + } else { + seen->insert(src); + } + for (auto field : GetFields(src)) { + if (field->is_map()) { + dst->GetReflection()->ClearField(dst, field); + for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) { + const Message& entry = + src->GetReflection()->GetRepeatedMessage(*src, field, j); + dst->GetReflection()->AddMessage(dst, field)->CopyFrom(entry); + } + } + if (field->type() == FieldDescriptor::TYPE_MESSAGE) { + if (field->is_repeated()) { + for (int j = 0; j < src->GetReflection()->FieldSize(*src, field); ++j) { + OrderMapEntries( + &src->GetReflection()->GetRepeatedMessage(*src, field, j), + dst->GetReflection()->MutableRepeatedMessage(dst, field, j), + seen); + } + } else { + OrderMapEntries(&src->GetReflection()->GetMessage(*src, field), + dst->GetReflection()->MutableMessage(dst, field), seen); + } + } + } +} + +// Copies a Message, keeping map entries in order. +std::unique_ptr CloneMessage(const Message* message) { + std::unique_ptr result(message->New()); + result->CopyFrom(*message); + OrderMapEntries(message, result.get()); + return result; +} + +using MessageMap = std::map>; + // For a non-repeated field, move the most recently parsed field value // into the most recently parsed template expression. -void StowFieldValue(Message* message, TemplateExpression* expression) { +void StowFieldValue(Message* message, TemplateExpression* expression, + MessageMap* stowed_messages) { const Reflection* reflection = message->GetReflection(); const Descriptor* descriptor = message->GetDescriptor(); ProtoUtilLite::ProtoPath path; MEDIAPIPE_CHECK_OK(ProtoPathSplit(expression->path(), &path)); int field_number = path[path.size() - 1].field_id; const FieldDescriptor* field = descriptor->FindFieldByNumber(field_number); + + // Save each stowed message unserialized preserving map entry order. + if (!field->is_repeated() && field->type() == FieldDescriptor::TYPE_MESSAGE) { + (*stowed_messages)[ProtoPathJoin(path)] = + CloneMessage(GetFieldMessage(*message, field, 0)); + } + if (!field->is_repeated()) { std::vector field_values; SerializeField(message, field, &field_values); @@ -1417,37 +1523,6 @@ const FieldDescriptor* FindFieldByNumber(const Message* message, return result; } -// Returns the message value from a field at an index. -const Message* GetFieldMessage(const Message& message, - const FieldDescriptor* field, int index) { - if (field->type() != FieldDescriptor::TYPE_MESSAGE) { - return nullptr; - } - if (!field->is_repeated()) { - return &message.GetReflection()->GetMessage(message, field); - } - if (index < message.GetReflection()->FieldSize(message, field)) { - return &message.GetReflection()->GetRepeatedMessage(message, field, index); - } - return nullptr; -} - -// Serialize a ProtoPath as a readable string. -// For example, {{1, 1}, {2, 1}, {3, 1}} returns "/1[1]/2[1]/3[1]", -// and {{1, 1}, {2, 1, "INPUT_FRAMES"}} returns "/1[1]/2[@1=INPUT_FRAMES]". -std::string ProtoPathJoin(ProtoPath path) { - std::string result; - for (ProtoUtilLite::ProtoPathEntry& e : path) { - if (e.field_id >= 0) { - absl::StrAppend(&result, "/", e.field_id, "[", e.index, "]"); - } else if (e.map_id >= 0) { - absl::StrAppend(&result, "/", e.map_id, "[@", e.key_id, "=", e.key_value, - "]"); - } - } - return result; -} - // Returns the protobuf map key types from a ProtoPath. std::vector ProtoPathKeyTypes(ProtoPath path) { std::vector result; @@ -1473,9 +1548,29 @@ std::string GetMapKey(const Message& map_entry) { return ""; } +// Returns a Message store in CalculatorGraphTemplate::field_value. +Message* FindStowedMessage(MessageMap* stowed_messages, ProtoPath proto_path) { + auto it = stowed_messages->find(ProtoPathJoin(proto_path)); + return (it != stowed_messages->end()) ? it->second.get() : nullptr; +} + +const Message* GetNestedMessage(const Message& message, + const FieldDescriptor* field, + ProtoPath proto_path, + MessageMap* stowed_messages) { + if (field->type() != FieldDescriptor::TYPE_MESSAGE) { + return nullptr; + } + const Message* result = FindStowedMessage(stowed_messages, proto_path); + if (!result) { + result = GetFieldMessage(message, field, proto_path.back().index); + } + return result; +} + // Adjusts map-entries from indexes to keys. // Protobuf map-entry order is intentionally not preserved. -mediapipe::Status KeyProtoMapEntries(Message* source) { +absl::Status KeyProtoMapEntries(Message* source, MessageMap* stowed_messages) { // Copy the rules from the source CalculatorGraphTemplate. mediapipe::CalculatorGraphTemplate rules; rules.ParsePartialFromString(source->SerializePartialAsString()); @@ -1489,11 +1584,14 @@ mediapipe::Status KeyProtoMapEntries(Message* source) { MP_RETURN_IF_ERROR(ProtoPathSplit(rule->path(), &path)); for (int j = 0; j < path.size(); ++j) { int field_id = path[j].field_id; - int field_index = path[j].index; const FieldDescriptor* field = FindFieldByNumber(message, field_id); + ProtoPath prefix = {path.begin(), path.begin() + j + 1}; + message = GetNestedMessage(*message, field, prefix, stowed_messages); + if (!message) { + break; + } if (field->is_map()) { - const Message* map_entry = - GetFieldMessage(*message, field, path[j].index); + const Message* map_entry = message; int key_id = map_entry->GetDescriptor()->FindFieldByName("key")->number(); FieldType key_type = static_cast( @@ -1501,10 +1599,6 @@ mediapipe::Status KeyProtoMapEntries(Message* source) { std::string key_value = GetMapKey(*map_entry); path[j] = {field_id, key_id, key_type, key_value}; } - message = GetFieldMessage(*message, field, field_index); - if (!message) { - break; - } } if (!rule->path().empty()) { *rule->mutable_path() = ProtoPathJoin(path); @@ -1539,7 +1633,7 @@ class TemplateParser::Parser::MediaPipeParserImpl // Copy the template rules into the output template "rule" field. success &= MergeFields(template_rules_, output).ok(); // Replace map-entry indexes with map keys. - success &= KeyProtoMapEntries(output).ok(); + success &= KeyProtoMapEntries(output, &stowed_messages_).ok(); return success; } @@ -1565,7 +1659,7 @@ class TemplateParser::Parser::MediaPipeParserImpl DO(ConsumeFieldTemplate(message)); } else { DO(ConsumeField(message)); - StowFieldValue(message, expression); + StowFieldValue(message, expression, &stowed_messages_); } DO(ConsumeEndTemplate()); return true; @@ -1776,6 +1870,7 @@ class TemplateParser::Parser::MediaPipeParserImpl } mediapipe::CalculatorGraphTemplate template_rules_; + std::map> stowed_messages_; }; #undef DO From d5ff060bfa6930b9b6b1826b43ca0434b69050a9 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 13 Dec 2022 16:01:06 -0800 Subject: [PATCH 62/64] Internal change PiperOrigin-RevId: 495149484 --- mediapipe/graphs/object_detection_3d/calculators/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/graphs/object_detection_3d/calculators/BUILD b/mediapipe/graphs/object_detection_3d/calculators/BUILD index 783fff187..d4c5c496b 100644 --- a/mediapipe/graphs/object_detection_3d/calculators/BUILD +++ b/mediapipe/graphs/object_detection_3d/calculators/BUILD @@ -22,6 +22,7 @@ package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "gl_animation_overlay_calculator_proto", srcs = ["gl_animation_overlay_calculator.proto"], + def_options_lib = False, visibility = ["//visibility:public"], exports = [ "//mediapipe/gpu:gl_animation_overlay_calculator_proto", From 904a537b027a98c69c744cd4944f06e63c3e882d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 13 Dec 2022 16:08:54 -0800 Subject: [PATCH 63/64] Internal change PiperOrigin-RevId: 495151410 --- mediapipe/framework/BUILD | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 265ae9c6f..872944acd 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -57,12 +57,12 @@ mediapipe_proto_library( srcs = ["calculator.proto"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:mediapipe_options_proto", - "//mediapipe/framework:packet_factory_proto", - "//mediapipe/framework:packet_generator_proto", - "//mediapipe/framework:status_handler_proto", - "//mediapipe/framework:stream_handler_proto", + ":calculator_options_proto", + ":mediapipe_options_proto", + ":packet_factory_proto", + ":packet_generator_proto", + ":status_handler_proto", + ":stream_handler_proto", "@com_google_protobuf//:any_proto", ], ) @@ -79,8 +79,8 @@ mediapipe_proto_library( srcs = ["calculator_contract_test.proto"], visibility = ["//mediapipe/framework:__subpackages__"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", + ":calculator_options_proto", + ":calculator_proto", ], ) @@ -89,8 +89,8 @@ mediapipe_proto_library( srcs = ["calculator_profile.proto"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", + ":calculator_options_proto", + ":calculator_proto", ], ) @@ -126,14 +126,14 @@ mediapipe_proto_library( name = "status_handler_proto", srcs = ["status_handler.proto"], visibility = [":mediapipe_internal"], - deps = ["//mediapipe/framework:mediapipe_options_proto"], + deps = [":mediapipe_options_proto"], ) mediapipe_proto_library( name = "stream_handler_proto", srcs = ["stream_handler.proto"], visibility = [":mediapipe_internal"], - deps = ["//mediapipe/framework:mediapipe_options_proto"], + deps = [":mediapipe_options_proto"], ) mediapipe_proto_library( @@ -142,8 +142,8 @@ mediapipe_proto_library( srcs = ["test_calculators.proto"], visibility = [":mediapipe_internal"], deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", + ":calculator_options_proto", + ":calculator_proto", ], ) @@ -151,7 +151,7 @@ mediapipe_proto_library( name = "thread_pool_executor_proto", srcs = ["thread_pool_executor.proto"], visibility = [":mediapipe_internal"], - deps = ["//mediapipe/framework:mediapipe_options_proto"], + deps = [":mediapipe_options_proto"], ) # It is for pure-native Android builds where the library can't have any dependency on libandroid.so From b9d020cb7d32e936943b963c401cc3aeb9f88407 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Tue, 13 Dec 2022 16:58:12 -0800 Subject: [PATCH 64/64] Internal change PiperOrigin-RevId: 495163109 --- mediapipe/framework/scheduler.cc | 11 ++++++++--- mediapipe/framework/scheduler.h | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mediapipe/framework/scheduler.cc b/mediapipe/framework/scheduler.cc index afef4f383..854c10fd5 100644 --- a/mediapipe/framework/scheduler.cc +++ b/mediapipe/framework/scheduler.cc @@ -117,7 +117,7 @@ void Scheduler::SubmitWaitingTasksOnQueues() { // Note: state_mutex_ is held when this function is entered or // exited. void Scheduler::HandleIdle() { - if (handling_idle_) { + if (++handling_idle_ > 1) { // Someone is already inside this method. // Note: This can happen in the sections below where we unlock the mutex // and make more nodes runnable: the nodes can run and become idle again @@ -127,7 +127,6 @@ void Scheduler::HandleIdle() { VLOG(2) << "HandleIdle: already in progress"; return; } - handling_idle_ = true; while (IsIdle() && (state_ == STATE_RUNNING || state_ == STATE_CANCELLING)) { // Remove active sources that are closed. @@ -165,11 +164,17 @@ void Scheduler::HandleIdle() { } } + // If HandleIdle has been called again, then continue scheduling. + if (handling_idle_ > 1) { + handling_idle_ = 1; + continue; + } + // Nothing left to do. break; } - handling_idle_ = false; + handling_idle_ = 0; } // Note: state_mutex_ is held when this function is entered or exited. diff --git a/mediapipe/framework/scheduler.h b/mediapipe/framework/scheduler.h index dd1572d99..b59467b9f 100644 --- a/mediapipe/framework/scheduler.h +++ b/mediapipe/framework/scheduler.h @@ -302,7 +302,7 @@ class Scheduler { // - We need it to be reentrant, which Mutex does not support. // - We want simultaneous calls to return immediately instead of waiting, // and Mutex's TryLock is not guaranteed to work. - bool handling_idle_ ABSL_GUARDED_BY(state_mutex_) = false; + int handling_idle_ ABSL_GUARDED_BY(state_mutex_) = 0; // Mutex for the scheduler state and related things. // Note: state_ is declared as atomic so that its getter methods don't need