From 7ce21038bb124a8bab79b47c716458089292b405 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 4 Jan 2023 13:40:17 +0530 Subject: [PATCH] Merge branch 'master' into ios-task --- .github/bot_config.yml | 3 +- ...low_session_from_saved_model_calculator.cc | 7 +- ..._session_from_saved_model_calculator.proto | 4 +- ...flow_session_from_saved_model_generator.cc | 7 +- ...w_session_from_saved_model_generator.proto | 4 +- mediapipe/examples/desktop/autoflip/BUILD | 4 + mediapipe/framework/api2/builder.h | 2 +- mediapipe/framework/api2/packet.h | 2 +- mediapipe/framework/formats/BUILD | 2 +- mediapipe/framework/formats/tensor.h | 7 +- mediapipe/framework/formats/tensor_ahwb.cc | 10 +- .../framework/formats/tensor_ahwb_gpu_test.cc | 16 +- .../framework/formats/tensor_ahwb_test.cc | 94 ++++++++--- .../framework/profiler/graph_profiler.cc | 1 + mediapipe/framework/profiler/graph_profiler.h | 9 ++ .../framework/profiler/graph_profiler_test.cc | 26 +++ .../gpu/gpu_buffer_storage_cv_pixel_buffer.cc | 75 +++++---- .../text_classifier/text_classifier_test.py | 7 +- .../image_classifier/image_classifier_test.py | 5 +- mediapipe/objc/MPPGraph.mm | 9 +- .../gesture_recognizer/gesture_recognizer.cc | 4 +- mediapipe/tasks/ios/common/utils/BUILD | 1 - .../ios/common/utils/sources/MPPCommonUtils.h | 2 + .../common/utils/sources/MPPCommonUtils.mm | 5 +- .../common/utils/sources/NSString+Helpers.h | 3 +- .../tasks/ios/components/processors/BUILD | 1 - .../processors/sources/MPPClassifierOptions.h | 21 ++- .../processors/sources/MPPClassifierOptions.m | 4 +- .../ios/components/processors/utils/BUILD | 9 +- .../sources/MPPClassifierOptions+Helpers.h | 1 + .../sources/MPPClassifierOptions+Helpers.mm | 3 +- mediapipe/tasks/ios/core/BUILD | 18 ++- .../tasks/ios/core/sources/MPPTaskInfo.h | 4 +- .../tasks/ios/core/sources/MPPTaskInfo.mm | 1 - .../ios/core/sources/MPPTaskOptionsProtocol.h | 3 +- .../tasks/ios/core/sources/MPPTaskRunner.h | 8 +- .../tasks/ios/core/sources/MPPTaskRunner.mm | 5 +- .../audio_classifier/audio_classifier.ts | 7 +- .../audio_classifier/audio_classifier_test.ts | 3 +- .../audio/audio_embedder/audio_embedder.ts | 7 +- .../audio_embedder/audio_embedder_test.ts | 3 +- .../tasks/web/components/processors/BUILD | 26 --- .../processors/base_options.test.ts | 127 --------------- .../web/components/processors/base_options.ts | 80 ---------- mediapipe/tasks/web/core/BUILD | 5 +- mediapipe/tasks/web/core/task_runner.ts | 75 ++++++++- mediapipe/tasks/web/core/task_runner_test.ts | 148 +++++++++++++++++- .../tasks/web/core/task_runner_test_utils.ts | 4 +- .../text/text_classifier/text_classifier.ts | 7 +- .../text_classifier/text_classifier_test.ts | 3 +- .../web/text/text_embedder/text_embedder.ts | 7 +- .../text/text_embedder/text_embedder_test.ts | 3 +- mediapipe/tasks/web/vision/core/BUILD | 1 + .../vision/core/vision_task_runner.test.ts | 32 ++-- .../web/vision/core/vision_task_runner.ts | 4 +- .../gesture_recognizer/gesture_recognizer.ts | 43 +++-- .../gesture_recognizer_result.d.ts | 8 +- .../gesture_recognizer_test.ts | 26 ++- .../vision/hand_landmarker/hand_landmarker.ts | 8 +- .../hand_landmarker_result.d.ts | 2 + .../hand_landmarker/hand_landmarker_test.ts | 3 +- .../image_classifier/image_classifier.ts | 7 +- .../image_classifier/image_classifier_test.ts | 3 +- .../vision/image_embedder/image_embedder.ts | 7 +- .../image_embedder/image_embedder_test.ts | 3 +- .../vision/object_detector/object_detector.ts | 8 +- .../object_detector/object_detector_test.ts | 3 +- 67 files changed, 593 insertions(+), 457 deletions(-) delete mode 100644 mediapipe/tasks/web/components/processors/base_options.test.ts delete mode 100644 mediapipe/tasks/web/components/processors/base_options.ts diff --git a/.github/bot_config.yml b/.github/bot_config.yml index 8ad724168..74a60e4b9 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -15,4 +15,5 @@ # A list of assignees assignees: - - sureshdagooglecom + - kuaashish + - ayushgdev diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc index 922eb9d50..18bddbbe3 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc @@ -55,7 +55,7 @@ absl::Status GetLatestDirectory(std::string* path) { } // If options.convert_signature_to_tags() is set, will convert letters to -// uppercase and replace /'s and -'s with _'s. This enables the standard +// uppercase and replace /, -, . and :'s with _'s. This enables the standard // SavedModel classification, regression, and prediction signatures to be used // as uppercase INPUTS and OUTPUTS tags for streams and supports other common // patterns. @@ -67,9 +67,8 @@ const std::string MaybeConvertSignatureToTag( output.resize(name.length()); std::transform(name.begin(), name.end(), output.begin(), [](unsigned char c) { return std::toupper(c); }); - output = absl::StrReplaceAll(output, {{"/", "_"}}); - output = absl::StrReplaceAll(output, {{"-", "_"}}); - output = absl::StrReplaceAll(output, {{".", "_"}}); + output = absl::StrReplaceAll( + output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}}); LOG(INFO) << "Renamed TAG from: " << name << " to " << output; return output; } else { diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto index 927d3b51f..515b46fa9 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto @@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelCalculatorOptions { // The name of the generic signature to load into the mapping from tags to // tensor names. optional string signature_name = 2 [default = "serving_default"]; - // Whether to convert the signature keys to uppercase as well as switch /'s - // and -'s to _'s, which enables common signatures to be used as Tags. + // Whether to convert the signature keys to uppercase as well as switch + // /, -, .and :'s to _'s, which enables common signatures to be used as Tags. optional bool convert_signature_to_tags = 3 [default = true]; // If true, saved_model_path can have multiple exported models in // subdirectories saved_model_path/%08d and the alphabetically last (i.e., diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc index d5236f1cc..ee69ec56a 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc @@ -61,7 +61,7 @@ absl::Status GetLatestDirectory(std::string* path) { } // If options.convert_signature_to_tags() is set, will convert letters to -// uppercase and replace /'s and -'s with _'s. This enables the standard +// uppercase and replace /, -, and .'s with _'s. This enables the standard // SavedModel classification, regression, and prediction signatures to be used // as uppercase INPUTS and OUTPUTS tags for streams and supports other common // patterns. @@ -73,9 +73,8 @@ const std::string MaybeConvertSignatureToTag( output.resize(name.length()); std::transform(name.begin(), name.end(), output.begin(), [](unsigned char c) { return std::toupper(c); }); - output = absl::StrReplaceAll(output, {{"/", "_"}}); - output = absl::StrReplaceAll(output, {{"-", "_"}}); - output = absl::StrReplaceAll(output, {{".", "_"}}); + output = absl::StrReplaceAll( + output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}}); LOG(INFO) << "Renamed TAG from: " << name << " to " << output; return output; } else { diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto index d24a1cd73..d45fcb662 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto @@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelGeneratorOptions { // The name of the generic signature to load into the mapping from tags to // tensor names. optional string signature_name = 2 [default = "serving_default"]; - // Whether to convert the signature keys to uppercase as well as switch /'s - // and -'s to _'s, which enables common signatures to be used as Tags. + // Whether to convert the signature keys to uppercase, as well as switch /'s + // -'s, .'s, and :'s to _'s, enabling common signatures to be used as Tags. optional bool convert_signature_to_tags = 3 [default = true]; // If true, saved_model_path can have multiple exported models in // subdirectories saved_model_path/%08d and the alphabetically last (i.e., diff --git a/mediapipe/examples/desktop/autoflip/BUILD b/mediapipe/examples/desktop/autoflip/BUILD index 562f11c49..0e28746dc 100644 --- a/mediapipe/examples/desktop/autoflip/BUILD +++ b/mediapipe/examples/desktop/autoflip/BUILD @@ -30,6 +30,10 @@ proto_library( java_lite_proto_library( name = "autoflip_messages_java_proto_lite", + visibility = [ + "//java/com/google/android/apps/photos:__subpackages__", + "//javatests/com/google/android/apps/photos:__subpackages__", + ], deps = [ ":autoflip_messages_proto", ], diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 19273bf44..2a98c4166 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -398,7 +398,7 @@ template class Node; #if __cplusplus >= 201703L // Deduction guide to silence -Wctad-maybe-unsupported. -explicit Node()->Node; +explicit Node() -> Node; #endif // C++17 template <> diff --git a/mediapipe/framework/api2/packet.h b/mediapipe/framework/api2/packet.h index 7933575d3..b1ebb0410 100644 --- a/mediapipe/framework/api2/packet.h +++ b/mediapipe/framework/api2/packet.h @@ -181,7 +181,7 @@ template class Packet; #if __cplusplus >= 201703L // Deduction guide to silence -Wctad-maybe-unsupported. -explicit Packet()->Packet; +explicit Packet() -> Packet; #endif // C++17 template <> diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index f5a043f10..cce7e5bd0 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -455,7 +455,7 @@ cc_library( ], }), deps = [ - "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 8a6f02e9d..0f19bb5ee 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -24,7 +24,7 @@ #include #include -#include "absl/container/flat_hash_set.h" +#include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/formats/tensor_internal.h" #include "mediapipe/framework/port.h" @@ -434,8 +434,9 @@ class Tensor { mutable bool use_ahwb_ = false; mutable uint64_t ahwb_tracking_key_ = 0; // TODO: Tracks all unique tensors. Can grow to a large number. LRU - // can be more predicted. - static inline absl::flat_hash_set ahwb_usage_track_; + // (Least Recently Used) can be more predicted. + // The value contains the size alignment parameter. + static inline absl::flat_hash_map ahwb_usage_track_; // Expects the target SSBO to be already bound. bool AllocateAhwbMapToSsbo() const; bool InsertAhwbToSsboFence() const; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 466811be7..525f05f31 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -266,7 +266,12 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView( bool Tensor::AllocateAHardwareBuffer(int size_alignment) const { // Mark current tracking key as Ahwb-use. - ahwb_usage_track_.insert(ahwb_tracking_key_); + if (auto it = ahwb_usage_track_.find(ahwb_tracking_key_); + it != ahwb_usage_track_.end()) { + size_alignment = it->second; + } else if (ahwb_tracking_key_ != 0) { + ahwb_usage_track_.insert({ahwb_tracking_key_, size_alignment}); + } use_ahwb_ = true; if (__builtin_available(android 26, *)) { @@ -458,7 +463,8 @@ void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const { ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim); } } - use_ahwb_ = ahwb_usage_track_.contains(ahwb_tracking_key_); + // Keep flag value if it was set previously. + use_ahwb_ = use_ahwb_ || ahwb_usage_track_.contains(ahwb_tracking_key_); } #else // MEDIAPIPE_TENSOR_USE_AHWB diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index a6ca00949..e2ad869f9 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -92,9 +92,14 @@ class TensorAhwbGpuTest : public mediapipe::GpuTestBase { }; TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { - Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + { + // Request Ahwb first to get Ahwb storage allocated internally. + auto view = tensor.GetAHardwareBufferWriteView(); + EXPECT_NE(view.handle(), nullptr); + view.SetWritingFinishedFD(-1, [](bool) { return true; }); + } RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); @@ -114,9 +119,14 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { } TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { - Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})}; + { + // Request Ahwb first to get Ahwb storage allocated internally. + auto view = tensor.GetAHardwareBufferWriteView(); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); + } RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); @@ -139,7 +149,6 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { // Request the CPU view to get the memory to be allocated. // Request Ahwb view then to transform the storage into Ahwb. - Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; { @@ -168,7 +177,6 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) { // Request the GPU view to get the ssbo allocated internally. // Request Ahwb view then to transform the storage into Ahwb. - Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; RunInGlContext([&tensor] { diff --git a/mediapipe/framework/formats/tensor_ahwb_test.cc b/mediapipe/framework/formats/tensor_ahwb_test.cc index 7ab5a4925..3da6ca8d3 100644 --- a/mediapipe/framework/formats/tensor_ahwb_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_test.cc @@ -1,34 +1,28 @@ #include "mediapipe/framework/formats/tensor.h" -#include "mediapipe/gpu/gpu_test_base.h" #include "testing/base/public/gmock.h" #include "testing/base/public/gunit.h" -#ifdef MEDIAPIPE_TENSOR_USE_AHWB -#if !MEDIAPIPE_DISABLE_GPU - namespace mediapipe { -class TensorAhwbTest : public mediapipe::GpuTestBase { - public: -}; - -TEST_F(TensorAhwbTest, TestCpuThenAHWB) { +TEST(TensorAhwbTest, TestCpuThenAHWB) { Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); { auto ptr = tensor.GetCpuWriteView().buffer(); EXPECT_NE(ptr, nullptr); } { - auto ahwb = tensor.GetAHardwareBufferReadView().handle(); - EXPECT_NE(ahwb, nullptr); + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); } } -TEST_F(TensorAhwbTest, TestAHWBThenCpu) { +TEST(TensorAhwbTest, TestAHWBThenCpu) { Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); { - auto ahwb = tensor.GetAHardwareBufferWriteView().handle(); - EXPECT_NE(ahwb, nullptr); + auto view = tensor.GetAHardwareBufferWriteView(); + EXPECT_NE(view.handle(), nullptr); + view.SetWritingFinishedFD(-1, [](bool) { return true; }); } { auto ptr = tensor.GetCpuReadView().buffer(); @@ -36,21 +30,71 @@ TEST_F(TensorAhwbTest, TestAHWBThenCpu) { } } -TEST_F(TensorAhwbTest, TestCpuThenGl) { - RunInGlContext([] { - Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); +TEST(TensorAhwbTest, TestAhwbAlignment) { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{5}); + { + auto view = tensor.GetAHardwareBufferWriteView(16); + EXPECT_NE(view.handle(), nullptr); + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc; + AHardwareBuffer_describe(view.handle(), &desc); + // sizeof(float) * 5 = 20, the closest aligned to 16 size is 32. + EXPECT_EQ(desc.width, 32); + } + view.SetWritingFinishedFD(-1, [](bool) { return true; }); + } +} + +// Tensor::GetCpuView uses source location mechanism that gives source file name +// and line from where the method is called. The function is intended just to +// have two calls providing the same source file name and line. +auto GetCpuView(const Tensor &tensor) { return tensor.GetCpuWriteView(); } + +// The test checks the tracking mechanism: when a tensor's Cpu view is retrieved +// for the first time then the source location is attached to the tensor. If the +// Ahwb view is requested then from the tensor then the previously recorded Cpu +// view request source location is marked for using Ahwb storage. +// When a Cpu view with the same source location (but for the newly allocated +// tensor) is requested and the location is marked to use Ahwb storage then the +// Ahwb storage is allocated for the CpuView. +TEST(TensorAhwbTest, TestTrackingAhwb) { + // Create first tensor and request Cpu and then Ahwb view to mark the source + // location for Ahwb storage. + { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9}); { - auto ptr = tensor.GetCpuWriteView().buffer(); - EXPECT_NE(ptr, nullptr); + auto view = GetCpuView(tensor); + EXPECT_NE(view.buffer(), nullptr); } { - auto ssbo = tensor.GetOpenGlBufferReadView().name(); - EXPECT_GT(ssbo, 0); + // Align size of the Ahwb by multiple of 16. + auto view = tensor.GetAHardwareBufferWriteView(16); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); } - }); + } + { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9}); + { + // The second tensor uses the same Cpu view source location so Ahwb + // storage is allocated internally. + auto view = GetCpuView(tensor); + EXPECT_NE(view.buffer(), nullptr); + } + { + // Check the Ahwb size to be aligned to multiple of 16. The alignment is + // stored by previous requesting of the Ahwb view. + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc; + AHardwareBuffer_describe(view.handle(), &desc); + // sizeof(float) * 9 = 36. The closest aligned size is 48. + EXPECT_EQ(desc.width, 48); + } + view.SetReadingFinishedFunc([](bool) { return true; }); + } + } } } // namespace mediapipe - -#endif // !MEDIAPIPE_DISABLE_GPU -#endif // MEDIAPIPE_TENSOR_USE_AHWB diff --git a/mediapipe/framework/profiler/graph_profiler.cc b/mediapipe/framework/profiler/graph_profiler.cc index f14acfc78..6aead5250 100644 --- a/mediapipe/framework/profiler/graph_profiler.cc +++ b/mediapipe/framework/profiler/graph_profiler.cc @@ -194,6 +194,7 @@ void GraphProfiler::Initialize( "Calculator \"$0\" has already been added.", node_name); } profile_builder_ = std::make_unique(this); + graph_id_ = ++next_instance_id_; is_initialized_ = true; } diff --git a/mediapipe/framework/profiler/graph_profiler.h b/mediapipe/framework/profiler/graph_profiler.h index 23caed4ec..6358cb057 100644 --- a/mediapipe/framework/profiler/graph_profiler.h +++ b/mediapipe/framework/profiler/graph_profiler.h @@ -237,6 +237,9 @@ class GraphProfiler : public std::enable_shared_from_this { return validated_graph_; } + // Gets a numerical identifier for this GraphProfiler object. + uint64_t GetGraphId() { return graph_id_; } + private: // This can be used to add packet info for the input streams to the graph. // It treats the stream defined by |stream_name| as a stream produced by a @@ -357,6 +360,12 @@ class GraphProfiler : public std::enable_shared_from_this { class GraphProfileBuilder; std::unique_ptr profile_builder_; + // The globally incrementing identifier for all graphs in a process. + static inline std::atomic_int next_instance_id_ = 0; + + // A unique identifier for this object. Only unique within a process. + uint64_t graph_id_; + // For testing. friend GraphProfilerTestPeer; }; diff --git a/mediapipe/framework/profiler/graph_profiler_test.cc b/mediapipe/framework/profiler/graph_profiler_test.cc index 81ba90cda..75d1c7ebd 100644 --- a/mediapipe/framework/profiler/graph_profiler_test.cc +++ b/mediapipe/framework/profiler/graph_profiler_test.cc @@ -442,6 +442,32 @@ TEST_F(GraphProfilerTestPeer, InitializeMultipleTimes) { "Cannot initialize .* multiple times."); } +// Tests that graph identifiers are not reused, even after destruction. +TEST_F(GraphProfilerTestPeer, InitializeMultipleProfilers) { + auto raw_graph_config = R"( + profiler_config { + enable_profiler: true + } + input_stream: "input_stream" + node { + calculator: "DummyTestCalculator" + input_stream: "input_stream" + })"; + const int n_iterations = 100; + absl::flat_hash_set seen_ids; + for (int i = 0; i < n_iterations; ++i) { + std::shared_ptr profiler = + std::make_shared(); + auto graph_config = CreateGraphConfig(raw_graph_config); + mediapipe::ValidatedGraphConfig validated_graph; + QCHECK_OK(validated_graph.Initialize(graph_config)); + profiler->Initialize(validated_graph); + + int id = profiler->GetGraphId(); + ASSERT_THAT(seen_ids, testing::Not(testing::Contains(id))); + seen_ids.insert(id); + } +} // Tests that Pause(), Resume(), and Reset() works. TEST_F(GraphProfilerTestPeer, PauseResumeReset) { InitializeProfilerWithGraphConfig(R"( diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc index 014cc1c69..7cac32b7f 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc @@ -74,42 +74,51 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView( static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer, const GlTextureView& view) { CHECK(pixel_buffer); - CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); - CHECK(err == kCVReturnSuccess) - << "CVPixelBufferLockBaseAddress failed: " << err; - OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); - size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer); - uint8_t* pixel_ptr = - static_cast(CVPixelBufferGetBaseAddress(pixel_buffer)); - if (pixel_format == kCVPixelFormatType_32BGRA) { - // TODO: restore previous framebuffer? Move this to helper so we - // can use BindFramebuffer? - glViewport(0, 0, view.width(), view.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), - view.name(), 0); + auto ctx = GlContext::GetCurrent().get(); + if (!ctx) ctx = view.gl_context(); + ctx->Run([pixel_buffer, &view, ctx] { + CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); + CHECK(err == kCVReturnSuccess) + << "CVPixelBufferLockBaseAddress failed: " << err; + OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); + size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer); + uint8_t* pixel_ptr = + static_cast(CVPixelBufferGetBaseAddress(pixel_buffer)); + if (pixel_format == kCVPixelFormatType_32BGRA) { + glBindFramebuffer(GL_FRAMEBUFFER, kUtilityFramebuffer.Get(*ctx)); + glViewport(0, 0, view.width(), view.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + view.target(), view.name(), 0); - size_t contiguous_bytes_per_row = view.width() * 4; - if (bytes_per_row == contiguous_bytes_per_row) { - glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE, - pixel_ptr); - } else { - std::vector contiguous_buffer(contiguous_bytes_per_row * - view.height()); - uint8_t* temp_ptr = contiguous_buffer.data(); - glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE, - temp_ptr); - for (int i = 0; i < view.height(); ++i) { - memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row); - temp_ptr += contiguous_bytes_per_row; - pixel_ptr += bytes_per_row; + size_t contiguous_bytes_per_row = view.width() * 4; + if (bytes_per_row == contiguous_bytes_per_row) { + glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, + GL_UNSIGNED_BYTE, pixel_ptr); + } else { + // TODO: use GL_PACK settings for row length. We can expect + // GLES 3.0 on iOS now. + std::vector contiguous_buffer(contiguous_bytes_per_row * + view.height()); + uint8_t* temp_ptr = contiguous_buffer.data(); + glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, + GL_UNSIGNED_BYTE, temp_ptr); + for (int i = 0; i < view.height(); ++i) { + memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row); + temp_ptr += contiguous_bytes_per_row; + pixel_ptr += bytes_per_row; + } } + // TODO: restore previous framebuffer? + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + view.target(), 0, 0); + glBindFramebuffer(GL_FRAMEBUFFER, 0); + } else { + LOG(ERROR) << "unsupported pixel format: " << pixel_format; } - } else { - LOG(ERROR) << "unsupported pixel format: " << pixel_format; - } - err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); - CHECK(err == kCVReturnSuccess) - << "CVPixelBufferUnlockBaseAddress failed: " << err; + err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); + CHECK(err == kCVReturnSuccess) + << "CVPixelBufferUnlockBaseAddress failed: " << err; + }); } #endif // TARGET_IPHONE_SIMULATOR diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index 7a30d19fd..eb4443b44 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -71,9 +71,12 @@ class TextClassifierTest(tf.test.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) + filecmp.clear_cache() self.assertTrue( - filecmp.cmp(output_metadata_file, - self._AVERAGE_WORD_EMBEDDING_JSON_FILE)) + filecmp.cmp( + output_metadata_file, + self._AVERAGE_WORD_EMBEDDING_JSON_FILE, + shallow=False)) def test_create_and_train_bert(self): train_data, validation_data = self._get_data() diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index 6ca21d334..afda8643b 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -135,7 +135,10 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) - self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file)) + filecmp.clear_cache() + self.assertTrue( + filecmp.cmp( + output_metadata_file, expected_metadata_file, shallow=False)) def test_continual_training_by_loading_checkpoint(self): mock_stdout = io.StringIO() diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.mm index 1bd177e80..3123eb863 100644 --- a/mediapipe/objc/MPPGraph.mm +++ b/mediapipe/objc/MPPGraph.mm @@ -230,16 +230,17 @@ if ([wrapper.delegate } - (absl::Status)performStart { - absl::Status status = _graph->Initialize(_config); - if (!status.ok()) { - return status; - } + absl::Status status; for (const auto& service_packet : _servicePackets) { status = _graph->SetServicePacket(*service_packet.first, service_packet.second); if (!status.ok()) { return status; } } + status = _graph->Initialize(_config); + if (!status.ok()) { + return status; + } status = _graph->StartRun(_inputSidePackets, _streamHeaders); if (!status.ok()) { return status; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index 01f444742..91a5ec213 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -151,11 +151,11 @@ ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) { auto custom_gestures_classifier_options_proto = std::make_unique( components::processors::ConvertClassifierOptionsToProto( - &(options->canned_gestures_classifier_options))); + &(options->custom_gestures_classifier_options))); hand_gesture_recognizer_graph_options ->mutable_custom_gesture_classifier_graph_options() ->mutable_classifier_options() - ->Swap(canned_gestures_classifier_options_proto.get()); + ->Swap(custom_gestures_classifier_options_proto.get()); return options_proto; } diff --git a/mediapipe/tasks/ios/common/utils/BUILD b/mediapipe/tasks/ios/common/utils/BUILD index f2ffda39e..a29c700da 100644 --- a/mediapipe/tasks/ios/common/utils/BUILD +++ b/mediapipe/tasks/ios/common/utils/BUILD @@ -38,4 +38,3 @@ objc_library( "-std=c++17", ], ) - diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h index 407d87aba..5404a074d 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h @@ -13,6 +13,7 @@ // limitations under the License. #import + #include "mediapipe/tasks/cc/common.h" NS_ASSUME_NONNULL_BEGIN @@ -56,6 +57,7 @@ extern NSString *const MPPTasksErrorDomain; * @param status absl::Status. * @param error Pointer to the memory location where the created error should be saved. If `nil`, * no error will be saved. + * @return YES when there is no error, NO otherwise. */ + (BOOL)checkCppError:(const absl::Status &)status toError:(NSError **)error; diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 4d4880a87..8234ac6d3 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -20,7 +20,6 @@ #include "absl/status/status.h" // from @com_google_absl #include "absl/strings/cord.h" // from @com_google_absl - #include "mediapipe/tasks/cc/common.h" /** Error domain of MediaPipe task library errors. */ @@ -96,8 +95,8 @@ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks"; // appropriate MPPTasksErrorCode in default cases. Note: // The mapping to absl::Status::code() is done to generate a more specific error code than // MPPTasksErrorCodeError in cases when the payload can't be mapped to - // MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn returned - // without modification by Mediapipe cc library methods. + // MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn + // returned without modification by Mediapipe cc library methods. if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) { switch (status.code()) { case absl::StatusCode::kInternal: diff --git a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h index aac7485da..66f9c5ccc 100644 --- a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h +++ b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h @@ -13,13 +13,14 @@ // limitations under the License. #import + #include NS_ASSUME_NONNULL_BEGIN @interface NSString (Helpers) -@property(readonly) std::string cppString; +@property(readonly, nonatomic) std::string cppString; + (NSString *)stringWithCppString:(std::string)text; diff --git a/mediapipe/tasks/ios/components/processors/BUILD b/mediapipe/tasks/ios/components/processors/BUILD index 6d1cfdf59..165145076 100644 --- a/mediapipe/tasks/ios/components/processors/BUILD +++ b/mediapipe/tasks/ios/components/processors/BUILD @@ -21,4 +21,3 @@ objc_library( srcs = ["sources/MPPClassifierOptions.m"], hdrs = ["sources/MPPClassifierOptions.h"], ) - diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index 7bf5744f7..13dca4030 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -22,29 +22,34 @@ NS_ASSUME_NONNULL_BEGIN NS_SWIFT_NAME(ClassifierOptions) @interface MPPClassifierOptions : NSObject -/** The locale to use for display names specified through the TFLite Model - * Metadata, if any. Defaults to English. +/** + * The locale to use for display names specified through the TFLite Model + * Metadata, if any. Defaults to English. */ @property(nonatomic, copy) NSString *displayNamesLocale; -/** The maximum number of top-scored classification results to return. If < 0, +/** + * The maximum number of top-scored classification results to return. If < 0, * all available results will be returned. If 0, an invalid argument error is - * returned. + * returned. */ @property(nonatomic) NSInteger maxResults; -/** Score threshold to override the one provided in the model metadata (if any). - * Results below this value are rejected. +/** + * Score threshold to override the one provided in the model metadata (if any). + * Results below this value are rejected. */ @property(nonatomic) float scoreThreshold; -/** The allowlist of category names. If non-empty, detection results whose +/** + * The allowlist of category names. If non-empty, detection results whose * category name is not in this set will be filtered out. Duplicate or unknown * category names are ignored. Mutually exclusive with categoryDenylist. */ @property(nonatomic, copy) NSArray *categoryAllowlist; -/** The denylist of category names. If non-empty, detection results whose +/** + * The denylist of category names. If non-empty, detection results whose * category name is in this set will be filtered out. Duplicate or unknown * category names are ignored. Mutually exclusive with categoryAllowlist. */ diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m index accb6c7dd..01f498184 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m @@ -19,8 +19,8 @@ - (instancetype)init { self = [super init]; if (self) { - self.maxResults = -1; - self.scoreThreshold = 0; + _maxResults = -1; + _scoreThreshold = 0; } return self; } diff --git a/mediapipe/tasks/ios/components/processors/utils/BUILD b/mediapipe/tasks/ios/components/processors/utils/BUILD index 820c6bb56..5344c5fdf 100644 --- a/mediapipe/tasks/ios/components/processors/utils/BUILD +++ b/mediapipe/tasks/ios/components/processors/utils/BUILD @@ -21,9 +21,8 @@ objc_library( srcs = ["sources/MPPClassifierOptions+Helpers.mm"], hdrs = ["sources/MPPClassifierOptions+Helpers.h"], deps = [ - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/ios/components/processors:MPPClassifierOptions", - "//mediapipe/tasks/ios/common/utils:NSStringHelpers", - ] + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/components/processors:MPPClassifierOptions", + ], ) - diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h index 6644a6255..e156020df 100644 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h @@ -13,6 +13,7 @@ // limitations under the License. #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" + #import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" NS_ASSUME_NONNULL_BEGIN diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm index efe9572e1..24b54fd6a 100644 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm @@ -23,13 +23,12 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto - (void)copyToProto:(ClassifierOptionsProto *)classifierOptionsProto { classifierOptionsProto->Clear(); - + if (self.displayNamesLocale) { classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); } classifierOptionsProto->set_max_results((int)self.maxResults); - classifierOptionsProto->set_score_threshold(self.scoreThreshold); for (NSString *category in self.categoryAllowlist) { diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index adc37d901..434d20085 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -54,14 +54,14 @@ objc_library( "-std=c++17", ], deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", - "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", ":MPPTaskOptions", ":MPPTaskOptionsProtocol", + "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/tasks/ios/common:MPPCommon", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:NSStringHelpers", - "//mediapipe/tasks/ios/common:MPPCommon", ], ) @@ -83,9 +83,13 @@ objc_library( name = "MPPTaskRunner", srcs = ["sources/MPPTaskRunner.mm"], hdrs = ["sources/MPPTaskRunner.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], deps = [ - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", ], ) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h index ae4c9eba1..b94e704d1 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h @@ -13,7 +13,9 @@ // limitations under the License. #import + #include "mediapipe/framework/calculator.pb.h" + #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" @@ -59,7 +61,7 @@ NS_ASSUME_NONNULL_BEGIN /** * Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance. */ -- (mediapipe::CalculatorGraphConfig)generateGraphConfig; +- (::mediapipe::CalculatorGraphConfig)generateGraphConfig; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm index be3c8cbf7..5f2290497 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm @@ -24,7 +24,6 @@ namespace { using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig; using Node = ::mediapipe::CalculatorGraphConfig::Node; -using ::mediapipe::CalculatorOptions; using ::mediapipe::FlowLimiterCalculatorOptions; using ::mediapipe::InputStreamInfo; } // namespace diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h index 44fba4c0b..c03165c1d 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h @@ -13,6 +13,7 @@ // limitations under the License. #import + #include "mediapipe/framework/calculator_options.pb.h" NS_ASSUME_NONNULL_BEGIN @@ -25,7 +26,7 @@ NS_ASSUME_NONNULL_BEGIN /** * Copies the iOS MediaPipe task options to an object of mediapipe::CalculatorOptions proto. */ -- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto; +- (void)copyToProto:(::mediapipe::CalculatorOptions *)optionsProto; @end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h index 6561e136d..2b9f2ecdb 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -22,7 +22,6 @@ NS_ASSUME_NONNULL_BEGIN /** * This class is used to create and call appropriate methods on the C++ Task Runner. */ - @interface MPPTaskRunner : NSObject /** @@ -35,11 +34,10 @@ NS_ASSUME_NONNULL_BEGIN - (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig error:(NSError **)error NS_DESIGNATED_INITIALIZER; -- (absl::StatusOr) - process:(const mediapipe::tasks::core::PacketMap &)packetMap - error:(NSError **)error; +- (absl::StatusOr)process: + (const mediapipe::tasks::core::PacketMap &)packetMap; -- (void)close; +- (absl::Status)close; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm index e08d0bc1b..c5c307fd5 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm @@ -17,7 +17,6 @@ namespace { using ::mediapipe::CalculatorGraphConfig; -using ::mediapipe::Packet; using ::mediapipe::tasks::core::PacketMap; using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; } // namespace @@ -49,8 +48,8 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; return _cppTaskRunner->Process(packetMap); } -- (void)close { - _cppTaskRunner->Close(); +- (absl::Status)close { + return _cppTaskRunner->Close(); } @end diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 7bfca680a..51573f50a 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -119,11 +119,10 @@ export class AudioClassifier extends AudioTaskRunner { * * @param options The options for the audio classifier. */ - override async setOptions(options: AudioClassifierOptions): Promise { - await super.setOptions(options); + override setOptions(options: AudioClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -171,7 +170,7 @@ export class AudioClassifier extends AudioTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM); diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts index d5c0a9429..2089f184f 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts @@ -79,7 +79,8 @@ describe('AudioClassifier', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); audioClassifier = new AudioClassifierFake(); - await audioClassifier.setOptions({}); // Initialize graph + await audioClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 246cba883..6a4b8ce39 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -121,11 +121,10 @@ export class AudioEmbedder extends AudioTaskRunner { * * @param options The options for the audio embedder. */ - override async setOptions(options: AudioEmbedderOptions): Promise { - await super.setOptions(options); + override setOptions(options: AudioEmbedderOptions): Promise { this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -171,7 +170,7 @@ export class AudioEmbedder extends AudioTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM); diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts index 2f605ff98..dde61a6e9 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts @@ -70,7 +70,8 @@ describe('AudioEmbedder', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); audioEmbedder = new AudioEmbedderFake(); - await audioEmbedder.setOptions({}); // Initialize graph + await audioEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', () => { diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index 148a08238..cab24293d 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -103,29 +103,3 @@ jasmine_node_test( name = "embedder_options_test", deps = [":embedder_options_test_lib"], ) - -mediapipe_ts_library( - name = "base_options", - srcs = [ - "base_options.ts", - ], - deps = [ - "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", - "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", - "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", - "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", - "//mediapipe/tasks/web/core", - ], -) - -mediapipe_ts_library( - name = "base_options_test_lib", - testonly = True, - srcs = ["base_options.test.ts"], - deps = [":base_options"], -) - -jasmine_node_test( - name = "base_options_test", - deps = [":base_options_test_lib"], -) diff --git a/mediapipe/tasks/web/components/processors/base_options.test.ts b/mediapipe/tasks/web/components/processors/base_options.test.ts deleted file mode 100644 index 6d58be68f..000000000 --- a/mediapipe/tasks/web/components/processors/base_options.test.ts +++ /dev/null @@ -1,127 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import 'jasmine'; - -// Placeholder for internal dependency on encodeByteArray -// Placeholder for internal dependency on trusted resource URL builder - -import {convertBaseOptionsToProto} from './base_options'; - -describe('convertBaseOptionsToProto()', () => { - const mockBytes = new Uint8Array([0, 1, 2, 3]); - const mockBytesResult = { - modelAsset: { - fileContent: Buffer.from(mockBytes).toString('base64'), - fileName: undefined, - fileDescriptorMeta: undefined, - filePointerMeta: undefined, - }, - useStreamMode: false, - acceleration: { - xnnpack: undefined, - gpu: undefined, - tflite: {}, - }, - }; - - let fetchSpy: jasmine.Spy; - - beforeEach(() => { - fetchSpy = jasmine.createSpy().and.callFake(async url => { - expect(url).toEqual('foo'); - return { - arrayBuffer: () => mockBytes.buffer, - } as unknown as Response; - }); - global.fetch = fetchSpy; - }); - - it('verifies that at least one model asset option is provided', async () => { - await expectAsync(convertBaseOptionsToProto({})) - .toBeRejectedWithError( - /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); - }); - - it('verifies that no more than one model asset option is provided', async () => { - await expectAsync(convertBaseOptionsToProto({ - modelAssetPath: `foo`, - modelAssetBuffer: new Uint8Array([]) - })) - .toBeRejectedWithError( - /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); - }); - - it('downloads model', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetPath: `foo`, - }); - - expect(fetchSpy).toHaveBeenCalled(); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); - - it('does not download model when bytes are provided', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - }); - - expect(fetchSpy).not.toHaveBeenCalled(); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); - - it('can enable CPU delegate', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'CPU', - }); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); - - it('can enable GPU delegate', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'GPU', - }); - expect(baseOptionsProto.toObject()).toEqual({ - ...mockBytesResult, - acceleration: { - xnnpack: undefined, - gpu: { - useAdvancedGpuApi: false, - api: 0, - allowPrecisionLoss: true, - cachedKernelPath: undefined, - serializedModelDir: undefined, - modelToken: undefined, - usage: 2, - }, - tflite: undefined, - }, - }); - }); - - it('can reset delegate', async () => { - let baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'GPU', - }); - // Clear backend - baseOptionsProto = - await convertBaseOptionsToProto({delegate: undefined}, baseOptionsProto); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); -}); diff --git a/mediapipe/tasks/web/components/processors/base_options.ts b/mediapipe/tasks/web/components/processors/base_options.ts deleted file mode 100644 index 97b62b784..000000000 --- a/mediapipe/tasks/web/components/processors/base_options.ts +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb'; -import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb'; -import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; -import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; -import {BaseOptions} from '../../../../tasks/web/core/task_runner_options'; - -// The OSS JS API does not support the builder pattern. -// tslint:disable:jspb-use-builder-pattern - -/** - * Converts a BaseOptions API object to its Protobuf representation. - * @throws If neither a model assset path or buffer is provided - */ -export async function convertBaseOptionsToProto( - updatedOptions: BaseOptions, - currentOptions?: BaseOptionsProto): Promise { - const result = - currentOptions ? currentOptions.clone() : new BaseOptionsProto(); - - await configureExternalFile(updatedOptions, result); - configureAcceleration(updatedOptions, result); - - return result; -} - -/** - * Configues the `externalFile` option and validates that a single model is - * provided. - */ -async function configureExternalFile( - options: BaseOptions, proto: BaseOptionsProto) { - const externalFile = proto.getModelAsset() || new ExternalFile(); - proto.setModelAsset(externalFile); - - if (options.modelAssetPath || options.modelAssetBuffer) { - if (options.modelAssetPath && options.modelAssetBuffer) { - throw new Error( - 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); - } - - let modelAssetBuffer = options.modelAssetBuffer; - if (!modelAssetBuffer) { - const response = await fetch(options.modelAssetPath!.toString()); - modelAssetBuffer = new Uint8Array(await response.arrayBuffer()); - } - externalFile.setFileContent(modelAssetBuffer); - } - - if (!externalFile.hasFileContent()) { - throw new Error( - 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); - } -} - -/** Configues the `acceleration` option. */ -function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) { - const acceleration = proto.getAcceleration() ?? new Acceleration(); - if (options.delegate === 'GPU') { - acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); - } else { - acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); - } - proto.setAcceleration(acceleration); -} diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index 1721661f5..c0d10d28b 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -18,8 +18,10 @@ mediapipe_ts_library( srcs = ["task_runner.ts"], deps = [ ":core", + "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", + "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", - "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", @@ -53,6 +55,7 @@ mediapipe_ts_library( "task_runner_test.ts", ], deps = [ + ":core", ":task_runner", ":task_runner_test_utils", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 2011fadef..ffb538b52 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -14,9 +14,11 @@ * limitations under the License. */ +import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb'; +import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; -import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options'; -import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; +import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb'; +import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; @@ -91,14 +93,52 @@ export abstract class TaskRunner { this.graphRunner.registerModelResourcesGraphService(); } - /** Configures the shared options of a MediaPipe Task. */ - async setOptions(options: TaskRunnerOptions): Promise { - if (options.baseOptions) { - this.baseOptions = await convertBaseOptionsToProto( - options.baseOptions, this.baseOptions); + /** Configures the task with custom options. */ + abstract setOptions(options: TaskRunnerOptions): Promise; + + /** + * Applies the current set of options, including any base options that have + * not been processed by the task implementation. The options are applied + * synchronously unless a `modelAssetPath` is provided. This ensures that + * for most use cases options are applied directly and immediately affect + * the next inference. + */ + protected applyOptions(options: TaskRunnerOptions): Promise { + const baseOptions: BaseOptions = options.baseOptions || {}; + + // Validate that exactly one model is configured + if (options.baseOptions?.modelAssetBuffer && + options.baseOptions?.modelAssetPath) { + throw new Error( + 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); + } else if (!(this.baseOptions.getModelAsset()?.hasFileContent() || + options.baseOptions?.modelAssetBuffer || + options.baseOptions?.modelAssetPath)) { + throw new Error( + 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); + } + + this.setAcceleration(baseOptions); + if (baseOptions.modelAssetPath) { + // We don't use `await` here since we want to apply most settings + // synchronously. + return fetch(baseOptions.modelAssetPath.toString()) + .then(response => response.arrayBuffer()) + .then(buffer => { + this.setExternalFile(new Uint8Array(buffer)); + this.refreshGraph(); + }); + } else { + // Apply the setting synchronously. + this.setExternalFile(baseOptions.modelAssetBuffer); + this.refreshGraph(); + return Promise.resolve(); } } + /** Appliest the current options to the MediaPipe graph. */ + protected abstract refreshGraph(): void; + /** * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run * over the video stream. Will replace the previously running MediaPipe graph, @@ -140,6 +180,27 @@ export abstract class TaskRunner { } this.processingErrors = []; } + + /** Configures the `externalFile` option */ + private setExternalFile(modelAssetBuffer?: Uint8Array): void { + const externalFile = this.baseOptions.getModelAsset() || new ExternalFile(); + if (modelAssetBuffer) { + externalFile.setFileContent(modelAssetBuffer); + } + this.baseOptions.setModelAsset(externalFile); + } + + /** Configures the `acceleration` option. */ + private setAcceleration(options: BaseOptions) { + const acceleration = + this.baseOptions.getAcceleration() ?? new Acceleration(); + if (options.delegate === 'GPU') { + acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); + } else { + acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); + } + this.baseOptions.setAcceleration(acceleration); + } } diff --git a/mediapipe/tasks/web/core/task_runner_test.ts b/mediapipe/tasks/web/core/task_runner_test.ts index c9aad9d25..a55ac04d7 100644 --- a/mediapipe/tasks/web/core/task_runner_test.ts +++ b/mediapipe/tasks/web/core/task_runner_test.ts @@ -15,18 +15,22 @@ */ import 'jasmine'; +// Placeholder for internal dependency on encodeByteArray import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {TaskRunner} from '../../../tasks/web/core/task_runner'; import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils'; import {ErrorListener} from '../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource URL builder import {GraphRunnerImageLib} from './task_runner'; +import {TaskRunnerOptions} from './task_runner_options.d'; class TaskRunnerFake extends TaskRunner { - protected baseOptions = new BaseOptionsProto(); private errorListener: ErrorListener|undefined; private errors: string[] = []; + baseOptions = new BaseOptionsProto(); + static createFake(): TaskRunnerFake { const wasmModule = createSpyWasmModule(); return new TaskRunnerFake(wasmModule); @@ -61,10 +65,16 @@ class TaskRunnerFake extends TaskRunner { super.finishProcessing(); } + override refreshGraph(): void {} + override setGraph(graphData: Uint8Array, isBinary: boolean): void { super.setGraph(graphData, isBinary); } + setOptions(options: TaskRunnerOptions): Promise { + return this.applyOptions(options); + } + private throwErrors(): void { expect(this.errorListener).toBeDefined(); for (const error of this.errors) { @@ -75,8 +85,38 @@ class TaskRunnerFake extends TaskRunner { } describe('TaskRunner', () => { + const mockBytes = new Uint8Array([0, 1, 2, 3]); + const mockBytesResult = { + modelAsset: { + fileContent: Buffer.from(mockBytes).toString('base64'), + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined, + }, + useStreamMode: false, + acceleration: { + xnnpack: undefined, + gpu: undefined, + tflite: {}, + }, + }; + + let fetchSpy: jasmine.Spy; + let taskRunner: TaskRunnerFake; + + beforeEach(() => { + fetchSpy = jasmine.createSpy().and.callFake(async url => { + expect(url).toEqual('foo'); + return { + arrayBuffer: () => mockBytes.buffer, + } as unknown as Response; + }); + global.fetch = fetchSpy; + + taskRunner = TaskRunnerFake.createFake(); + }); + it('handles errors during graph update', () => { - const taskRunner = TaskRunnerFake.createFake(); taskRunner.enqueueError('Test error'); expect(() => { @@ -85,7 +125,6 @@ describe('TaskRunner', () => { }); it('handles errors during graph execution', () => { - const taskRunner = TaskRunnerFake.createFake(); taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); taskRunner.enqueueError('Test error'); @@ -96,7 +135,6 @@ describe('TaskRunner', () => { }); it('can handle multiple errors', () => { - const taskRunner = TaskRunnerFake.createFake(); taskRunner.enqueueError('Test error 1'); taskRunner.enqueueError('Test error 2'); @@ -104,4 +142,106 @@ describe('TaskRunner', () => { taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); }).toThrowError(/Test error 1, Test error 2/); }); + + it('verifies that at least one model asset option is provided', () => { + expect(() => { + taskRunner.setOptions({}); + }) + .toThrowError( + /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); + }); + + it('verifies that no more than one model asset option is provided', () => { + expect(() => { + taskRunner.setOptions({ + baseOptions: { + modelAssetPath: `foo`, + modelAssetBuffer: new Uint8Array([]) + } + }); + }) + .toThrowError( + /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); + }); + + it('doesn\'t require model once it is configured', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + expect(() => { + taskRunner.setOptions({}); + }).not.toThrowError(); + }); + + it('downloads model', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetPath: `foo`}}); + + expect(fetchSpy).toHaveBeenCalled(); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('does not download model when bytes are provided', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('changes model synchronously when bytes are provided', () => { + const resolvedPromise = taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + + // Check that the change has been applied even though we do not await the + // above Promise + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + return resolvedPromise; + }); + + it('can enable CPU delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'CPU', + } + }); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('can enable GPU delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'GPU', + } + }); + expect(taskRunner.baseOptions.toObject()).toEqual({ + ...mockBytesResult, + acceleration: { + xnnpack: undefined, + gpu: { + useAdvancedGpuApi: false, + api: 0, + allowPrecisionLoss: true, + cachedKernelPath: undefined, + serializedModelDir: undefined, + modelToken: undefined, + usage: 2, + }, + tflite: undefined, + }, + }); + }); + + it('can reset delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'GPU', + } + }); + // Clear backend + await taskRunner.setOptions({baseOptions: {delegate: undefined}}); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); }); diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts index 2a1161a55..838b3f585 100644 --- a/mediapipe/tasks/web/core/task_runner_test_utils.ts +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -44,10 +44,10 @@ export function createSpyWasmModule(): SpyWasmModule { * Sets up our equality testing to use a custom float equality checking function * to avoid incorrect test results due to minor floating point inaccuracies. */ -export function addJasmineCustomFloatEqualityTester() { +export function addJasmineCustomFloatEqualityTester(tolerance = 5e-8) { jasmine.addCustomEqualityTester((a, b) => { // Custom float equality if (a === +a && b === +b && (a !== (a | 0) || b !== (b | 0))) { - return Math.abs(a - b) < 5e-8; + return Math.abs(a - b) < tolerance; } return; }); diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 62708700a..981438625 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -109,11 +109,10 @@ export class TextClassifier extends TaskRunner { * * @param options The options for the text classifier. */ - override async setOptions(options: TextClassifierOptions): Promise { - await super.setOptions(options); + override setOptions(options: TextClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } protected override get baseOptions(): BaseOptionsProto { @@ -141,7 +140,7 @@ export class TextClassifier extends TaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts index 841bf8c48..5578362cb 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts @@ -56,7 +56,8 @@ describe('TextClassifier', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); textClassifier = new TextClassifierFake(); - await textClassifier.setOptions({}); // Initialize graph + await textClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 611233e02..7aa0aa6b9 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -113,11 +113,10 @@ export class TextEmbedder extends TaskRunner { * * @param options The options for the text embedder. */ - override async setOptions(options: TextEmbedderOptions): Promise { - await super.setOptions(options); + override setOptions(options: TextEmbedderOptions): Promise { this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); + return this.applyOptions(options); } protected override get baseOptions(): BaseOptionsProto { @@ -157,7 +156,7 @@ export class TextEmbedder extends TaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(EMBEDDINGS_STREAM); diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts index 04a9b371a..2804e4deb 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts @@ -56,7 +56,8 @@ describe('TextEmbedder', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); textEmbedder = new TextEmbedderFake(); - await textEmbedder.setOptions({}); // Initialize graph + await textEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index e4ea3036f..03958a819 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -29,6 +29,7 @@ mediapipe_ts_library( testonly = True, srcs = ["vision_task_runner.test.ts"], deps = [ + ":vision_task_options", ":vision_task_runner", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/web/core:task_runner_test_utils", diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts index 6cc9ea328..d77cc4fed 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -20,6 +20,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils'; import {ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {VisionTaskOptions} from './vision_task_options'; import {VisionTaskRunner} from './vision_task_runner'; class VisionTaskRunnerFake extends VisionTaskRunner { @@ -31,6 +32,12 @@ class VisionTaskRunnerFake extends VisionTaskRunner { protected override process(): void {} + protected override refreshGraph(): void {} + + override setOptions(options: VisionTaskOptions): Promise { + return this.applyOptions(options); + } + override processImageData(image: ImageSource): void { super.processImageData(image); } @@ -41,32 +48,24 @@ class VisionTaskRunnerFake extends VisionTaskRunner { } describe('VisionTaskRunner', () => { - const streamMode = { - modelAsset: undefined, - useStreamMode: true, - acceleration: undefined, - }; - - const imageMode = { - modelAsset: undefined, - useStreamMode: false, - acceleration: undefined, - }; - let visionTaskRunner: VisionTaskRunnerFake; - beforeEach(() => { + beforeEach(async () => { visionTaskRunner = new VisionTaskRunnerFake(); + await visionTaskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('can enable image mode', async () => { await visionTaskRunner.setOptions({runningMode: 'image'}); - expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: false})); }); it('can enable video mode', async () => { await visionTaskRunner.setOptions({runningMode: 'video'}); - expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: true})); }); it('can clear running mode', async () => { @@ -74,7 +73,8 @@ describe('VisionTaskRunner', () => { // Clear running mode await visionTaskRunner.setOptions({runningMode: undefined}); - expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: false})); }); it('cannot process images with video mode', async () => { diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 3432b521b..952990326 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -22,13 +22,13 @@ import {VisionTaskOptions} from './vision_task_options'; /** Base class for all MediaPipe Vision Tasks. */ export abstract class VisionTaskRunner extends TaskRunner { /** Configures the shared options of a vision task. */ - override async setOptions(options: VisionTaskOptions): Promise { - await super.setOptions(options); + override applyOptions(options: VisionTaskOptions): Promise { if ('runningMode' in options) { const useStreamMode = !!options.runningMode && options.runningMode !== 'image'; this.baseOptions.setUseStreamMode(useStreamMode); } + return super.applyOptions(options); } /** Sends an image packet to the graph and awaits results. */ diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index b6b795076..c77f2c67a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -169,9 +169,7 @@ export class GestureRecognizer extends * * @param options The options for the gesture recognizer. */ - override async setOptions(options: GestureRecognizerOptions): Promise { - await super.setOptions(options); - + override setOptions(options: GestureRecognizerOptions): Promise { if ('numHands' in options) { this.handDetectorGraphOptions.setNumHands( options.numHands ?? DEFAULT_NUM_HANDS); @@ -221,7 +219,7 @@ export class GestureRecognizer extends ?.clearClassifierOptions(); } - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -265,12 +263,22 @@ export class GestureRecognizer extends NORM_RECT_STREAM, timestamp); this.finishProcessing(); - return { - gestures: this.gestures, - landmarks: this.landmarks, - worldLandmarks: this.worldLandmarks, - handednesses: this.handednesses - }; + if (this.gestures.length === 0) { + // If no gestures are detected in the image, just return an empty list + return { + gestures: [], + landmarks: [], + worldLandmarks: [], + handednesses: [], + }; + } else { + return { + gestures: this.gestures, + landmarks: this.landmarks, + worldLandmarks: this.worldLandmarks, + handednesses: this.handednesses + }; + } } /** Sets the default values for the graph. */ @@ -285,15 +293,19 @@ export class GestureRecognizer extends } /** Converts the proto data to a Category[][] structure. */ - private toJsCategories(data: Uint8Array[]): Category[][] { + private toJsCategories(data: Uint8Array[], populateIndex = true): + Category[][] { const result: Category[][] = []; for (const binaryProto of data) { const inputList = ClassificationList.deserializeBinary(binaryProto); const outputList: Category[] = []; for (const classification of inputList.getClassificationList()) { + const index = populateIndex && classification.hasIndex() ? + classification.getIndex()! : + DEFAULT_CATEGORY_INDEX; outputList.push({ score: classification.getScore() ?? 0, - index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX, + index, categoryName: classification.getLabel() ?? '', displayName: classification.getDisplayName() ?? '', }); @@ -342,7 +354,7 @@ export class GestureRecognizer extends } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM); @@ -377,7 +389,10 @@ export class GestureRecognizer extends }); this.graphRunner.attachProtoVectorListener( HAND_GESTURES_STREAM, binaryProto => { - this.gestures.push(...this.toJsCategories(binaryProto)); + // Gesture index is not used, because the final gesture result comes + // from multiple classifiers. + this.gestures.push( + ...this.toJsCategories(binaryProto, /* populateIndex= */ false)); }); this.graphRunner.attachProtoVectorListener( HANDEDNESS_STREAM, binaryProto => { diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts index e570270b2..323290008 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts @@ -17,6 +17,8 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; +export {Category, Landmark, NormalizedLandmark}; + /** * Represents the gesture recognition results generated by `GestureRecognizer`. */ @@ -30,6 +32,10 @@ export declare interface GestureRecognizerResult { /** Handedness of detected hands. */ handednesses: Category[][]; - /** Recognized hand gestures of detected hands */ + /** + * Recognized hand gestures of detected hands. Note that the index of the + * gesture is always -1, because the raw indices from multiple gesture + * classifiers cannot consolidate to a meaningful index. + */ gestures: Category[][]; } diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts index c0f0d1554..ee51fd32a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -109,7 +109,8 @@ describe('GestureRecognizer', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); gestureRecognizer = new GestureRecognizerFake(); - await gestureRecognizer.setOptions({}); // Initialize graph + await gestureRecognizer.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { @@ -271,7 +272,7 @@ describe('GestureRecognizer', () => { expect(gestures).toEqual({ 'gestures': [[{ 'score': 0.2, - 'index': 2, + 'index': -1, 'categoryName': 'gesture_label', 'displayName': 'gesture_display_name' }]], @@ -304,4 +305,25 @@ describe('GestureRecognizer', () => { // gestures. expect(gestures2).toEqual(gestures1); }); + + it('returns empty results when no gestures are detected', async () => { + // Pass the test data to our listener + gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(gestureRecognizer); + gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + gestureRecognizer.listeners.get('handedness')!(createHandednesses()); + gestureRecognizer.listeners.get('hand_gestures')!([]); + }); + + // Invoke the gesture recognizer + const gestures = gestureRecognizer.recognize({} as HTMLImageElement); + expect(gestures).toEqual({ + 'gestures': [], + 'landmarks': [], + 'worldLandmarks': [], + 'handednesses': [] + }); + }); }); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 2a0e8286c..24cf9a402 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -150,9 +150,7 @@ export class HandLandmarker extends VisionTaskRunner { * * @param options The options for the hand landmarker. */ - override async setOptions(options: HandLandmarkerOptions): Promise { - await super.setOptions(options); - + override setOptions(options: HandLandmarkerOptions): Promise { // Configure hand detector options. if ('numHands' in options) { this.handDetectorGraphOptions.setNumHands( @@ -173,7 +171,7 @@ export class HandLandmarker extends VisionTaskRunner { options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD); } - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -291,7 +289,7 @@ export class HandLandmarker extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts index 89f867d69..8a6d9bfa6 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts @@ -17,6 +17,8 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; +export {Landmark, NormalizedLandmark, Category}; + /** * Represents the hand landmarks deection results generated by `HandLandmarker`. */ diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts index fc26680e0..76e77b4bf 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -98,7 +98,8 @@ describe('HandLandmarker', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); handLandmarker = new HandLandmarkerFake(); - await handLandmarker.setOptions({}); // Initialize graph + await handLandmarker.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 36e7311fb..9298a860c 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -118,11 +118,10 @@ export class ImageClassifier extends VisionTaskRunner { * * @param options The options for the image classifier. */ - override async setOptions(options: ImageClassifierOptions): Promise { - await super.setOptions(options); + override setOptions(options: ImageClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -163,7 +162,7 @@ export class ImageClassifier extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts index 2041a0cef..da4a01d02 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts @@ -61,7 +61,8 @@ describe('ImageClassifier', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); imageClassifier = new ImageClassifierFake(); - await imageClassifier.setOptions({}); // Initialize graph + await imageClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 0c45ba5e7..cf0bd8c5d 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -120,11 +120,10 @@ export class ImageEmbedder extends VisionTaskRunner { * * @param options The options for the image embedder. */ - override async setOptions(options: ImageEmbedderOptions): Promise { - await super.setOptions(options); + override setOptions(options: ImageEmbedderOptions): Promise { this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -186,7 +185,7 @@ export class ImageEmbedder extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(EMBEDDINGS_STREAM); diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts index cafe0f3d8..b63bb374c 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts @@ -57,7 +57,8 @@ describe('ImageEmbedder', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); imageEmbedder = new ImageEmbedderFake(); - await imageEmbedder.setOptions({}); // Initialize graph + await imageEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index fbfaced12..e4c51de08 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -117,9 +117,7 @@ export class ObjectDetector extends VisionTaskRunner { * * @param options The options for the object detector. */ - override async setOptions(options: ObjectDetectorOptions): Promise { - await super.setOptions(options); - + override setOptions(options: ObjectDetectorOptions): Promise { // Note that we have to support both JSPB and ProtobufJS, hence we // have to expliclity clear the values instead of setting them to // `undefined`. @@ -153,7 +151,7 @@ export class ObjectDetector extends VisionTaskRunner { this.options.clearCategoryDenylistList(); } - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -226,7 +224,7 @@ export class ObjectDetector extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(DETECTIONS_STREAM); diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts index fff1a1c48..43b7035d5 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -61,7 +61,8 @@ describe('ObjectDetector', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); objectDetector = new ObjectDetectorFake(); - await objectDetector.setOptions({}); // Initialize graph + await objectDetector.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => {