From a4d0e68bee18b11d06af266fb96eae438c4725cc Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 17 May 2023 13:34:46 -0700 Subject: [PATCH 01/20] Internal change PiperOrigin-RevId: 532890317 --- third_party/flatbuffers/BUILD.bazel | 3 +-- third_party/flatbuffers/workspace.bzl | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/third_party/flatbuffers/BUILD.bazel b/third_party/flatbuffers/BUILD.bazel index d5264a026..8b814f8af 100644 --- a/third_party/flatbuffers/BUILD.bazel +++ b/third_party/flatbuffers/BUILD.bazel @@ -42,16 +42,15 @@ filegroup( "include/flatbuffers/allocator.h", "include/flatbuffers/array.h", "include/flatbuffers/base.h", - "include/flatbuffers/bfbs_generator.h", "include/flatbuffers/buffer.h", "include/flatbuffers/buffer_ref.h", "include/flatbuffers/code_generator.h", "include/flatbuffers/code_generators.h", "include/flatbuffers/default_allocator.h", "include/flatbuffers/detached_buffer.h", + "include/flatbuffers/file_manager.h", "include/flatbuffers/flatbuffer_builder.h", "include/flatbuffers/flatbuffers.h", - "include/flatbuffers/flatc.h", "include/flatbuffers/flex_flat_util.h", "include/flatbuffers/flexbuffers.h", "include/flatbuffers/grpc.h", diff --git a/third_party/flatbuffers/workspace.bzl b/third_party/flatbuffers/workspace.bzl index 02247268b..0edb7a6f6 100644 --- a/third_party/flatbuffers/workspace.bzl +++ b/third_party/flatbuffers/workspace.bzl @@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive") def repo(): third_party_http_archive( name = "flatbuffers", - strip_prefix = "flatbuffers-23.1.21", - sha256 = "d84cb25686514348e615163b458ae0767001b24b42325f426fd56406fd384238", + strip_prefix = "flatbuffers-23.5.8", + sha256 = "55b75dfa5b6f6173e4abf9c35284a10482ba65db886b39db511eba6c244f1e88", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v23.1.21.tar.gz", - "https://github.com/google/flatbuffers/archive/v23.1.21.tar.gz", + "https://github.com/google/flatbuffers/archive/v23.5.8.tar.gz", + "https://github.com/google/flatbuffers/archive/v23.5.8.tar.gz", ], build_file = "//third_party/flatbuffers:BUILD.bazel", delete = ["build_defs.bzl", "BUILD.bazel"], From 1fb98f5ebd006f92e4061265dbdf7ce92a6990e5 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 17 May 2023 15:47:36 -0700 Subject: [PATCH 02/20] Don't double build ARM64 arch on M1 Macs PiperOrigin-RevId: 532934646 --- setup.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 892c6dca7..b61e5c296 100644 --- a/setup.py +++ b/setup.py @@ -357,7 +357,10 @@ class BuildExtension(build_ext.build_ext): for ext in self.extensions: target_name = self.get_ext_fullpath(ext.name) # Build x86 - self._build_binary(ext) + self._build_binary( + ext, + ['--cpu=darwin', '--ios_multi_cpus=i386,x86_64,armv7,arm64'], + ) x86_name = self.get_ext_fullpath(ext.name) # Build Arm64 ext.name = ext.name + '.arm64' From 02230f65d1005203fa066259220db53c0d75b225 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 17 May 2023 15:48:18 -0700 Subject: [PATCH 03/20] Internal change PiperOrigin-RevId: 532934867 --- mediapipe/framework/BUILD | 1 + mediapipe/framework/graph_service.h | 23 ++++++++++++++++--- .../framework/graph_service_manager_test.cc | 2 +- mediapipe/framework/graph_service_test.cc | 14 +++++++---- mediapipe/framework/test_service.cc | 9 -------- mediapipe/framework/test_service.h | 12 ++++++---- 6 files changed, 40 insertions(+), 21 deletions(-) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 8224b73fc..a7d9e0a63 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1099,6 +1099,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", ], + alwayslink = True, # Defines TestServiceCalculator ) cc_library( diff --git a/mediapipe/framework/graph_service.h b/mediapipe/framework/graph_service.h index 51caf31f2..12b2ccb3a 100644 --- a/mediapipe/framework/graph_service.h +++ b/mediapipe/framework/graph_service.h @@ -44,7 +44,6 @@ class GraphServiceBase { constexpr GraphServiceBase(const char* key) : key(key) {} - virtual ~GraphServiceBase() = default; inline virtual absl::StatusOr CreateDefaultObject() const { return DefaultInitializationUnsupported(); } @@ -52,14 +51,32 @@ class GraphServiceBase { const char* key; protected: + // `GraphService` objects, deriving `GraphServiceBase` are designed to be + // global constants and not ever deleted through `GraphServiceBase`. Hence, + // protected and non-virtual destructor which helps to make `GraphService` + // trivially destructible and properly defined as global constants. + // + // A class with any virtual functions should have a destructor that is either + // public and virtual or else protected and non-virtual. + // https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#Rc-dtor-virtual + ~GraphServiceBase() = default; + absl::Status DefaultInitializationUnsupported() const { return absl::UnimplementedError(absl::StrCat( "Graph service '", key, "' does not support default initialization")); } }; +// A global constant to refer a service: +// - Requesting `CalculatorContract::UseService` from calculator +// - Accessing `Calculator/SubgraphContext::Service`from calculator/subgraph +// - Setting before graph initialization `CalculatorGraph::SetServiceObject` +// +// NOTE: In headers, define your graph service reference safely as following: +// `inline constexpr GraphService kYourService("YourService");` +// template -class GraphService : public GraphServiceBase { +class GraphService final : public GraphServiceBase { public: using type = T; using packet_type = std::shared_ptr; @@ -68,7 +85,7 @@ class GraphService : public GraphServiceBase { kDisallowDefaultInitialization) : GraphServiceBase(my_key), default_init_(default_init) {} - absl::StatusOr CreateDefaultObject() const override { + absl::StatusOr CreateDefaultObject() const final { if (default_init_ != kAllowDefaultInitialization) { return DefaultInitializationUnsupported(); } diff --git a/mediapipe/framework/graph_service_manager_test.cc b/mediapipe/framework/graph_service_manager_test.cc index 1895a6f70..23d4af0df 100644 --- a/mediapipe/framework/graph_service_manager_test.cc +++ b/mediapipe/framework/graph_service_manager_test.cc @@ -7,7 +7,7 @@ namespace mediapipe { namespace { -const GraphService kIntService("mediapipe::IntService"); +constexpr GraphService kIntService("mediapipe::IntService"); } // namespace TEST(GraphServiceManager, SetGetServiceObject) { diff --git a/mediapipe/framework/graph_service_test.cc b/mediapipe/framework/graph_service_test.cc index 69992f212..0556aac63 100644 --- a/mediapipe/framework/graph_service_test.cc +++ b/mediapipe/framework/graph_service_test.cc @@ -14,6 +14,8 @@ #include "mediapipe/framework/graph_service.h" +#include + #include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/canonical_errors.h" @@ -159,7 +161,7 @@ TEST_F(GraphServiceTest, CreateDefault) { struct TestServiceData {}; -const GraphService kTestServiceAllowDefaultInitialization( +constexpr GraphService kTestServiceAllowDefaultInitialization( "kTestServiceAllowDefaultInitialization", GraphServiceBase::kAllowDefaultInitialization); @@ -272,9 +274,13 @@ TEST(AllowDefaultInitializationGraphServiceTest, HasSubstr("Service is unavailable."))); } -const GraphService kTestServiceDisallowDefaultInitialization( - "kTestServiceDisallowDefaultInitialization", - GraphServiceBase::kDisallowDefaultInitialization); +constexpr GraphService + kTestServiceDisallowDefaultInitialization( + "kTestServiceDisallowDefaultInitialization", + GraphServiceBase::kDisallowDefaultInitialization); + +static_assert(std::is_trivially_destructible_v>, + "GraphService is not trivially destructible"); class FailOnUnavailableOptionalDisallowDefaultInitServiceCalculator : public CalculatorBase { diff --git a/mediapipe/framework/test_service.cc b/mediapipe/framework/test_service.cc index 4bafaf28c..e7233ebf9 100644 --- a/mediapipe/framework/test_service.cc +++ b/mediapipe/framework/test_service.cc @@ -16,15 +16,6 @@ namespace mediapipe { -const GraphService kTestService( - "test_service", GraphServiceBase::kDisallowDefaultInitialization); -const GraphService kAnotherService( - "another_service", GraphServiceBase::kAllowDefaultInitialization); -const GraphService kNoDefaultService( - "no_default_service", GraphServiceBase::kAllowDefaultInitialization); -const GraphService kNeedsCreateService( - "needs_create_service", GraphServiceBase::kAllowDefaultInitialization); - absl::Status TestServiceCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); diff --git a/mediapipe/framework/test_service.h b/mediapipe/framework/test_service.h index 2ff5a384a..42ebd8df8 100644 --- a/mediapipe/framework/test_service.h +++ b/mediapipe/framework/test_service.h @@ -22,14 +22,17 @@ namespace mediapipe { using TestServiceObject = std::map; -extern const GraphService kTestService; -extern const GraphService kAnotherService; +inline constexpr GraphService kTestService( + "test_service", GraphServiceBase::kDisallowDefaultInitialization); +inline constexpr GraphService kAnotherService( + "another_service", GraphServiceBase::kAllowDefaultInitialization); class NoDefaultConstructor { public: NoDefaultConstructor() = delete; }; -extern const GraphService kNoDefaultService; +inline constexpr GraphService kNoDefaultService( + "no_default_service", GraphServiceBase::kAllowDefaultInitialization); class NeedsCreateMethod { public: @@ -40,7 +43,8 @@ class NeedsCreateMethod { private: NeedsCreateMethod() = default; }; -extern const GraphService kNeedsCreateService; +inline constexpr GraphService kNeedsCreateService( + "needs_create_service", GraphServiceBase::kAllowDefaultInitialization); // Use a service. class TestServiceCalculator : public CalculatorBase { From 25458138a99132dc8444b5f54270d1b2f5eeb242 Mon Sep 17 00:00:00 2001 From: Rachel Hornung Date: Wed, 17 May 2023 17:08:38 -0700 Subject: [PATCH 04/20] #MediaPipe Add ConcatenateStringVectorCalculator. PiperOrigin-RevId: 532956844 --- .../core/concatenate_vector_calculator.cc | 4 +++ .../concatenate_vector_calculator_test.cc | 25 ++++++++++++++++--- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/mediapipe/calculators/core/concatenate_vector_calculator.cc b/mediapipe/calculators/core/concatenate_vector_calculator.cc index 0079aa98d..4d0d66206 100644 --- a/mediapipe/calculators/core/concatenate_vector_calculator.cc +++ b/mediapipe/calculators/core/concatenate_vector_calculator.cc @@ -55,6 +55,10 @@ MEDIAPIPE_REGISTER_NODE(ConcatenateUInt64VectorCalculator); typedef ConcatenateVectorCalculator ConcatenateBoolVectorCalculator; MEDIAPIPE_REGISTER_NODE(ConcatenateBoolVectorCalculator); +typedef ConcatenateVectorCalculator + ConcatenateStringVectorCalculator; +MEDIAPIPE_REGISTER_NODE(ConcatenateStringVectorCalculator); + // Example config: // node { // calculator: "ConcatenateTfLiteTensorVectorCalculator" diff --git a/mediapipe/calculators/core/concatenate_vector_calculator_test.cc b/mediapipe/calculators/core/concatenate_vector_calculator_test.cc index 5510b98a3..3fccf58fd 100644 --- a/mediapipe/calculators/core/concatenate_vector_calculator_test.cc +++ b/mediapipe/calculators/core/concatenate_vector_calculator_test.cc @@ -30,13 +30,15 @@ namespace mediapipe { typedef ConcatenateVectorCalculator TestConcatenateIntVectorCalculator; MEDIAPIPE_REGISTER_NODE(TestConcatenateIntVectorCalculator); -void AddInputVector(int index, const std::vector& input, int64_t timestamp, +template +void AddInputVector(int index, const std::vector& input, int64_t timestamp, CalculatorRunner* runner) { runner->MutableInputs()->Index(index).packets.push_back( - MakePacket>(input).At(Timestamp(timestamp))); + MakePacket>(input).At(Timestamp(timestamp))); } -void AddInputVectors(const std::vector>& inputs, +template +void AddInputVectors(const std::vector>& inputs, int64_t timestamp, CalculatorRunner* runner) { for (int i = 0; i < inputs.size(); ++i) { AddInputVector(i, inputs[i], timestamp, runner); @@ -382,6 +384,23 @@ TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamNoOutput) { EXPECT_EQ(0, outputs.size()); } +TEST(ConcatenateStringVectorCalculatorTest, OneTimestamp) { + CalculatorRunner runner("ConcatenateStringVectorCalculator", + /*options_string=*/"", /*num_inputs=*/3, + /*num_outputs=*/1, /*num_side_packets=*/0); + + std::vector> inputs = { + {"a", "b"}, {"c"}, {"d", "e", "f"}}; + AddInputVectors(inputs, /*timestamp=*/1, &runner); + MP_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(Timestamp(1), outputs[0].Timestamp()); + std::vector expected_vector = {"a", "b", "c", "d", "e", "f"}; + EXPECT_EQ(expected_vector, outputs[0].Get>()); +} + typedef ConcatenateVectorCalculator> TestConcatenateUniqueIntPtrCalculator; MEDIAPIPE_REGISTER_NODE(TestConcatenateUniqueIntPtrCalculator); From 03b901a443f5635210590e941e8d1d8436f68ffd Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 18 May 2023 09:14:06 -0700 Subject: [PATCH 05/20] Internal change PiperOrigin-RevId: 533150010 --- mediapipe/tasks/python/test/vision/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 9efec7f2a..ae3d53d61 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -201,6 +201,7 @@ py_test( "//mediapipe/tasks/testdata/vision:test_images", "//mediapipe/tasks/testdata/vision:test_models", ], + tags = ["not_run:arm"], deps = [ "//mediapipe/python:_framework_bindings", "//mediapipe/tasks/python/components/containers:rect", From a1755044ea8ef5a4161065e59e5c0dd010c105b6 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 18 May 2023 11:08:13 -0700 Subject: [PATCH 06/20] Internal change PiperOrigin-RevId: 533187060 --- mediapipe/gpu/gpu_buffer_format.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mediapipe/gpu/gpu_buffer_format.cc b/mediapipe/gpu/gpu_buffer_format.cc index a820f04d6..00ee9e248 100644 --- a/mediapipe/gpu/gpu_buffer_format.cc +++ b/mediapipe/gpu/gpu_buffer_format.cc @@ -28,6 +28,12 @@ namespace mediapipe { #define GL_HALF_FLOAT 0x140B #endif // GL_HALF_FLOAT +#ifdef __EMSCRIPTEN__ +#ifndef GL_HALF_FLOAT_OES +#define GL_HALF_FLOAT_OES 0x8D61 +#endif // GL_HALF_FLOAT_OES +#endif // __EMSCRIPTEN__ + #if !MEDIAPIPE_DISABLE_GPU #ifdef GL_ES_VERSION_2_0 static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) { @@ -48,6 +54,12 @@ static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) { case GL_RG8: info->gl_internal_format = info->gl_format = GL_RG_EXT; return; +#ifdef __EMSCRIPTEN__ + case GL_RGBA16F: + info->gl_internal_format = GL_RGBA; + info->gl_type = GL_HALF_FLOAT_OES; + return; +#endif // __EMSCRIPTEN__ default: return; } From c248525eeb17da110346ec76a9865de5a20d4c4d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 18 May 2023 11:37:13 -0700 Subject: [PATCH 07/20] internal update PiperOrigin-RevId: 533197055 --- .../tensors_to_segmentation_calculator.cc | 32 ++++++++++++++--- .../vision/image_segmenter/image_segmenter.cc | 18 ++++++++-- .../image_segmenter/image_segmenter_graph.cc | 36 ++++++++++++------- .../image_segmenter/image_segmenter_result.h | 4 +++ .../interactive_segmenter.cc | 8 ++++- .../interactive_segmenter_graph.cc | 3 ++ .../vision/imagesegmenter/ImageSegmenter.java | 14 ++++++-- .../imagesegmenter/ImageSegmenterResult.java | 15 ++++++-- .../InteractiveSegmenter.java | 13 +++++++ 9 files changed, 119 insertions(+), 24 deletions(-) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc index 660dc59b7..f77855587 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc @@ -291,8 +291,11 @@ class TensorsToSegmentationCalculator : public Node { static constexpr Output::Multiple kConfidenceMaskOut{ "CONFIDENCE_MASK"}; static constexpr Output::Optional kCategoryMaskOut{"CATEGORY_MASK"}; + static constexpr Output>::Optional kQualityScoresOut{ + "QUALITY_SCORES"}; MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut, - kConfidenceMaskOut, kCategoryMaskOut); + kConfidenceMaskOut, kCategoryMaskOut, + kQualityScoresOut); static absl::Status UpdateContract(CalculatorContract* cc); @@ -345,12 +348,33 @@ absl::Status TensorsToSegmentationCalculator::Open( absl::Status TensorsToSegmentationCalculator::Process( mediapipe::CalculatorContext* cc) { - RET_CHECK_EQ(kTensorsIn(cc).Get().size(), 1) - << "Expect a vector of single Tensor."; - const auto& input_tensor = kTensorsIn(cc).Get()[0]; + const auto& input_tensors = kTensorsIn(cc).Get(); + if (input_tensors.size() != 1 && input_tensors.size() != 2) { + return absl::InvalidArgumentError( + "Expect input tensor vector of size 1 or 2."); + } + const auto& input_tensor = *input_tensors.rbegin(); ASSIGN_OR_RETURN(const Shape input_shape, GetImageLikeTensorShape(input_tensor)); + // TODO: should use tensor signature to get the correct output + // tensor. + if (input_tensors.size() == 2) { + const auto& quality_tensor = input_tensors[0]; + const float* quality_score_buffer = + quality_tensor.GetCpuReadView().buffer(); + const std::vector quality_scores( + quality_score_buffer, + quality_score_buffer + + (quality_tensor.bytes() / quality_tensor.element_size())); + kQualityScoresOut(cc).Send(quality_scores); + } else { + // If the input_tensors don't contain quality scores, send the default + // quality scores as 1. + const std::vector quality_scores(input_shape.channels, 1.0f); + kQualityScoresOut(cc).Send(quality_scores); + } + // Category mask does not require activation function. if (options_.segmenter_options().output_type() == SegmenterOptions::CONFIDENCE_MASK && diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index a67843258..99faa1064 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -46,6 +46,8 @@ constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageTag[] = "IMAGE"; constexpr char kNormRectStreamName[] = "norm_rect_in"; constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kQualityScoresStreamName[] = "quality_scores"; +constexpr char kQualityScoresTag[] = "QUALITY_SCORES"; constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; @@ -77,6 +79,8 @@ CalculatorGraphConfig CreateGraphConfig( task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >> graph.Out(kCategoryMaskTag); } + task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >> + graph.Out(kQualityScoresTag); task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); if (enable_flow_limiting) { @@ -172,9 +176,13 @@ absl::StatusOr> ImageSegmenter::Create( category_mask = status_or_packets.value()[kCategoryMaskStreamName].Get(); } + const std::vector& quality_scores = + status_or_packets.value()[kQualityScoresStreamName] + .Get>(); Packet image_packet = status_or_packets.value()[kImageOutStreamName]; result_callback( - {{confidence_masks, category_mask}}, image_packet.Get(), + {{confidence_masks, category_mask, quality_scores}}, + image_packet.Get(), image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); }; } @@ -227,7 +235,9 @@ absl::StatusOr ImageSegmenter::Segment( if (output_category_mask_) { category_mask = output_packets[kCategoryMaskStreamName].Get(); } - return {{confidence_masks, category_mask}}; + const std::vector& quality_scores = + output_packets[kQualityScoresStreamName].Get>(); + return {{confidence_masks, category_mask, quality_scores}}; } absl::StatusOr ImageSegmenter::SegmentForVideo( @@ -260,7 +270,9 @@ absl::StatusOr ImageSegmenter::SegmentForVideo( if (output_category_mask_) { category_mask = output_packets[kCategoryMaskStreamName].Get(); } - return {{confidence_masks, category_mask}}; + const std::vector& quality_scores = + output_packets[kQualityScoresStreamName].Get>(); + return {{confidence_masks, category_mask, quality_scores}}; } absl::Status ImageSegmenter::SegmentAsync( diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 6ecfa3685..0ae47ffd1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -81,6 +82,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU"; constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; +constexpr char kQualityScoresTag[] = "QUALITY_SCORES"; constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA"; // Struct holding the different output streams produced by the image segmenter @@ -90,6 +92,7 @@ struct ImageSegmenterOutputs { std::optional>> confidence_masks; std::optional> category_mask; // The same as the input image, mainly used for live stream mode. + std::optional>> quality_scores; Source image; }; @@ -191,19 +194,12 @@ absl::Status ConfigureTensorsToSegmentationCalculator( "Segmentation tflite models are assumed to have a single subgraph.", MediaPipeTasksStatus::kInvalidArgumentError); } - const auto* primary_subgraph = (*model.subgraphs())[0]; - if (primary_subgraph->outputs()->size() != 1) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "Segmentation tflite models are assumed to have a single output.", - MediaPipeTasksStatus::kInvalidArgumentError); - } - ASSIGN_OR_RETURN( *options->mutable_label_items(), - GetLabelItemsIfAny(*metadata_extractor, - *metadata_extractor->GetOutputTensorMetadata()->Get(0), - segmenter_option.display_names_locale())); + GetLabelItemsIfAny( + *metadata_extractor, + **metadata_extractor->GetOutputTensorMetadata()->crbegin(), + segmenter_option.display_names_locale())); return absl::OkStatus(); } @@ -213,10 +209,16 @@ absl::StatusOr GetOutputTensor( const tflite::Model& model = *model_resources.GetTfLiteModel(); const auto* primary_subgraph = (*model.subgraphs())[0]; const auto* output_tensor = - (*primary_subgraph->tensors())[(*primary_subgraph->outputs())[0]]; + (*primary_subgraph->tensors())[*(*primary_subgraph->outputs()).rbegin()]; return output_tensor; } +uint32_t GetOutputTensorsSize(const core::ModelResources& model_resources) { + const tflite::Model& model = *model_resources.GetTfLiteModel(); + const auto* primary_subgraph = (*model.subgraphs())[0]; + return primary_subgraph->outputs()->size(); +} + // Get the input tensor from the tflite model of given model resources. absl::StatusOr GetInputTensor( const core::ModelResources& model_resources) { @@ -433,6 +435,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { *output_streams.category_mask >> graph[Output(kCategoryMaskTag)]; } } + if (output_streams.quality_scores) { + *output_streams.quality_scores >> + graph[Output>::Optional(kQualityScoresTag)]; + } output_streams.image >> graph[Output(kImageTag)]; return graph.GetConfig(); } @@ -530,9 +536,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { tensor_to_images[Output::Multiple(kSegmentationTag)][i])); } } + auto quality_scores = + tensor_to_images[Output>(kQualityScoresTag)]; return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks, /*confidence_masks=*/std::nullopt, /*category_mask=*/std::nullopt, + /*quality_scores=*/quality_scores, /*image=*/image_and_tensors.image}; } else { std::optional>> confidence_masks; @@ -552,9 +561,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { if (output_category_mask_) { category_mask = tensor_to_images[Output(kCategoryMaskTag)]; } + auto quality_scores = + tensor_to_images[Output>(kQualityScoresTag)]; return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt, /*confidence_masks=*/confidence_masks, /*category_mask=*/category_mask, + /*quality_scores=*/quality_scores, /*image=*/image_and_tensors.image}; } } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h index 1e7968ebd..a203718f4 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h @@ -33,6 +33,10 @@ struct ImageSegmenterResult { // A category mask of uint8 image in GRAY8 format where each pixel represents // the class which the pixel in the original image was predicted to belong to. std::optional category_mask; + // The quality scores of the result masks, in the range of [0, 1]. Default to + // `1` if the model doesn't output quality scores. Each element corresponds to + // the score of the category in the model outputs. + std::vector quality_scores; }; } // namespace image_segmenter diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc index c0d89c87d..38bbf3baf 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc @@ -51,12 +51,14 @@ constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; constexpr char kRoiStreamName[] = "roi_in"; constexpr char kNormRectStreamName[] = "norm_rect_in"; +constexpr char kQualityScoresStreamName[] = "quality_scores"; constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"}; constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"}; constexpr absl::string_view kImageTag{"IMAGE"}; constexpr absl::string_view kRoiTag{"ROI"}; constexpr absl::string_view kNormRectTag{"NORM_RECT"}; +constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"}; constexpr absl::string_view kSubgraphTypeName{ "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"}; @@ -91,6 +93,8 @@ CalculatorGraphConfig CreateGraphConfig( task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >> graph.Out(kCategoryMaskTag); } + task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >> + graph.Out(kQualityScoresTag); task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); graph.In(kImageTag) >> task_subgraph.In(kImageTag); @@ -201,7 +205,9 @@ absl::StatusOr InteractiveSegmenter::Segment( if (output_category_mask_) { category_mask = output_packets[kCategoryMaskStreamName].Get(); } - return {{confidence_masks, category_mask}}; + const std::vector& quality_scores = + output_packets[kQualityScoresStreamName].Get>(); + return {{confidence_masks, category_mask, quality_scores}}; } } // namespace interactive_segmenter diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc index a765997d8..5bb3e8ece 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc @@ -58,6 +58,7 @@ constexpr absl::string_view kAlphaTag{"ALPHA"}; constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"}; constexpr absl::string_view kNormRectTag{"NORM_RECT"}; constexpr absl::string_view kRoiTag{"ROI"}; +constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"}; // Updates the graph to return `roi` stream which has same dimension as // `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is @@ -200,6 +201,8 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph { graph[Output(kCategoryMaskTag)]; } } + image_segmenter.Out(kQualityScoresTag) >> + graph[Output>::Optional(kQualityScoresTag)]; image_segmenter.Out(kImageTag) >> graph[Output(kImageTag)]; return graph.GetConfig(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java index 3d6df3022..f977c0159 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java @@ -115,6 +115,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi { segmenterOptions.outputCategoryMask() ? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask") : -1; + final int qualityScoresOutStreamIndex = + getStreamIndex.apply(outputStreams, "QUALITY_SCORES:quality_scores"); final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out"); // TODO: Consolidate OutputHandler and TaskRunner. @@ -128,6 +130,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi { return ImageSegmenterResult.create( Optional.empty(), Optional.empty(), + new ArrayList<>(), packets.get(imageOutStreamIndex).getTimestamp()); } boolean copyImage = !segmenterOptions.resultListener().isPresent(); @@ -182,9 +185,16 @@ public final class ImageSegmenter extends BaseVisionTaskApi { new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA); categoryMask = Optional.of(builder.build()); } + float[] qualityScores = + PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex)); + List qualityScoresList = new ArrayList<>(qualityScores.length); + for (float score : qualityScores) { + qualityScoresList.add(score); + } return ImageSegmenterResult.create( confidenceMasks, categoryMask, + qualityScoresList, BaseVisionTaskApi.generateResultTimestampMs( segmenterOptions.runningMode(), packets.get(imageOutStreamIndex))); } @@ -592,8 +602,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi { public abstract Builder setOutputCategoryMask(boolean value); /** - * Sets an optional {@link ResultListener} to receive the segmentation results when the graph - * pipeline is done processing an image. + * /** Sets an optional {@link ResultListener} to receive the segmentation results when the + * graph pipeline is done processing an image. */ public abstract Builder setResultListener( ResultListener value); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java index cbc5211cc..e4ac85c2f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java @@ -34,19 +34,30 @@ public abstract class ImageSegmenterResult implements TaskResult { * @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a * category mask, where each pixel represents the class which the pixel in the original image * was predicted to belong to. + * @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Default to + * `1` if the model doesn't output quality scores. Each element corresponds to the score of + * the category in the model outputs. * @param timestampMs a timestamp for this result. */ // TODO: consolidate output formats across platforms. public static ImageSegmenterResult create( - Optional> confidenceMasks, Optional categoryMask, long timestampMs) { + Optional> confidenceMasks, + Optional categoryMask, + List qualityScores, + long timestampMs) { return new AutoValue_ImageSegmenterResult( - confidenceMasks.map(Collections::unmodifiableList), categoryMask, timestampMs); + confidenceMasks.map(Collections::unmodifiableList), + categoryMask, + Collections.unmodifiableList(qualityScores), + timestampMs); } public abstract Optional> confidenceMasks(); public abstract Optional categoryMask(); + public abstract List qualityScores(); + @Override public abstract long timestampMs(); } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java index e9ff1f2b5..fe0ce0c3f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java @@ -127,6 +127,10 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { outputStreams.add("CATEGORY_MASK:category_mask"); } final int categoryMaskOutStreamIndex = outputStreams.size() - 1; + + outputStreams.add("QUALITY_SCORES:quality_scores"); + final int qualityScoresOutStreamIndex = outputStreams.size() - 1; + outputStreams.add("IMAGE:image_out"); // TODO: add test for stream indices. final int imageOutStreamIndex = outputStreams.size() - 1; @@ -142,6 +146,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { return ImageSegmenterResult.create( Optional.empty(), Optional.empty(), + new ArrayList<>(), packets.get(imageOutStreamIndex).getTimestamp()); } // If resultListener is not provided, the resulted MPImage is deep copied from @@ -199,9 +204,17 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { categoryMask = Optional.of(builder.build()); } + float[] qualityScores = + PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex)); + List qualityScoresList = new ArrayList<>(qualityScores.length); + for (float score : qualityScores) { + qualityScoresList.add(score); + } + return ImageSegmenterResult.create( confidenceMasks, categoryMask, + qualityScoresList, BaseVisionTaskApi.generateResultTimestampMs( RunningMode.IMAGE, packets.get(imageOutStreamIndex))); } From 937a6b14228591e30f74142ba4f51e456224bb35 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 18 May 2023 20:17:07 -0700 Subject: [PATCH 08/20] Internal change PiperOrigin-RevId: 533327411 --- mediapipe/gpu/BUILD | 6 +++++- mediapipe/gpu/gpu_service.h | 3 ++- mediapipe/gpu/gpu_shared_data_internal.cc | 3 ++- mediapipe/tasks/cc/core/model_resources_cache.h | 4 ++-- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index c785e5624..e7c65b7c6 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -38,7 +38,10 @@ cc_library( srcs = ["gpu_service.cc"], hdrs = ["gpu_service.h"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:graph_service"] + select({ + deps = [ + "//mediapipe/framework:graph_service", + "@com_google_absl//absl/base:core_headers", + ] + select({ "//conditions:default": [ ":gpu_shared_data_internal", ], @@ -630,6 +633,7 @@ cc_library( "//mediapipe/framework:executor", "//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/base:core_headers", ] + select({ "//conditions:default": [], "//mediapipe:apple": [ diff --git a/mediapipe/gpu/gpu_service.h b/mediapipe/gpu/gpu_service.h index 65fecd0b8..dd3bd3bf5 100644 --- a/mediapipe/gpu/gpu_service.h +++ b/mediapipe/gpu/gpu_service.h @@ -15,6 +15,7 @@ #ifndef MEDIAPIPE_GPU_GPU_SERVICE_H_ #define MEDIAPIPE_GPU_GPU_SERVICE_H_ +#include "absl/base/attributes.h" #include "mediapipe/framework/graph_service.h" #if !MEDIAPIPE_DISABLE_GPU @@ -29,7 +30,7 @@ class GpuResources { }; #endif // MEDIAPIPE_DISABLE_GPU -extern const GraphService kGpuService; +ABSL_CONST_INIT extern const GraphService kGpuService; } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index f542f0bb2..1098c82ec 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -14,6 +14,7 @@ #include "mediapipe/gpu/gpu_shared_data_internal.h" +#include "absl/base/attributes.h" #include "mediapipe/framework/deps/no_destructor.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/gpu/gl_context.h" @@ -116,7 +117,7 @@ GpuResources::~GpuResources() { #endif // __APPLE__ } -extern const GraphService kGpuService; +ABSL_CONST_INIT extern const GraphService kGpuService; absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) { CHECK(node->Contract().ServiceRequests().contains(kGpuService.key)); diff --git a/mediapipe/tasks/cc/core/model_resources_cache.h b/mediapipe/tasks/cc/core/model_resources_cache.h index 75c24f344..113cfb2d4 100644 --- a/mediapipe/tasks/cc/core/model_resources_cache.h +++ b/mediapipe/tasks/cc/core/model_resources_cache.h @@ -103,8 +103,8 @@ class ModelResourcesCache { }; // Global service for mediapipe task model resources cache. -const mediapipe::GraphService kModelResourcesCacheService( - "mediapipe::tasks::ModelResourcesCacheService"); +inline constexpr mediapipe::GraphService + kModelResourcesCacheService("mediapipe::tasks::ModelResourcesCacheService"); } // namespace core } // namespace tasks From f219829b1d353d8410d7226aa594741acff9898e Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 19 May 2023 17:42:57 +0530 Subject: [PATCH 09/20] Removed support for Delegates from iOS --- .../tasks/ios/core/sources/MPPBaseOptions.h | 17 ----------------- .../tasks/ios/core/sources/MPPBaseOptions.m | 1 - .../utils/sources/MPPBaseOptions+Helpers.mm | 15 +-------------- 3 files changed, 1 insertion(+), 32 deletions(-) diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h index 603be803d..bef6bb9ee 100644 --- a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h @@ -16,17 +16,6 @@ NS_ASSUME_NONNULL_BEGIN -/** - * MediaPipe Tasks delegate. - */ -typedef NS_ENUM(NSUInteger, MPPDelegate) { - /** CPU. */ - MPPDelegateCPU, - - /** GPU. */ - MPPDelegateGPU -} NS_SWIFT_NAME(Delegate); - /** * Holds the base options that is used for creation of any type of task. It has fields with * important information acceleration configuration, TFLite model source etc. @@ -37,12 +26,6 @@ NS_SWIFT_NAME(BaseOptions) /** The path to the model asset to open and mmap in memory. */ @property(nonatomic, copy) NSString *modelAssetPath; -/** - * Device delegate to run the MediaPipe pipeline. If the delegate is not set, the default - * delegate CPU is used. - */ -@property(nonatomic) MPPDelegate delegate; - @end NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m index c3571c4b4..a43119ad8 100644 --- a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m @@ -28,7 +28,6 @@ MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init]; baseOptions.modelAssetPath = self.modelAssetPath; - baseOptions.delegate = self.delegate; return baseOptions; } diff --git a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm index f7f8e5a55..eceed4998 100644 --- a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm +++ b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm @@ -33,20 +33,7 @@ using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions; if (self.modelAssetPath) { baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String); } - - switch (self.delegate) { - case MPPDelegateCPU: { - baseOptionsProto->mutable_acceleration()->mutable_tflite(); - break; - } - case MPPDelegateGPU: { - // TODO: Provide an implementation for GPU Delegate. - [NSException raise:@"Invalid value for delegate" format:@"GPU Delegate is not implemented."]; - break; - } - default: - break; - } + } @end From 7c28c5d58ffbcb72043cbe8c9cc32b40aaebac41 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 19 May 2023 14:24:19 -0700 Subject: [PATCH 10/20] Fix rendering of MPMask and MPImage clone PiperOrigin-RevId: 533551170 --- mediapipe/tasks/web/vision/core/image.test.ts | 4 +++ mediapipe/tasks/web/vision/core/image.ts | 31 +++++++++++++------ mediapipe/tasks/web/vision/core/mask.test.ts | 5 ++- mediapipe/tasks/web/vision/core/mask.ts | 27 +++++++++++----- 4 files changed, 48 insertions(+), 19 deletions(-) diff --git a/mediapipe/tasks/web/vision/core/image.test.ts b/mediapipe/tasks/web/vision/core/image.test.ts index e92debc2e..3c30c7293 100644 --- a/mediapipe/tasks/web/vision/core/image.test.ts +++ b/mediapipe/tasks/web/vision/core/image.test.ts @@ -60,6 +60,10 @@ class MPImageTestContext { this.webGLTexture = gl.createTexture()!; gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR); gl.texImage2D( gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, this.imageBitmap); gl.bindTexture(gl.TEXTURE_2D, null); diff --git a/mediapipe/tasks/web/vision/core/image.ts b/mediapipe/tasks/web/vision/core/image.ts index 3b067bd78..9a5d0de86 100644 --- a/mediapipe/tasks/web/vision/core/image.ts +++ b/mediapipe/tasks/web/vision/core/image.ts @@ -187,10 +187,11 @@ export class MPImage { destinationContainer = assertNotNull(gl.createTexture(), 'Failed to create texture'); gl.bindTexture(gl.TEXTURE_2D, destinationContainer); - + this.configureTextureParams(); gl.texImage2D( gl.TEXTURE_2D, 0, gl.RGBA, this.width, this.height, 0, gl.RGBA, gl.UNSIGNED_BYTE, null); + gl.bindTexture(gl.TEXTURE_2D, null); shaderContext.bindFramebuffer(gl, destinationContainer); shaderContext.run(gl, /* flipVertically= */ false, () => { @@ -302,6 +303,20 @@ export class MPImage { return webGLTexture; } + /** Sets texture params for the currently bound texture. */ + private configureTextureParams() { + const gl = this.getGL(); + // `gl.LINEAR` might break rendering for some textures, but it allows us to + // do smooth resizing. Ideally, this would be user-configurable, but for now + // we hard-code the value here to `gl.LINEAR` (versus `gl.NEAREST` for + // `MPMask` where we do not want to interpolate mask values, especially for + // category masks). + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR); + } + /** * Binds the backing texture to the canvas. If the texture does not yet * exist, creates it first. @@ -318,16 +333,12 @@ export class MPImage { assertNotNull(gl.createTexture(), 'Failed to create texture'); this.containers.push(webGLTexture); this.ownsWebGLTexture = true; + + gl.bindTexture(gl.TEXTURE_2D, webGLTexture); + this.configureTextureParams(); + } else { + gl.bindTexture(gl.TEXTURE_2D, webGLTexture); } - - gl.bindTexture(gl.TEXTURE_2D, webGLTexture); - // TODO: Ideally, we would only set these once per texture and - // not once every frame. - gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); - gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); - gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR); - gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR); - return webGLTexture; } diff --git a/mediapipe/tasks/web/vision/core/mask.test.ts b/mediapipe/tasks/web/vision/core/mask.test.ts index b632f2dc5..d2f5ddb09 100644 --- a/mediapipe/tasks/web/vision/core/mask.test.ts +++ b/mediapipe/tasks/web/vision/core/mask.test.ts @@ -60,8 +60,11 @@ class MPMaskTestContext { } this.webGLTexture = gl.createTexture()!; - gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST); gl.texImage2D( gl.TEXTURE_2D, 0, gl.R32F, width, height, 0, gl.RED, gl.FLOAT, new Float32Array(pixels).map(v => v / 255)); diff --git a/mediapipe/tasks/web/vision/core/mask.ts b/mediapipe/tasks/web/vision/core/mask.ts index d7cf59e5f..3f37e804f 100644 --- a/mediapipe/tasks/web/vision/core/mask.ts +++ b/mediapipe/tasks/web/vision/core/mask.ts @@ -175,6 +175,7 @@ export class MPMask { destinationContainer = assertNotNull(gl.createTexture(), 'Failed to create texture'); gl.bindTexture(gl.TEXTURE_2D, destinationContainer); + this.configureTextureParams(); gl.texImage2D( gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED, gl.FLOAT, null); @@ -283,6 +284,19 @@ export class MPMask { return webGLTexture; } + /** Sets texture params for the currently bound texture. */ + private configureTextureParams() { + const gl = this.getGL(); + // `gl.NEAREST` ensures that we do not get interpolated values for + // masks. In some cases, the user might want interpolation (e.g. for + // confidence masks), so we might want to make this user-configurable. + // Note that `MPImage` uses `gl.LINEAR`. + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST); + } + /** * Binds the backing texture to the canvas. If the texture does not yet * exist, creates it first. @@ -299,15 +313,12 @@ export class MPMask { assertNotNull(gl.createTexture(), 'Failed to create texture'); this.containers.push(webGLTexture); this.ownsWebGLTexture = true; - } - gl.bindTexture(gl.TEXTURE_2D, webGLTexture); - // TODO: Ideally, we would only set these once per texture and - // not once every frame. - gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); - gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); - gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST); - gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST); + gl.bindTexture(gl.TEXTURE_2D, webGLTexture); + this.configureTextureParams(); + } else { + gl.bindTexture(gl.TEXTURE_2D, webGLTexture); + } return webGLTexture; } From 102cffdf4cc42a5403e0683424c716a89c766bd1 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 22 May 2023 05:01:36 -0700 Subject: [PATCH 11/20] Add some helpful error messages in case GL texture creation fails. PiperOrigin-RevId: 534029187 --- mediapipe/gpu/BUILD | 1 + mediapipe/gpu/gl_texture_buffer.cc | 6 +++++- mediapipe/gpu/gpu_buffer.h | 7 +++++-- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index e7c65b7c6..ee32b91e2 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -295,6 +295,7 @@ cc_library( "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:logging", "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", ] + select({ diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index f1497f741..4e5ce4ee4 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -47,6 +47,7 @@ std::unique_ptr GlTextureBuffer::Create(int width, int height, auto buf = absl::make_unique(GL_TEXTURE_2D, 0, width, height, format, nullptr); if (!buf->CreateInternal(data, alignment)) { + LOG(WARNING) << "Failed to create a GL texture"; return nullptr; } return buf; @@ -106,7 +107,10 @@ GlTextureBuffer::GlTextureBuffer(GLenum target, GLuint name, int width, bool GlTextureBuffer::CreateInternal(const void* data, int alignment) { auto context = GlContext::GetCurrent(); - if (!context) return false; + if (!context) { + LOG(WARNING) << "Cannot create a GL texture without a valid context"; + return false; + } producer_context_ = context; // Save creation GL context. diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index b9a88aa53..93eb1460e 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -20,6 +20,7 @@ #include #include +#include "absl/log/check.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/gpu/gpu_buffer_format.h" @@ -72,8 +73,10 @@ class GpuBuffer { // are not portable. Applications and calculators should normally obtain // GpuBuffers in a portable way from the framework, e.g. using // GpuBufferMultiPool. - explicit GpuBuffer(std::shared_ptr storage) - : holder_(std::make_shared(std::move(storage))) {} + explicit GpuBuffer(std::shared_ptr storage) { + CHECK(storage) << "Cannot construct GpuBuffer with null storage"; + holder_ = std::make_shared(std::move(storage)); + } #if !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER // This is used to support backward-compatible construction of GpuBuffer from From 51730ec25c785c82fd2e92c48d9721627eb9acb0 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 22 May 2023 12:57:32 -0700 Subject: [PATCH 12/20] Add iOS support for MPMask PiperOrigin-RevId: 534155657 --- mediapipe/tasks/web/vision/core/BUILD | 5 +- mediapipe/tasks/web/vision/core/mask.ts | 67 ++++++++++++++++---- mediapipe/web/graph_runner/platform_utils.ts | 13 ++++ 3 files changed, 71 insertions(+), 14 deletions(-) diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index 325603353..fa28e04a5 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -62,7 +62,10 @@ jasmine_node_test( mediapipe_ts_library( name = "mask", srcs = ["mask.ts"], - deps = [":image"], + deps = [ + ":image", + "//mediapipe/web/graph_runner:platform_utils", + ], ) mediapipe_ts_library( diff --git a/mediapipe/tasks/web/vision/core/mask.ts b/mediapipe/tasks/web/vision/core/mask.ts index 3f37e804f..9622b638f 100644 --- a/mediapipe/tasks/web/vision/core/mask.ts +++ b/mediapipe/tasks/web/vision/core/mask.ts @@ -15,6 +15,7 @@ */ import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context'; +import {isIOS} from '../../../../web/graph_runner/platform_utils'; /** Number of instances a user can keep alive before we raise a warning. */ const INSTANCE_COUNT_WARNING_THRESHOLD = 250; @@ -32,6 +33,8 @@ enum MPMaskType { /** The supported mask formats. For internal usage. */ export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture; + + /** * The wrapper class for MediaPipe segmentation masks. * @@ -56,6 +59,9 @@ export class MPMask { */ private static instancesBeforeWarning = INSTANCE_COUNT_WARNING_THRESHOLD; + /** The format used to write pixel values from textures. */ + private static texImage2DFormat?: GLenum; + /** @hideconstructor */ constructor( private readonly containers: MPMaskContainer[], @@ -127,6 +133,29 @@ export class MPMask { return this.convertToWebGLTexture(); } + /** + * Returns the texture format used for writing float textures on this + * platform. + */ + getTexImage2DFormat(): GLenum { + const gl = this.getGL(); + if (!MPMask.texImage2DFormat) { + // Note: This is the same check we use in + // `SegmentationPostprocessorGl::GetSegmentationResultGpu()`. + if (gl.getExtension('EXT_color_buffer_float') && + gl.getExtension('OES_texture_float_linear') && + gl.getExtension('EXT_float_blend')) { + MPMask.texImage2DFormat = gl.R32F; + } else if (gl.getExtension('EXT_color_buffer_half_float')) { + MPMask.texImage2DFormat = gl.R16F; + } else { + throw new Error( + 'GPU does not fully support 4-channel float32 or float16 formats'); + } + } + return MPMask.texImage2DFormat; + } + private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined; private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined; private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined; @@ -176,8 +205,9 @@ export class MPMask { assertNotNull(gl.createTexture(), 'Failed to create texture'); gl.bindTexture(gl.TEXTURE_2D, destinationContainer); this.configureTextureParams(); + const format = this.getTexImage2DFormat(); gl.texImage2D( - gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED, + gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED, gl.FLOAT, null); gl.bindTexture(gl.TEXTURE_2D, null); @@ -208,7 +238,7 @@ export class MPMask { if (!this.canvas) { throw new Error( 'Conversion to different image formats require that a canvas ' + - 'is passed when iniitializing the image.'); + 'is passed when initializing the image.'); } if (!this.gl) { this.gl = assertNotNull( @@ -216,11 +246,6 @@ export class MPMask { 'You cannot use a canvas that is already bound to a different ' + 'type of rendering context.'); } - const ext = this.gl.getExtension('EXT_color_buffer_float'); - if (!ext) { - // TODO: Ensure this works on iOS - throw new Error('Missing required EXT_color_buffer_float extension'); - } return this.gl; } @@ -238,18 +263,34 @@ export class MPMask { if (uint8Array) { float32Array = new Float32Array(uint8Array).map(v => v / 255); } else { + float32Array = new Float32Array(this.width * this.height); + const gl = this.getGL(); const shaderContext = this.getShaderContext(); - float32Array = new Float32Array(this.width * this.height); // Create texture if needed const webGlTexture = this.convertToWebGLTexture(); // Create a framebuffer from the texture and read back pixels shaderContext.bindFramebuffer(gl, webGlTexture); - gl.readPixels( - 0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array); - shaderContext.unbindFramebuffer(); + + if (isIOS()) { + // WebKit on iOS only supports gl.HALF_FLOAT for single channel reads + // (as tested on iOS 16.4). HALF_FLOAT requires reading data into a + // Uint16Array, however, and requires a manual bitwise conversion from + // Uint16 to floating point numbers. This conversion is more expensive + // that reading back a Float32Array from the RGBA image and dropping + // the superfluous data, so we do this instead. + const outputArray = new Float32Array(this.width * this.height * 4); + gl.readPixels( + 0, 0, this.width, this.height, gl.RGBA, gl.FLOAT, outputArray); + for (let i = 0, j = 0; i < float32Array.length; ++i, j += 4) { + float32Array[i] = outputArray[j]; + } + } else { + gl.readPixels( + 0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array); + } } this.containers.push(float32Array); } @@ -274,9 +315,9 @@ export class MPMask { webGLTexture = this.bindTexture(); const data = this.convertToFloat32Array(); - // TODO: Add support for R16F to support iOS + const format = this.getTexImage2DFormat(); gl.texImage2D( - gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED, + gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED, gl.FLOAT, data); this.unbindTexture(); } diff --git a/mediapipe/web/graph_runner/platform_utils.ts b/mediapipe/web/graph_runner/platform_utils.ts index 71239abab..7e1decf34 100644 --- a/mediapipe/web/graph_runner/platform_utils.ts +++ b/mediapipe/web/graph_runner/platform_utils.ts @@ -21,3 +21,16 @@ export function isWebKit(browser = navigator) { // it uses "CriOS". return userAgent.includes('Safari') && !userAgent.includes('Chrome'); } + +/** Detect if code is running on iOS. */ +export function isIOS() { + // Source: + // https://stackoverflow.com/questions/9038625/detect-if-device-is-ios + return [ + 'iPad Simulator', 'iPhone Simulator', 'iPod Simulator', 'iPad', 'iPhone', + 'iPod' + // tslint:disable-next-line:deprecation + ].includes(navigator.platform) + // iPad on iOS 13 detection + || (navigator.userAgent.includes('Mac') && 'ontouchend' in document); +} From 87f525c76b1b133c3caa0b033a097bb268c37b06 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 22 May 2023 19:59:19 -0700 Subject: [PATCH 13/20] Internal change PiperOrigin-RevId: 534264040 --- mediapipe/tasks/ios/BUILD | 4 +- .../object_detector/MPPObjectDetectorTests.m | 127 +++++++++--------- .../tasks/ios/vision/object_detector/BUILD | 12 +- .../sources/MPPObjectDetector.h | 17 ++- .../sources/MPPObjectDetector.mm | 34 ++--- .../sources/MPPObjectDetectorOptions.h | 6 +- ...tionResult.h => MPPObjectDetectorResult.h} | 8 +- ...tionResult.m => MPPObjectDetectorResult.m} | 4 +- .../ios/vision/object_detector/utils/BUILD | 8 +- ...rs.h => MPPObjectDetectorResult+Helpers.h} | 10 +- ....mm => MPPObjectDetectorResult+Helpers.mm} | 14 +- 11 files changed, 121 insertions(+), 123 deletions(-) rename mediapipe/tasks/ios/vision/object_detector/sources/{MPPObjectDetectionResult.h => MPPObjectDetectorResult.h} (86%) rename mediapipe/tasks/ios/vision/object_detector/sources/{MPPObjectDetectionResult.m => MPPObjectDetectorResult.m} (93%) rename mediapipe/tasks/ios/vision/object_detector/utils/sources/{MPPObjectDetectionResult+Helpers.h => MPPObjectDetectorResult+Helpers.h} (75%) rename mediapipe/tasks/ios/vision/object_detector/utils/sources/{MPPObjectDetectionResult+Helpers.mm => MPPObjectDetectorResult+Helpers.mm} (75%) diff --git a/mediapipe/tasks/ios/BUILD b/mediapipe/tasks/ios/BUILD index fb3d57ddd..d9be847f0 100644 --- a/mediapipe/tasks/ios/BUILD +++ b/mediapipe/tasks/ios/BUILD @@ -81,7 +81,7 @@ strip_api_include_path_prefix( "//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierResult.h", "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetector.h", "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorOptions.h", - "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectionResult.h", + "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorResult.h", ], ) @@ -162,7 +162,7 @@ apple_static_xcframework( ":MPPImageClassifierResult.h", ":MPPObjectDetector.h", ":MPPObjectDetectorOptions.h", - ":MPPObjectDetectionResult.h", + ":MPPObjectDetectorResult.h", ], deps = [ "//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifier", diff --git a/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m b/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m index 700df65a5..164159ed6 100644 --- a/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m +++ b/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m @@ -70,7 +70,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; #pragma mark Results -+ (MPPObjectDetectionResult *)expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: ++ (MPPObjectDetectorResult *)expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: (NSInteger)timestampInMilliseconds { NSArray *detections = @[ [[MPPDetection alloc] initWithCategories:@[ @@ -95,8 +95,8 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; keypoints:nil], ]; - return [[MPPObjectDetectionResult alloc] initWithDetections:detections - timestampInMilliseconds:timestampInMilliseconds]; + return [[MPPObjectDetectorResult alloc] initWithDetections:detections + timestampInMilliseconds:timestampInMilliseconds]; } - (void)assertDetections:(NSArray *)detections @@ -112,25 +112,25 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; } } -- (void)assertObjectDetectionResult:(MPPObjectDetectionResult *)objectDetectionResult - isEqualToExpectedResult:(MPPObjectDetectionResult *)expectedObjectDetectionResult - expectedDetectionsCount:(NSInteger)expectedDetectionsCount { - XCTAssertNotNil(objectDetectionResult); +- (void)assertObjectDetectorResult:(MPPObjectDetectorResult *)objectDetectorResult + isEqualToExpectedResult:(MPPObjectDetectorResult *)expectedObjectDetectorResult + expectedDetectionsCount:(NSInteger)expectedDetectionsCount { + XCTAssertNotNil(objectDetectorResult); NSArray *detectionsSubsetToCompare; - XCTAssertEqual(objectDetectionResult.detections.count, expectedDetectionsCount); - if (objectDetectionResult.detections.count > expectedObjectDetectionResult.detections.count) { - detectionsSubsetToCompare = [objectDetectionResult.detections - subarrayWithRange:NSMakeRange(0, expectedObjectDetectionResult.detections.count)]; + XCTAssertEqual(objectDetectorResult.detections.count, expectedDetectionsCount); + if (objectDetectorResult.detections.count > expectedObjectDetectorResult.detections.count) { + detectionsSubsetToCompare = [objectDetectorResult.detections + subarrayWithRange:NSMakeRange(0, expectedObjectDetectorResult.detections.count)]; } else { - detectionsSubsetToCompare = objectDetectionResult.detections; + detectionsSubsetToCompare = objectDetectorResult.detections; } [self assertDetections:detectionsSubsetToCompare - isEqualToExpectedDetections:expectedObjectDetectionResult.detections]; + isEqualToExpectedDetections:expectedObjectDetectorResult.detections]; - XCTAssertEqual(objectDetectionResult.timestampInMilliseconds, - expectedObjectDetectionResult.timestampInMilliseconds); + XCTAssertEqual(objectDetectorResult.timestampInMilliseconds, + expectedObjectDetectorResult.timestampInMilliseconds); } #pragma mark File @@ -195,28 +195,27 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; - (void)assertResultsOfDetectInImage:(MPPImage *)mppImage usingObjectDetector:(MPPObjectDetector *)objectDetector maxResults:(NSInteger)maxResults - equalsObjectDetectionResult:(MPPObjectDetectionResult *)expectedObjectDetectionResult { - MPPObjectDetectionResult *objectDetectionResult = [objectDetector detectInImage:mppImage - error:nil]; + equalsObjectDetectorResult:(MPPObjectDetectorResult *)expectedObjectDetectorResult { + MPPObjectDetectorResult *ObjectDetectorResult = [objectDetector detectInImage:mppImage error:nil]; - [self assertObjectDetectionResult:objectDetectionResult - isEqualToExpectedResult:expectedObjectDetectionResult - expectedDetectionsCount:maxResults > 0 ? maxResults - : objectDetectionResult.detections.count]; + [self assertObjectDetectorResult:ObjectDetectorResult + isEqualToExpectedResult:expectedObjectDetectorResult + expectedDetectionsCount:maxResults > 0 ? maxResults + : ObjectDetectorResult.detections.count]; } - (void)assertResultsOfDetectInImageWithFileInfo:(NSDictionary *)fileInfo usingObjectDetector:(MPPObjectDetector *)objectDetector maxResults:(NSInteger)maxResults - equalsObjectDetectionResult: - (MPPObjectDetectionResult *)expectedObjectDetectionResult { + equalsObjectDetectorResult: + (MPPObjectDetectorResult *)expectedObjectDetectorResult { MPPImage *mppImage = [self imageWithFileInfo:fileInfo]; [self assertResultsOfDetectInImage:mppImage usingObjectDetector:objectDetector maxResults:maxResults - equalsObjectDetectionResult:expectedObjectDetectionResult]; + equalsObjectDetectorResult:expectedObjectDetectorResult]; } #pragma mark General Tests @@ -266,10 +265,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage usingObjectDetector:objectDetector maxResults:-1 - equalsObjectDetectionResult: - [MPPObjectDetectorTests - expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: - 0]]; + equalsObjectDetectorResult: + [MPPObjectDetectorTests + expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: + 0]]; } - (void)testDetectWithOptionsSucceeds { @@ -280,10 +279,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage usingObjectDetector:objectDetector maxResults:-1 - equalsObjectDetectionResult: - [MPPObjectDetectorTests - expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: - 0]]; + equalsObjectDetectorResult: + [MPPObjectDetectorTests + expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: + 0]]; } - (void)testDetectWithMaxResultsSucceeds { @@ -297,10 +296,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage usingObjectDetector:objectDetector maxResults:maxResults - equalsObjectDetectionResult: - [MPPObjectDetectorTests - expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: - 0]]; + equalsObjectDetectorResult: + [MPPObjectDetectorTests + expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: + 0]]; } - (void)testDetectWithScoreThresholdSucceeds { @@ -316,13 +315,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; boundingBox:CGRectMake(608, 161, 381, 439) keypoints:nil], ]; - MPPObjectDetectionResult *expectedObjectDetectionResult = - [[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0]; + MPPObjectDetectorResult *expectedObjectDetectorResult = + [[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0]; [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage usingObjectDetector:objectDetector maxResults:-1 - equalsObjectDetectionResult:expectedObjectDetectionResult]; + equalsObjectDetectorResult:expectedObjectDetectorResult]; } - (void)testDetectWithCategoryAllowlistSucceeds { @@ -359,13 +358,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; keypoints:nil], ]; - MPPObjectDetectionResult *expectedDetectionResult = - [[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0]; + MPPObjectDetectorResult *expectedDetectionResult = + [[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0]; [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage usingObjectDetector:objectDetector maxResults:-1 - equalsObjectDetectionResult:expectedDetectionResult]; + equalsObjectDetectorResult:expectedDetectionResult]; } - (void)testDetectWithCategoryDenylistSucceeds { @@ -414,13 +413,13 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; keypoints:nil], ]; - MPPObjectDetectionResult *expectedDetectionResult = - [[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0]; + MPPObjectDetectorResult *expectedDetectionResult = + [[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0]; [self assertResultsOfDetectInImageWithFileInfo:kCatsAndDogsImage usingObjectDetector:objectDetector maxResults:-1 - equalsObjectDetectionResult:expectedDetectionResult]; + equalsObjectDetectorResult:expectedDetectionResult]; } - (void)testDetectWithOrientationSucceeds { @@ -437,8 +436,8 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; keypoints:nil], ]; - MPPObjectDetectionResult *expectedDetectionResult = - [[MPPObjectDetectionResult alloc] initWithDetections:detections timestampInMilliseconds:0]; + MPPObjectDetectorResult *expectedDetectionResult = + [[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0]; MPPImage *image = [self imageWithFileInfo:kCatsAndDogsRotatedImage orientation:UIImageOrientationRight]; @@ -446,7 +445,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; [self assertResultsOfDetectInImage:image usingObjectDetector:objectDetector maxResults:1 - equalsObjectDetectionResult:expectedDetectionResult]; + equalsObjectDetectorResult:expectedDetectionResult]; } #pragma mark Running Mode Tests @@ -613,15 +612,15 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; MPPImage *image = [self imageWithFileInfo:kCatsAndDogsImage]; for (int i = 0; i < 3; i++) { - MPPObjectDetectionResult *objectDetectionResult = [objectDetector detectInVideoFrame:image - timestampInMilliseconds:i - error:nil]; + MPPObjectDetectorResult *ObjectDetectorResult = [objectDetector detectInVideoFrame:image + timestampInMilliseconds:i + error:nil]; - [self assertObjectDetectionResult:objectDetectionResult - isEqualToExpectedResult: - [MPPObjectDetectorTests - expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:i] - expectedDetectionsCount:maxResults]; + [self assertObjectDetectorResult:ObjectDetectorResult + isEqualToExpectedResult: + [MPPObjectDetectorTests + expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds:i] + expectedDetectionsCount:maxResults]; } } @@ -714,16 +713,16 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; #pragma mark MPPObjectDetectorLiveStreamDelegate Methods - (void)objectDetector:(MPPObjectDetector *)objectDetector - didFinishDetectionWithResult:(MPPObjectDetectionResult *)objectDetectionResult + didFinishDetectionWithResult:(MPPObjectDetectorResult *)ObjectDetectorResult timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(NSError *)error { NSInteger maxResults = 4; - [self assertObjectDetectionResult:objectDetectionResult - isEqualToExpectedResult: - [MPPObjectDetectorTests - expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: - timestampInMilliseconds] - expectedDetectionsCount:maxResults]; + [self assertObjectDetectorResult:ObjectDetectorResult + isEqualToExpectedResult: + [MPPObjectDetectorTests + expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: + timestampInMilliseconds] + expectedDetectionsCount:maxResults]; if (objectDetector == outOfOrderTimestampTestDict[kLiveStreamTestsDictObjectDetectorKey]) { [outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill]; diff --git a/mediapipe/tasks/ios/vision/object_detector/BUILD b/mediapipe/tasks/ios/vision/object_detector/BUILD index 81c97c894..002a59920 100644 --- a/mediapipe/tasks/ios/vision/object_detector/BUILD +++ b/mediapipe/tasks/ios/vision/object_detector/BUILD @@ -17,9 +17,9 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) objc_library( - name = "MPPObjectDetectionResult", - srcs = ["sources/MPPObjectDetectionResult.m"], - hdrs = ["sources/MPPObjectDetectionResult.h"], + name = "MPPObjectDetectorResult", + srcs = ["sources/MPPObjectDetectorResult.m"], + hdrs = ["sources/MPPObjectDetectorResult.h"], deps = [ "//mediapipe/tasks/ios/components/containers:MPPDetection", "//mediapipe/tasks/ios/core:MPPTaskResult", @@ -31,7 +31,7 @@ objc_library( srcs = ["sources/MPPObjectDetectorOptions.m"], hdrs = ["sources/MPPObjectDetectorOptions.h"], deps = [ - ":MPPObjectDetectionResult", + ":MPPObjectDetectorResult", "//mediapipe/tasks/ios/core:MPPTaskOptions", "//mediapipe/tasks/ios/vision/core:MPPRunningMode", ], @@ -47,8 +47,8 @@ objc_library( "-x objective-c++", ], deps = [ - ":MPPObjectDetectionResult", ":MPPObjectDetectorOptions", + ":MPPObjectDetectorResult", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:NSStringHelpers", @@ -56,7 +56,7 @@ objc_library( "//mediapipe/tasks/ios/vision/core:MPPImage", "//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator", "//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner", - "//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectionResultHelpers", "//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectorOptionsHelpers", + "//mediapipe/tasks/ios/vision/object_detector/utils:MPPObjectDetectorResultHelpers", ], ) diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.h b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.h index ae18bf58d..d3f946bbe 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.h +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.h @@ -15,8 +15,8 @@ #import #import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h" -#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h" #import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h" +#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h" NS_ASSUME_NONNULL_BEGIN @@ -109,14 +109,13 @@ NS_SWIFT_NAME(ObjectDetector) * @param error An optional error parameter populated when there is an error in performing object * detection on the input image. * - * @return An `MPPObjectDetectionResult` object that contains a list of detections, each detection + * @return An `MPPObjectDetectorResult` object that contains a list of detections, each detection * has a bounding box that is expressed in the unrotated input frame of reference coordinates * system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying * image data. */ -- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image - error:(NSError **)error - NS_SWIFT_NAME(detect(image:)); +- (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image + error:(NSError **)error NS_SWIFT_NAME(detect(image:)); /** * Performs object detection on the provided video frame of type `MPPImage` using the whole @@ -139,14 +138,14 @@ NS_SWIFT_NAME(ObjectDetector) * @param error An optional error parameter populated when there is an error in performing object * detection on the input image. * - * @return An `MPPObjectDetectionResult` object that contains a list of detections, each detection + * @return An `MPPObjectDetectorResult` object that contains a list of detections, each detection * has a bounding box that is expressed in the unrotated input frame of reference coordinates * system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying * image data. */ -- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image - timestampInMilliseconds:(NSInteger)timestampInMilliseconds - error:(NSError **)error +- (nullable MPPObjectDetectorResult *)detectInVideoFrame:(MPPImage *)image + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(NSError **)error NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:)); /** diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.mm b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.mm index a5b4077be..27b196d7f 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.mm +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.mm @@ -19,8 +19,8 @@ #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" #import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h" #import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h" -#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.h" #import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorOptions+Helpers.h" +#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.h" namespace { using ::mediapipe::NormalizedRect; @@ -118,9 +118,9 @@ static NSString *const kTaskName = @"objectDetector"; return; } - MPPObjectDetectionResult *result = [MPPObjectDetectionResult - objectDetectionResultWithDetectionsPacket:statusOrPackets.value()[kDetectionsStreamName - .cppString]]; + MPPObjectDetectorResult *result = [MPPObjectDetectorResult + objectDetectorResultWithDetectionsPacket:statusOrPackets + .value()[kDetectionsStreamName.cppString]]; NSInteger timeStampInMilliseconds = outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() / @@ -184,9 +184,9 @@ static NSString *const kTaskName = @"objectDetector"; return inputPacketMap; } -- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image - regionOfInterest:(CGRect)roi - error:(NSError **)error { +- (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image + regionOfInterest:(CGRect)roi + error:(NSError **)error { std::optional rect = [_visionTaskRunner normalizedRectFromRegionOfInterest:roi imageSize:CGSizeMake(image.width, image.height) @@ -213,18 +213,18 @@ static NSString *const kTaskName = @"objectDetector"; return nil; } - return [MPPObjectDetectionResult - objectDetectionResultWithDetectionsPacket:outputPacketMap - .value()[kDetectionsStreamName.cppString]]; + return [MPPObjectDetectorResult + objectDetectorResultWithDetectionsPacket:outputPacketMap + .value()[kDetectionsStreamName.cppString]]; } -- (nullable MPPObjectDetectionResult *)detectInImage:(MPPImage *)image error:(NSError **)error { +- (nullable MPPObjectDetectorResult *)detectInImage:(MPPImage *)image error:(NSError **)error { return [self detectInImage:image regionOfInterest:CGRectZero error:error]; } -- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image - timestampInMilliseconds:(NSInteger)timestampInMilliseconds - error:(NSError **)error { +- (nullable MPPObjectDetectorResult *)detectInVideoFrame:(MPPImage *)image + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(NSError **)error { std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image timestampInMilliseconds:timestampInMilliseconds error:error]; @@ -239,9 +239,9 @@ static NSString *const kTaskName = @"objectDetector"; return nil; } - return [MPPObjectDetectionResult - objectDetectionResultWithDetectionsPacket:outputPacketMap - .value()[kDetectionsStreamName.cppString]]; + return [MPPObjectDetectorResult + objectDetectorResultWithDetectionsPacket:outputPacketMap + .value()[kDetectionsStreamName.cppString]]; } - (BOOL)detectAsyncInImage:(MPPImage *)image diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h index 182714c03..33d7bdbbb 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h @@ -16,7 +16,7 @@ #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h" -#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h" +#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h" NS_ASSUME_NONNULL_BEGIN @@ -44,7 +44,7 @@ NS_SWIFT_NAME(ObjectDetectorLiveStreamDelegate) * * @param objectDetector The object detector which performed the object detection. * This is useful to test equality when there are multiple instances of `MPPObjectDetector`. - * @param result The `MPPObjectDetectionResult` object that contains a list of detections, each + * @param result The `MPPObjectDetectorResult` object that contains a list of detections, each * detection has a bounding box that is expressed in the unrotated input frame of reference * coordinates system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the * underlying image data. @@ -54,7 +54,7 @@ NS_SWIFT_NAME(ObjectDetectorLiveStreamDelegate) * detection on the input live stream image data. */ - (void)objectDetector:(MPPObjectDetector *)objectDetector - didFinishDetectionWithResult:(nullable MPPObjectDetectionResult *)result + didFinishDetectionWithResult:(nullable MPPObjectDetectorResult *)result timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(nullable NSError *)error NS_SWIFT_NAME(objectDetector(_:didFinishDetection:timestampInMilliseconds:error:)); diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h similarity index 86% rename from mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h rename to mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h index da9899d40..2641b6b4e 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h @@ -19,8 +19,8 @@ NS_ASSUME_NONNULL_BEGIN /** Represents the detection results generated by `MPPObjectDetector`. */ -NS_SWIFT_NAME(ObjectDetectionResult) -@interface MPPObjectDetectionResult : MPPTaskResult +NS_SWIFT_NAME(ObjectDetectorResult) +@interface MPPObjectDetectorResult : MPPTaskResult /** * The array of `MPPDetection` objects each of which has a bounding box that is expressed in the @@ -30,7 +30,7 @@ NS_SWIFT_NAME(ObjectDetectionResult) @property(nonatomic, readonly) NSArray *detections; /** - * Initializes a new `MPPObjectDetectionResult` with the given array of detections and timestamp (in + * Initializes a new `MPPObjectDetectorResult` with the given array of detections and timestamp (in * milliseconds). * * @param detections An array of `MPPDetection` objects each of which has a bounding box that is @@ -38,7 +38,7 @@ NS_SWIFT_NAME(ObjectDetectionResult) * x [0,image_height)`, which are the dimensions of the underlying image data. * @param timestampInMilliseconds The timestamp (in milliseconds) for this result. * - * @return An instance of `MPPObjectDetectionResult` initialized with the given array of detections + * @return An instance of `MPPObjectDetectorResult` initialized with the given array of detections * and timestamp (in milliseconds). */ - (instancetype)initWithDetections:(NSArray *)detections diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.m b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.m similarity index 93% rename from mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.m rename to mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.m index 47902bba4..568fbcff7 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.m +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.m @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h" +#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h" -@implementation MPPObjectDetectionResult +@implementation MPPObjectDetectorResult - (instancetype)initWithDetections:(NSArray *)detections timestampInMilliseconds:(NSInteger)timestampInMilliseconds { diff --git a/mediapipe/tasks/ios/vision/object_detector/utils/BUILD b/mediapipe/tasks/ios/vision/object_detector/utils/BUILD index d4597b239..63ddbd0b8 100644 --- a/mediapipe/tasks/ios/vision/object_detector/utils/BUILD +++ b/mediapipe/tasks/ios/vision/object_detector/utils/BUILD @@ -31,12 +31,12 @@ objc_library( ) objc_library( - name = "MPPObjectDetectionResultHelpers", - srcs = ["sources/MPPObjectDetectionResult+Helpers.mm"], - hdrs = ["sources/MPPObjectDetectionResult+Helpers.h"], + name = "MPPObjectDetectorResultHelpers", + srcs = ["sources/MPPObjectDetectorResult+Helpers.mm"], + hdrs = ["sources/MPPObjectDetectorResult+Helpers.h"], deps = [ "//mediapipe/framework:packet", "//mediapipe/tasks/ios/components/containers/utils:MPPDetectionHelpers", - "//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetectionResult", + "//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetectorResult", ], ) diff --git a/mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.h b/mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.h similarity index 75% rename from mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.h rename to mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.h index 645f5c565..377e6e323 100644 --- a/mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.h +++ b/mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.h @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h" +#import "mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorResult.h" #include "mediapipe/framework/packet.h" @@ -20,17 +20,17 @@ NS_ASSUME_NONNULL_BEGIN static const int kMicroSecondsPerMilliSecond = 1000; -@interface MPPObjectDetectionResult (Helpers) +@interface MPPObjectDetectorResult (Helpers) /** - * Creates an `MPPObjectDetectionResult` from a MediaPipe packet containing a + * Creates an `MPPObjectDetectorResult` from a MediaPipe packet containing a * `std::vector`. * * @param packet a MediaPipe packet wrapping a `std::vector`. * - * @return An `MPPObjectDetectionResult` object that contains a list of detections. + * @return An `MPPObjectDetectorResult` object that contains a list of detections. */ -+ (nullable MPPObjectDetectionResult *)objectDetectionResultWithDetectionsPacket: ++ (nullable MPPObjectDetectorResult *)objectDetectorResultWithDetectionsPacket: (const mediapipe::Packet &)packet; @end diff --git a/mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.mm b/mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.mm similarity index 75% rename from mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.mm rename to mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.mm index 225a6993d..b2f9cfc08 100644 --- a/mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.mm +++ b/mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.mm @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.h" +#import "mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectorResult+Helpers.h" #import "mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.h" @@ -21,9 +21,9 @@ using DetectionProto = ::mediapipe::Detection; using ::mediapipe::Packet; } // namespace -@implementation MPPObjectDetectionResult (Helpers) +@implementation MPPObjectDetectorResult (Helpers) -+ (nullable MPPObjectDetectionResult *)objectDetectionResultWithDetectionsPacket: ++ (nullable MPPObjectDetectorResult *)objectDetectorResultWithDetectionsPacket: (const Packet &)packet { if (!packet.ValidateAsType>().ok()) { return nil; @@ -37,10 +37,10 @@ using ::mediapipe::Packet; [detections addObject:[MPPDetection detectionWithProto:detectionProto]]; } - return [[MPPObjectDetectionResult alloc] - initWithDetections:detections - timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() / - kMicroSecondsPerMilliSecond)]; + return + [[MPPObjectDetectorResult alloc] initWithDetections:detections + timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond)]; } @end From fda001b666c8f46aa5072a8d42ecdcee77fec1a7 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Tue, 23 May 2023 19:31:14 +0530 Subject: [PATCH 14/20] Updated error tests to use XCTAssertEqualObjects --- .../ios/test/text/text_classifier/MPPTextClassifierTests.m | 4 +--- .../ios/test/text/text_embedder/MPPTextEmbedderTests.m | 4 +--- .../test/vision/image_classifier/MPPImageClassifierTests.m | 6 ++---- .../test/vision/object_detector/MPPObjectDetectorTests.m | 4 +--- 4 files changed, 5 insertions(+), 13 deletions(-) diff --git a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m index a80fb8824..2ffccd2b7 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m +++ b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m @@ -28,9 +28,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; XCTAssertNotNil(error); \ XCTAssertEqualObjects(error.domain, expectedError.domain); \ XCTAssertEqual(error.code, expectedError.code); \ - XCTAssertNotEqual( \ - [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \ - NSNotFound) + XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription) #define AssertEqualCategoryArrays(categories, expectedCategories) \ XCTAssertEqual(categories.count, expectedCategories.count); \ diff --git a/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m index b32bb076c..51d37667c 100644 --- a/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m +++ b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m @@ -29,9 +29,7 @@ static const float kSimilarityDiffTolerance = 1e-4; XCTAssertNotNil(error); \ XCTAssertEqualObjects(error.domain, expectedError.domain); \ XCTAssertEqual(error.code, expectedError.code); \ - XCTAssertNotEqual( \ - [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \ - NSNotFound) + XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription) \ #define AssertTextEmbedderResultHasOneEmbedding(textEmbedderResult) \ XCTAssertNotNil(textEmbedderResult); \ diff --git a/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m b/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m index 8db71a11b..62cfb1f84 100644 --- a/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m +++ b/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m @@ -34,10 +34,8 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; XCTAssertNotNil(error); \ XCTAssertEqualObjects(error.domain, expectedError.domain); \ XCTAssertEqual(error.code, expectedError.code); \ - XCTAssertNotEqual( \ - [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \ - NSNotFound) - + XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription) \ + #define AssertEqualCategoryArrays(categories, expectedCategories) \ XCTAssertEqual(categories.count, expectedCategories.count); \ for (int i = 0; i < categories.count; i++) { \ diff --git a/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m b/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m index 164159ed6..c438a789a 100644 --- a/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m +++ b/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m @@ -32,9 +32,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; XCTAssertNotNil(error); \ XCTAssertEqualObjects(error.domain, expectedError.domain); \ XCTAssertEqual(error.code, expectedError.code); \ - XCTAssertNotEqual( \ - [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \ - NSNotFound) + XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription) #define AssertEqualCategories(category, expectedCategory, detectionIndex, categoryIndex) \ XCTAssertEqual(category.index, expectedCategory.index, \ From 3eb97ae1ff13660ef4a1110378075f3b7268b26b Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Tue, 23 May 2023 20:54:42 +0530 Subject: [PATCH 15/20] Updated Image classifier result to return empty results if packet can't be validated. --- .../sources/MPPImageClassifierResult+Helpers.mm | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm index 61ae785d1..e9f74a0cd 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm +++ b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm @@ -29,13 +29,20 @@ using ::mediapipe::Packet; + (nullable MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket: (const Packet &)packet { - MPPClassificationResult *classificationResult; + // Even if packet does not validate a the expected type, you can safely access the timestamp. + NSInteger timestampInMilliSeconds = + (NSInteger)(packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); if (!packet.ValidateAsType().ok()) { - return nil; + // MPPClassificationResult's timestamp is populated from timestamp `ClassificationResultProto`'s + // timestamp_ms(). It is 0 since the packet can't be validated as a `ClassificationResultProto`. + return [[MPPImageClassifierResult alloc] + initWithClassificationResult:[[MPPClassificationResult alloc] initWithClassifications:@[] + timestampInMilliseconds:0] + timestampInMilliseconds:timestampInMilliSeconds]; } - classificationResult = [MPPClassificationResult + MPPClassificationResult *classificationResult = [MPPClassificationResult classificationResultWithProto:packet.Get()]; return [[MPPImageClassifierResult alloc] From 1fe78180c86ef400c2f9b1af70ccb63a0175cee2 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 23 May 2023 11:33:56 -0700 Subject: [PATCH 16/20] Add quality scores to Segmenter tasks PiperOrigin-RevId: 534497957 --- .../image_segmenter/image_segmenter_result.h | 2 +- .../imagesegmenter/ImageSegmenterResult.java | 4 ++-- .../vision/image_segmenter/image_segmenter.ts | 21 +++++++++++++++++-- .../image_segmenter/image_segmenter_result.ts | 8 ++++++- .../image_segmenter/image_segmenter_test.ts | 16 +++++++++++++- .../interactive_segmenter.ts | 19 ++++++++++++++++- .../interactive_segmenter_result.ts | 8 ++++++- .../interactive_segmenter_test.ts | 18 ++++++++++++++-- 8 files changed, 85 insertions(+), 11 deletions(-) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h index a203718f4..7f159cc39 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h @@ -33,7 +33,7 @@ struct ImageSegmenterResult { // A category mask of uint8 image in GRAY8 format where each pixel represents // the class which the pixel in the original image was predicted to belong to. std::optional category_mask; - // The quality scores of the result masks, in the range of [0, 1]. Default to + // The quality scores of the result masks, in the range of [0, 1]. Defaults to // `1` if the model doesn't output quality scores. Each element corresponds to // the score of the category in the model outputs. std::vector quality_scores; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java index e4ac85c2f..7f567c1a4 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java @@ -34,8 +34,8 @@ public abstract class ImageSegmenterResult implements TaskResult { * @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a * category mask, where each pixel represents the class which the pixel in the original image * was predicted to belong to. - * @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Default to - * `1` if the model doesn't output quality scores. Each element corresponds to the score of + * @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Defaults + * to `1` if the model doesn't output quality scores. Each element corresponds to the score of * the category in the model outputs. * @param timestampMs a timestamp for this result. */ diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts index ee9caaa1f..3dd2d03ef 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts @@ -39,6 +39,7 @@ const IMAGE_STREAM = 'image_in'; const NORM_RECT_STREAM = 'norm_rect'; const CONFIDENCE_MASKS_STREAM = 'confidence_masks'; const CATEGORY_MASK_STREAM = 'category_mask'; +const QUALITY_SCORES_STREAM = 'quality_scores'; const IMAGE_SEGMENTER_GRAPH = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'; const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = @@ -61,6 +62,7 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void; export class ImageSegmenter extends VisionTaskRunner { private categoryMask?: MPMask; private confidenceMasks?: MPMask[]; + private qualityScores?: number[]; private labels: string[] = []; private userCallback?: ImageSegmenterCallback; private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; @@ -367,12 +369,13 @@ export class ImageSegmenter extends VisionTaskRunner { private reset(): void { this.categoryMask = undefined; this.confidenceMasks = undefined; + this.qualityScores = undefined; } private processResults(): ImageSegmenterResult|void { try { - const result = - new ImageSegmenterResult(this.confidenceMasks, this.categoryMask); + const result = new ImageSegmenterResult( + this.confidenceMasks, this.categoryMask, this.qualityScores); if (this.userCallback) { this.userCallback(result); } else { @@ -442,6 +445,20 @@ export class ImageSegmenter extends VisionTaskRunner { }); } + graphConfig.addOutputStream(QUALITY_SCORES_STREAM); + segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM); + + this.graphRunner.attachFloatVectorListener( + QUALITY_SCORES_STREAM, (scores, timestamp) => { + this.qualityScores = scores; + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + QUALITY_SCORES_STREAM, timestamp => { + this.categoryMask = undefined; + this.setLatestOutputTimestamp(timestamp); + }); + const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); } diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.ts index 9107a5c80..363cc213d 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.ts @@ -30,7 +30,13 @@ export class ImageSegmenterResult { * `WebGLTexture`-backed `MPImage` where each pixel represents the class * which the pixel in the original image was predicted to belong to. */ - readonly categoryMask?: MPMask) {} + readonly categoryMask?: MPMask, + /** + * The quality scores of the result masks, in the range of [0, 1]. + * Defaults to `1` if the model doesn't output quality scores. Each + * element corresponds to the score of the category in the model outputs. + */ + readonly qualityScores?: number[]) {} /** Frees the resources held by the category and confidence masks. */ close(): void { diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts index 10983b488..c245001b2 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts @@ -35,6 +35,8 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake { ((images: WasmImage, timestamp: number) => void)|undefined; confidenceMasksListener: ((images: WasmImage[], timestamp: number) => void)|undefined; + qualityScoresListener: + ((data: number[], timestamp: number) => void)|undefined; constructor() { super(createSpyWasmModule(), /* glCanvas= */ null); @@ -52,6 +54,12 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake { expect(stream).toEqual('confidence_masks'); this.confidenceMasksListener = listener; }); + this.attachListenerSpies[2] = + spyOn(this.graphRunner, 'attachFloatVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('quality_scores'); + this.qualityScoresListener = listener; + }); spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); }); @@ -266,6 +274,7 @@ describe('ImageSegmenter', () => { it('invokes listener after masks are available', async () => { const categoryMask = new Uint8Array([1]); const confidenceMask = new Float32Array([0.0]); + const qualityScores = [1.0]; let listenerCalled = false; await imageSegmenter.setOptions( @@ -283,11 +292,16 @@ describe('ImageSegmenter', () => { ], 1337); expect(listenerCalled).toBeFalse(); + imageSegmenter.qualityScoresListener!(qualityScores, 1337); + expect(listenerCalled).toBeFalse(); }); return new Promise(resolve => { - imageSegmenter.segment({} as HTMLImageElement, () => { + imageSegmenter.segment({} as HTMLImageElement, result => { listenerCalled = true; + expect(result.categoryMask).toBeInstanceOf(MPMask); + expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask); + expect(result.qualityScores).toEqual(qualityScores); resolve(); }); }); diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts index 16bf10eeb..662eaf09a 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts @@ -42,6 +42,7 @@ const NORM_RECT_IN_STREAM = 'norm_rect_in'; const ROI_IN_STREAM = 'roi_in'; const CONFIDENCE_MASKS_STREAM = 'confidence_masks'; const CATEGORY_MASK_STREAM = 'category_mask'; +const QUALITY_SCORES_STREAM = 'quality_scores'; const IMAGEA_SEGMENTER_GRAPH = 'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph'; const DEFAULT_OUTPUT_CATEGORY_MASK = false; @@ -86,6 +87,7 @@ export type InteractiveSegmenterCallback = export class InteractiveSegmenter extends VisionTaskRunner { private categoryMask?: MPMask; private confidenceMasks?: MPMask[]; + private qualityScores?: number[]; private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; private userCallback?: InteractiveSegmenterCallback; @@ -284,12 +286,13 @@ export class InteractiveSegmenter extends VisionTaskRunner { private reset(): void { this.confidenceMasks = undefined; this.categoryMask = undefined; + this.qualityScores = undefined; } private processResults(): InteractiveSegmenterResult|void { try { const result = new InteractiveSegmenterResult( - this.confidenceMasks, this.categoryMask); + this.confidenceMasks, this.categoryMask, this.qualityScores); if (this.userCallback) { this.userCallback(result); } else { @@ -361,6 +364,20 @@ export class InteractiveSegmenter extends VisionTaskRunner { }); } + graphConfig.addOutputStream(QUALITY_SCORES_STREAM); + segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM); + + this.graphRunner.attachFloatVectorListener( + QUALITY_SCORES_STREAM, (scores, timestamp) => { + this.qualityScores = scores; + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + QUALITY_SCORES_STREAM, timestamp => { + this.categoryMask = undefined; + this.setLatestOutputTimestamp(timestamp); + }); + const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); } diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.ts index 5da7e4df3..3b45d09e7 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.ts @@ -30,7 +30,13 @@ export class InteractiveSegmenterResult { * `WebGLTexture`-backed `MPImage` where each pixel represents the class * which the pixel in the original image was predicted to belong to. */ - readonly categoryMask?: MPMask) {} + readonly categoryMask?: MPMask, + /** + * The quality scores of the result masks, in the range of [0, 1]. + * Defaults to `1` if the model doesn't output quality scores. Each + * element corresponds to the score of the category in the model outputs. + */ + readonly qualityScores?: number[]) {} /** Frees the resources held by the category and confidence masks. */ close(): void { diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts index 6550202e0..fe2d27157 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts @@ -46,6 +46,8 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements ((images: WasmImage, timestamp: number) => void)|undefined; confidenceMasksListener: ((images: WasmImage[], timestamp: number) => void)|undefined; + qualityScoresListener: + ((data: number[], timestamp: number) => void)|undefined; lastRoi?: RenderDataProto; constructor() { @@ -64,6 +66,12 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements expect(stream).toEqual('confidence_masks'); this.confidenceMasksListener = listener; }); + this.attachListenerSpies[2] = + spyOn(this.graphRunner, 'attachFloatVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('quality_scores'); + this.qualityScoresListener = listener; + }); spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); }); @@ -277,9 +285,10 @@ describe('InteractiveSegmenter', () => { }); }); - it('invokes listener after masks are avaiblae', async () => { + it('invokes listener after masks are available', async () => { const categoryMask = new Uint8Array([1]); const confidenceMask = new Float32Array([0.0]); + const qualityScores = [1.0]; let listenerCalled = false; await interactiveSegmenter.setOptions( @@ -297,11 +306,16 @@ describe('InteractiveSegmenter', () => { ], 1337); expect(listenerCalled).toBeFalse(); + interactiveSegmenter.qualityScoresListener!(qualityScores, 1337); + expect(listenerCalled).toBeFalse(); }); return new Promise(resolve => { - interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => { + interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => { listenerCalled = true; + expect(result.categoryMask).toBeInstanceOf(MPMask); + expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask); + expect(result.qualityScores).toEqual(qualityScores); resolve(); }); }); From e8ee934bf99a5cb209c1132d08879f34c8c13894 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 23 May 2023 15:07:06 -0700 Subject: [PATCH 17/20] Use empty keypoint array for Detection if no keypoints are detected PiperOrigin-RevId: 534572162 --- .../web/components/containers/detection_result.d.ts | 13 +++++++------ .../components/processors/detection_result.test.ts | 3 ++- .../web/components/processors/detection_result.ts | 3 +-- .../web/vision/face_detector/face_detector_test.ts | 3 ++- .../vision/object_detector/object_detector_test.ts | 3 ++- 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/mediapipe/tasks/web/components/containers/detection_result.d.ts b/mediapipe/tasks/web/components/containers/detection_result.d.ts index 287632f2d..590bd7340 100644 --- a/mediapipe/tasks/web/components/containers/detection_result.d.ts +++ b/mediapipe/tasks/web/components/containers/detection_result.d.ts @@ -27,13 +27,14 @@ export declare interface Detection { boundingBox?: BoundingBox; /** - * Optional list of keypoints associated with the detection. Keypoints - * represent interesting points related to the detection. For example, the - * keypoints represent the eye, ear and mouth from face detection model. Or - * in the template matching detection, e.g. KNIFT, they can represent the - * feature points for template matching. + * List of keypoints associated with the detection. Keypoints represent + * interesting points related to the detection. For example, the keypoints + * represent the eye, ear and mouth from face detection model. Or in the + * template matching detection, e.g. KNIFT, they can represent the feature + * points for template matching. Contains an empty list if no keypoints are + * detected. */ - keypoints?: NormalizedKeypoint[]; + keypoints: NormalizedKeypoint[]; } /** Detection results of a model. */ diff --git a/mediapipe/tasks/web/components/processors/detection_result.test.ts b/mediapipe/tasks/web/components/processors/detection_result.test.ts index 28b37ab0e..0fa8156ba 100644 --- a/mediapipe/tasks/web/components/processors/detection_result.test.ts +++ b/mediapipe/tasks/web/components/processors/detection_result.test.ts @@ -85,7 +85,8 @@ describe('convertFromDetectionProto()', () => { categoryName: '', displayName: '', }], - boundingBox: {originX: 0, originY: 0, width: 0, height: 0} + boundingBox: {originX: 0, originY: 0, width: 0, height: 0}, + keypoints: [] }); }); }); diff --git a/mediapipe/tasks/web/components/processors/detection_result.ts b/mediapipe/tasks/web/components/processors/detection_result.ts index 304b51a8e..4999ed31b 100644 --- a/mediapipe/tasks/web/components/processors/detection_result.ts +++ b/mediapipe/tasks/web/components/processors/detection_result.ts @@ -26,7 +26,7 @@ export function convertFromDetectionProto(source: DetectionProto): Detection { const labels = source.getLabelList(); const displayNames = source.getDisplayNameList(); - const detection: Detection = {categories: []}; + const detection: Detection = {categories: [], keypoints: []}; for (let i = 0; i < scores.length; i++) { detection.categories.push({ score: scores[i], @@ -47,7 +47,6 @@ export function convertFromDetectionProto(source: DetectionProto): Detection { } if (source.getLocationData()?.getRelativeKeypointsList().length) { - detection.keypoints = []; for (const keypoint of source.getLocationData()!.getRelativeKeypointsList()) { detection.keypoints.push({ diff --git a/mediapipe/tasks/web/vision/face_detector/face_detector_test.ts b/mediapipe/tasks/web/vision/face_detector/face_detector_test.ts index 28f602965..dfe84bb17 100644 --- a/mediapipe/tasks/web/vision/face_detector/face_detector_test.ts +++ b/mediapipe/tasks/web/vision/face_detector/face_detector_test.ts @@ -191,7 +191,8 @@ describe('FaceDetector', () => { categoryName: '', displayName: '', }], - boundingBox: {originX: 0, originY: 0, width: 0, height: 0} + boundingBox: {originX: 0, originY: 0, width: 0, height: 0}, + keypoints: [] }); }); }); diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts index 1613f27d7..9c63eaba1 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -210,7 +210,8 @@ describe('ObjectDetector', () => { categoryName: '', displayName: '', }], - boundingBox: {originX: 0, originY: 0, width: 0, height: 0} + boundingBox: {originX: 0, originY: 0, width: 0, height: 0}, + keypoints: [] }); }); }); From 3dcfca3a731b0bfe7c86f0354fa4ecfa39b83fc2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 23 May 2023 17:56:25 -0700 Subject: [PATCH 18/20] Fix deprecated usages * In status_builder.h to use absl::Status directly * In type_map.h to use kTypeId.hash_code() directly PiperOrigin-RevId: 534622923 --- mediapipe/framework/deps/status_builder.cc | 4 ++-- mediapipe/framework/deps/status_builder.h | 4 ++-- mediapipe/framework/type_map.h | 15 +++++++++------ 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/mediapipe/framework/deps/status_builder.cc b/mediapipe/framework/deps/status_builder.cc index 0202b8689..041b8608b 100644 --- a/mediapipe/framework/deps/status_builder.cc +++ b/mediapipe/framework/deps/status_builder.cc @@ -68,11 +68,11 @@ StatusBuilder&& StatusBuilder::SetNoLogging() && { return std::move(SetNoLogging()); } -StatusBuilder::operator Status() const& { +StatusBuilder::operator absl::Status() const& { return StatusBuilder(*this).JoinMessageToStatus(); } -StatusBuilder::operator Status() && { return JoinMessageToStatus(); } +StatusBuilder::operator absl::Status() && { return JoinMessageToStatus(); } absl::Status StatusBuilder::JoinMessageToStatus() { if (!impl_) { diff --git a/mediapipe/framework/deps/status_builder.h b/mediapipe/framework/deps/status_builder.h index ae11699d2..935ab7776 100644 --- a/mediapipe/framework/deps/status_builder.h +++ b/mediapipe/framework/deps/status_builder.h @@ -83,8 +83,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder { return std::move(*this << msg); } - operator Status() const&; - operator Status() &&; + operator absl::Status() const&; + operator absl::Status() &&; absl::Status JoinMessageToStatus(); diff --git a/mediapipe/framework/type_map.h b/mediapipe/framework/type_map.h index e26efa039..8fb324e98 100644 --- a/mediapipe/framework/type_map.h +++ b/mediapipe/framework/type_map.h @@ -272,17 +272,20 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string); #define MEDIAPIPE_REGISTER_TYPE(type, type_name, serialize_fn, deserialize_fn) \ SET_MEDIAPIPE_TYPE_MAP_VALUE( \ mediapipe::PacketTypeIdToMediaPipeTypeData, \ - mediapipe::tool::GetTypeHash< \ - mediapipe::type_map_internal::ReflectType::Type>(), \ + mediapipe::TypeId::Of< \ + mediapipe::type_map_internal::ReflectType::Type>() \ + .hash_code(), \ (mediapipe::MediaPipeTypeData{ \ - mediapipe::tool::GetTypeHash< \ - mediapipe::type_map_internal::ReflectType::Type>(), \ + mediapipe::TypeId::Of< \ + mediapipe::type_map_internal::ReflectType::Type>() \ + .hash_code(), \ type_name, serialize_fn, deserialize_fn})); \ SET_MEDIAPIPE_TYPE_MAP_VALUE( \ mediapipe::PacketTypeStringToMediaPipeTypeData, type_name, \ (mediapipe::MediaPipeTypeData{ \ - mediapipe::tool::GetTypeHash< \ - mediapipe::type_map_internal::ReflectType::Type>(), \ + mediapipe::TypeId::Of< \ + mediapipe::type_map_internal::ReflectType::Type>() \ + .hash_code(), \ type_name, serialize_fn, deserialize_fn})); // End define MEDIAPIPE_REGISTER_TYPE. From 201b2d739d994c51bb1b72336a09b8a93ed982d2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 23 May 2023 18:00:54 -0700 Subject: [PATCH 19/20] Fix c++98-compat-extra-semi warnings PiperOrigin-RevId: 534624086 --- mediapipe/framework/deps/strong_int.h | 34 +++++++++++++-------------- mediapipe/framework/timestamp.h | 2 +- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/mediapipe/framework/deps/strong_int.h b/mediapipe/framework/deps/strong_int.h index 6f102238f..3ddb6d0be 100644 --- a/mediapipe/framework/deps/strong_int.h +++ b/mediapipe/framework/deps/strong_int.h @@ -403,11 +403,11 @@ std::ostream &operator<<(std::ostream &os, lhs op## = rhs; \ return lhs; \ } -STRONG_INT_VS_STRONG_INT_BINARY_OP(+); -STRONG_INT_VS_STRONG_INT_BINARY_OP(-); -STRONG_INT_VS_STRONG_INT_BINARY_OP(&); -STRONG_INT_VS_STRONG_INT_BINARY_OP(|); -STRONG_INT_VS_STRONG_INT_BINARY_OP(^); +STRONG_INT_VS_STRONG_INT_BINARY_OP(+) +STRONG_INT_VS_STRONG_INT_BINARY_OP(-) +STRONG_INT_VS_STRONG_INT_BINARY_OP(&) +STRONG_INT_VS_STRONG_INT_BINARY_OP(|) +STRONG_INT_VS_STRONG_INT_BINARY_OP(^) #undef STRONG_INT_VS_STRONG_INT_BINARY_OP // Define operators that take one StrongInt and one native integer argument. @@ -431,12 +431,12 @@ STRONG_INT_VS_STRONG_INT_BINARY_OP(^); rhs op## = lhs; \ return rhs; \ } -STRONG_INT_VS_NUMERIC_BINARY_OP(*); -NUMERIC_VS_STRONG_INT_BINARY_OP(*); -STRONG_INT_VS_NUMERIC_BINARY_OP(/); -STRONG_INT_VS_NUMERIC_BINARY_OP(%); -STRONG_INT_VS_NUMERIC_BINARY_OP(<<); // NOLINT(whitespace/operators) -STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators) +STRONG_INT_VS_NUMERIC_BINARY_OP(*) +NUMERIC_VS_STRONG_INT_BINARY_OP(*) +STRONG_INT_VS_NUMERIC_BINARY_OP(/) +STRONG_INT_VS_NUMERIC_BINARY_OP(%) +STRONG_INT_VS_NUMERIC_BINARY_OP(<<) // NOLINT(whitespace/operators) +STRONG_INT_VS_NUMERIC_BINARY_OP(>>) // NOLINT(whitespace/operators) #undef STRONG_INT_VS_NUMERIC_BINARY_OP #undef NUMERIC_VS_STRONG_INT_BINARY_OP @@ -447,12 +447,12 @@ STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators) StrongInt rhs) { \ return lhs.value() op rhs.value(); \ } -STRONG_INT_COMPARISON_OP(==); // NOLINT(whitespace/operators) -STRONG_INT_COMPARISON_OP(!=); // NOLINT(whitespace/operators) -STRONG_INT_COMPARISON_OP(<); // NOLINT(whitespace/operators) -STRONG_INT_COMPARISON_OP(<=); // NOLINT(whitespace/operators) -STRONG_INT_COMPARISON_OP(>); // NOLINT(whitespace/operators) -STRONG_INT_COMPARISON_OP(>=); // NOLINT(whitespace/operators) +STRONG_INT_COMPARISON_OP(==) // NOLINT(whitespace/operators) +STRONG_INT_COMPARISON_OP(!=) // NOLINT(whitespace/operators) +STRONG_INT_COMPARISON_OP(<) // NOLINT(whitespace/operators) +STRONG_INT_COMPARISON_OP(<=) // NOLINT(whitespace/operators) +STRONG_INT_COMPARISON_OP(>) // NOLINT(whitespace/operators) +STRONG_INT_COMPARISON_OP(>=) // NOLINT(whitespace/operators) #undef STRONG_INT_COMPARISON_OP } // namespace intops diff --git a/mediapipe/framework/timestamp.h b/mediapipe/framework/timestamp.h index b8c3a69a2..966ec1839 100644 --- a/mediapipe/framework/timestamp.h +++ b/mediapipe/framework/timestamp.h @@ -57,7 +57,7 @@ namespace mediapipe { // have underflow/overflow etc. This type is used internally by Timestamp // and TimestampDiff. MEDIAPIPE_DEFINE_SAFE_INT_TYPE(TimestampBaseType, int64, - mediapipe::intops::LogFatalOnError); + mediapipe::intops::LogFatalOnError) class TimestampDiff; From bc035d914667e94d8ab5a03f8b4fdb53e79c5ce4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 23 May 2023 21:47:20 -0700 Subject: [PATCH 20/20] This will fix the multiple typos in the tasks files. PiperOrigin-RevId: 534679277 --- mediapipe/model_maker/python/core/utils/loss_functions.py | 2 +- .../python/vision/face_stylizer/face_stylizer.py | 4 ++-- .../vision/image_classifier/MPPImageClassifierTests.m | 4 ++-- .../test/vision/object_detector/MPPObjectDetectorTests.m | 8 ++++---- .../mediapipe/tasks/vision/core/BaseVisionTaskApi.java | 2 +- mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts | 6 +++--- .../tasks/web/vision/image_segmenter/image_segmenter.ts | 6 +++--- 7 files changed, 16 insertions(+), 16 deletions(-) diff --git a/mediapipe/model_maker/python/core/utils/loss_functions.py b/mediapipe/model_maker/python/core/utils/loss_functions.py index a60bd2ed4..504ba91ef 100644 --- a/mediapipe/model_maker/python/core/utils/loss_functions.py +++ b/mediapipe/model_maker/python/core/utils/loss_functions.py @@ -187,7 +187,7 @@ class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta): """Instantiates perceptual loss. Args: - feature_weight: The weight coeffcients of multiple model extracted + feature_weight: The weight coefficients of multiple model extracted features used for calculating the perceptual loss. loss_weight: The weight coefficients between `style_loss` and `content_loss`. diff --git a/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py index 85b567ca3..5758ac7b5 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py @@ -105,7 +105,7 @@ class FaceStylizer(object): self._train_model(train_data=train_data, preprocessor=self._preprocessor) def _create_model(self): - """Creates the componenets of face stylizer.""" + """Creates the components of face stylizer.""" self._encoder = model_util.load_keras_model( constants.FACE_STYLIZER_ENCODER_MODEL_FILES.get_path() ) @@ -138,7 +138,7 @@ class FaceStylizer(object): """ train_dataset = train_data.gen_tf_dataset(preprocess=preprocessor) - # TODO: Support processing mulitple input style images. The + # TODO: Support processing multiple input style images. The # input style images are expected to have similar style. # style_sample represents a tuple of (style_image, style_label). style_sample = next(iter(train_dataset)) diff --git a/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m b/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m index 332a217ca..59383dad6 100644 --- a/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m +++ b/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m @@ -668,10 +668,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; // Because of flow limiting, we cannot ensure that the callback will be // invoked `iterationCount` times. - // An normal expectation will fail if expectation.fullfill() is not called + // An normal expectation will fail if expectation.fulfill() is not called // `expectation.expectedFulfillmentCount` times. // If `expectation.isInverted = true`, the test will only succeed if - // expectation is not fullfilled for the specified `expectedFulfillmentCount`. + // expectation is not fulfilled for the specified `expectedFulfillmentCount`. // Since in our case we cannot predict how many times the expectation is // supposed to be fullfilled setting, // `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and diff --git a/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m b/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m index 1b717ba48..2ef5a0957 100644 --- a/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m +++ b/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m @@ -673,15 +673,15 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; // Because of flow limiting, we cannot ensure that the callback will be // invoked `iterationCount` times. - // An normal expectation will fail if expectation.fullfill() is not called + // An normal expectation will fail if expectation.fulfill() is not called // `expectation.expectedFulfillmentCount` times. // If `expectation.isInverted = true`, the test will only succeed if - // expectation is not fullfilled for the specified `expectedFulfillmentCount`. + // expectation is not fulfilled for the specified `expectedFulfillmentCount`. // Since in our case we cannot predict how many times the expectation is - // supposed to be fullfilled setting, + // supposed to be fulfilled setting, // `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and // `expectation.isInverted = true` ensures that test succeeds if - // expectation is fullfilled <= `iterationCount` times. + // expectation is fulfilled <= `iterationCount` times. XCTestExpectation *expectation = [[XCTestExpectation alloc] initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"]; expectation.expectedFulfillmentCount = iterationCount + 1; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java index 5964cef2c..9ea057b0d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java @@ -166,7 +166,7 @@ public class BaseVisionTaskApi implements AutoCloseable { // For 90° and 270° rotations, we need to swap width and height. // This is due to the internal behavior of ImageToTensorCalculator, which: // - first denormalizes the provided rect by multiplying the rect width or - // height by the image width or height, repectively. + // height by the image width or height, respectively. // - then rotates this by denormalized rect by the provided rotation, and // uses this for cropping, // - then finally rotates this back. diff --git a/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts index 8169e6775..2d99dc54d 100644 --- a/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts +++ b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts @@ -171,7 +171,7 @@ export class FaceStylizer extends VisionTaskRunner { /** * Performs face stylization on the provided single image and returns the * result. This method creates a copy of the resulting image and should not be - * used in high-throughput applictions. Only use this method when the + * used in high-throughput applications. Only use this method when the * FaceStylizer is created with the image running mode. * * @param image An image to process. @@ -182,7 +182,7 @@ export class FaceStylizer extends VisionTaskRunner { /** * Performs face stylization on the provided single image and returns the * result. This method creates a copy of the resulting image and should not be - * used in high-throughput applictions. Only use this method when the + * used in high-throughput applications. Only use this method when the * FaceStylizer is created with the image running mode. * * The 'imageProcessingOptions' parameter can be used to specify one or all @@ -275,7 +275,7 @@ export class FaceStylizer extends VisionTaskRunner { /** * Performs face stylization on the provided video frame. This method creates * a copy of the resulting image and should not be used in high-throughput - * applictions. Only use this method when the FaceStylizer is created with the + * applications. Only use this method when the FaceStylizer is created with the * video running mode. * * The input frame can be of any size. It's required to provide the video diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts index 3dd2d03ef..6d295aaa8 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts @@ -231,7 +231,7 @@ export class ImageSegmenter extends VisionTaskRunner { /** * Performs image segmentation on the provided single image and returns the * segmentation result. This method creates a copy of the resulting masks and - * should not be used in high-throughput applictions. Only use this method + * should not be used in high-throughput applications. Only use this method * when the ImageSegmenter is created with running mode `image`. * * @param image An image to process. @@ -242,7 +242,7 @@ export class ImageSegmenter extends VisionTaskRunner { /** * Performs image segmentation on the provided single image and returns the * segmentation result. This method creates a copy of the resulting masks and - * should not be used in high-v applictions. Only use this method when + * should not be used in high-v applications. Only use this method when * the ImageSegmenter is created with running mode `image`. * * @param image An image to process. @@ -320,7 +320,7 @@ export class ImageSegmenter extends VisionTaskRunner { /** * Performs image segmentation on the provided video frame and returns the * segmentation result. This method creates a copy of the resulting masks and - * should not be used in high-v applictions. Only use this method when + * should not be used in high-v applications. Only use this method when * the ImageSegmenter is created with running mode `video`. * * @param videoFrame A video frame to process.