From 367ccbfdf3c59e9c4baa891dbdcd0c1d944d93f9 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 4 Apr 2023 00:23:03 -0700 Subject: [PATCH] update ImageSegmenterGraph to always output confidence mask and optionally output category mask PiperOrigin-RevId: 521679910 --- .../tasks/cc/vision/image_segmenter/BUILD | 9 ++ .../tensors_to_segmentation_calculator.cc | 96 +++++++++---- ...tensors_to_segmentation_calculator_test.cc | 86 +++++------- .../vision/image_segmenter/image_segmenter.cc | 80 ++++++----- .../vision/image_segmenter/image_segmenter.h | 60 +++------ .../image_segmenter/image_segmenter_graph.cc | 127 +++++++++++++----- .../image_segmenter/image_segmenter_result.h | 43 ++++++ .../image_segmenter/image_segmenter_test.cc | 124 ++++++++--------- .../proto/segmenter_options.proto | 2 +- 9 files changed, 370 insertions(+), 257 deletions(-) create mode 100644 mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 69833a5f6..ee1cd3693 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -16,6 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +cc_library( + name = "image_segmenter_result", + hdrs = ["image_segmenter_result.h"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework/formats:image"], +) + # Docs for Mediapipe Tasks Image Segmenter # https://developers.google.com/mediapipe/solutions/vision/image_segmenter cc_library( @@ -25,6 +32,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":image_segmenter_graph", + ":image_segmenter_result", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", @@ -82,6 +90,7 @@ cc_library( "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "//mediapipe/tasks/metadata:image_segmenter_metadata_schema_cc", "//mediapipe/tasks/metadata:metadata_schema_cc", + "//mediapipe/util:graph_builder_utils", "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_util", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc index 0cdc8fe0f..49ad18029 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc @@ -80,10 +80,10 @@ void Sigmoid(absl::Span values, [](float value) { return 1. / (1 + std::exp(-value)); }); } -std::vector ProcessForCategoryMaskCpu(const Shape& input_shape, - const Shape& output_shape, - const SegmenterOptions& options, - const float* tensors_buffer) { +Image ProcessForCategoryMaskCpu(const Shape& input_shape, + const Shape& output_shape, + const SegmenterOptions& options, + const float* tensors_buffer) { cv::Mat resized_tensors_mat; cv::Mat tensors_mat_view( input_shape.height, input_shape.width, CV_32FC(input_shape.channels), @@ -135,7 +135,7 @@ std::vector ProcessForCategoryMaskCpu(const Shape& input_shape, pixel = maximum_category_idx; } }); - return {category_mask}; + return category_mask; } std::vector ProcessForConfidenceMaskCpu(const Shape& input_shape, @@ -209,7 +209,9 @@ std::vector ProcessForConfidenceMaskCpu(const Shape& input_shape, } // namespace -// Converts Tensors from a vector of Tensor to Segmentation. +// Converts Tensors from a vector of Tensor to Segmentation masks. The +// calculator always output confidence masks, and an optional category mask if +// CATEGORY_MASK is connected. // // Performs optional resizing to OUTPUT_SIZE dimension if provided, // otherwise the segmented masks is the same size as input tensor. @@ -221,7 +223,12 @@ std::vector ProcessForConfidenceMaskCpu(const Shape& input_shape, // the size to resize masks to. // // Output: -// Segmentation: Segmentation proto. +// CONFIDENCE_MASK @Multiple: Multiple masks of float image where, for each +// mask, each pixel represents the prediction confidence, usually in the [0, +// 1] range. +// CATEGORY_MASK @Optional: A category mask of uint8 image where each pixel +// represents the class which the pixel in the original image was predicted to +// belong to. // // Options: // See tensors_to_segmentation_calculator.proto @@ -231,13 +238,13 @@ std::vector ProcessForConfidenceMaskCpu(const Shape& input_shape, // calculator: "TensorsToSegmentationCalculator" // input_stream: "TENSORS:tensors" // input_stream: "OUTPUT_SIZE:size" -// output_stream: "SEGMENTATION:0:segmentation" -// output_stream: "SEGMENTATION:1:segmentation" +// output_stream: "CONFIDENCE_MASK:0:confidence_mask" +// output_stream: "CONFIDENCE_MASK:1:confidence_mask" +// output_stream: "CATEGORY_MASK:category_mask" // options { // [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { // segmenter_options { // activation: SOFTMAX -// output_type: CONFIDENCE_MASK // } // } // } @@ -248,7 +255,11 @@ class TensorsToSegmentationCalculator : public Node { static constexpr Input>::Optional kOutputSizeIn{ "OUTPUT_SIZE"}; static constexpr Output::Multiple kSegmentationOut{"SEGMENTATION"}; - MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut); + static constexpr Output::Multiple kConfidenceMaskOut{ + "CONFIDENCE_MASK"}; + static constexpr Output::Optional kCategoryMaskOut{"CATEGORY_MASK"}; + MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut, + kConfidenceMaskOut, kCategoryMaskOut); static absl::Status UpdateContract(CalculatorContract* cc); @@ -279,9 +290,13 @@ absl::Status TensorsToSegmentationCalculator::UpdateContract( absl::Status TensorsToSegmentationCalculator::Open( mediapipe::CalculatorContext* cc) { options_ = cc->Options(); - RET_CHECK_NE(options_.segmenter_options().output_type(), - SegmenterOptions::UNSPECIFIED) - << "Must specify output_type as one of [CONFIDENCE_MASK|CATEGORY_MASK]."; + // TODO: remove deprecated output type support. + if (options_.segmenter_options().has_output_type()) { + RET_CHECK_NE(options_.segmenter_options().output_type(), + SegmenterOptions::UNSPECIFIED) + << "Must specify output_type as one of " + "[CONFIDENCE_MASK|CATEGORY_MASK]."; + } #ifdef __EMSCRIPTEN__ MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_)); #endif // __EMSCRIPTEN__ @@ -309,6 +324,10 @@ absl::Status TensorsToSegmentationCalculator::Process( if (cc->Inputs().HasTag("OUTPUT_SIZE")) { std::tie(output_width, output_height) = kOutputSizeIn(cc).Get(); } + + // Use GPU postprocessing on web when Tensor is there already and has <= 12 + // categories. +#ifdef __EMSCRIPTEN__ Shape output_shape = { /* height= */ output_height, /* width= */ output_width, @@ -316,10 +335,6 @@ absl::Status TensorsToSegmentationCalculator::Process( SegmenterOptions::CATEGORY_MASK ? 1 : input_shape.channels}; - - // Use GPU postprocessing on web when Tensor is there already and has <= 12 - // categories. -#ifdef __EMSCRIPTEN__ if (input_tensor.ready_as_opengl_texture_2d() && input_shape.channels <= 12) { std::vector> segmented_masks = postprocessor_.GetSegmentationResultGpu(input_shape, output_shape, @@ -332,10 +347,41 @@ absl::Status TensorsToSegmentationCalculator::Process( #endif // __EMSCRIPTEN__ // Otherwise, use CPU postprocessing. - std::vector segmented_masks = GetSegmentationResultCpu( - input_shape, output_shape, input_tensor.GetCpuReadView().buffer()); - for (int i = 0; i < segmented_masks.size(); ++i) { - kSegmentationOut(cc)[i].Send(std::move(segmented_masks[i])); + const float* tensors_buffer = input_tensor.GetCpuReadView().buffer(); + + // TODO: remove deprecated output type support. + if (options_.segmenter_options().has_output_type()) { + std::vector segmented_masks = GetSegmentationResultCpu( + input_shape, + {/* height= */ output_height, + /* width= */ output_width, + /* channels= */ options_.segmenter_options().output_type() == + SegmenterOptions::CATEGORY_MASK + ? 1 + : input_shape.channels}, + input_tensor.GetCpuReadView().buffer()); + for (int i = 0; i < segmented_masks.size(); ++i) { + kSegmentationOut(cc)[i].Send(std::move(segmented_masks[i])); + } + return absl::OkStatus(); + } + + std::vector confidence_masks = + ProcessForConfidenceMaskCpu(input_shape, + {/* height= */ output_height, + /* width= */ output_width, + /* channels= */ input_shape.channels}, + options_.segmenter_options(), tensors_buffer); + for (int i = 0; i < confidence_masks.size(); ++i) { + kConfidenceMaskOut(cc)[i].Send(std::move(confidence_masks[i])); + } + if (cc->Outputs().HasTag("CATEGORY_MASK")) { + kCategoryMaskOut(cc).Send(ProcessForCategoryMaskCpu( + input_shape, + {/* height= */ output_height, + /* width= */ output_width, + /* channels= */ 1}, + options_.segmenter_options(), tensors_buffer)); } return absl::OkStatus(); } @@ -345,9 +391,9 @@ std::vector TensorsToSegmentationCalculator::GetSegmentationResultCpu( const float* tensors_buffer) { if (options_.segmenter_options().output_type() == SegmenterOptions::CATEGORY_MASK) { - return ProcessForCategoryMaskCpu(input_shape, output_shape, - options_.segmenter_options(), - tensors_buffer); + return {ProcessForCategoryMaskCpu(input_shape, output_shape, + options_.segmenter_options(), + tensors_buffer)}; } else { return ProcessForConfidenceMaskCpu(input_shape, output_shape, options_.segmenter_options(), diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc index 54fb9b816..d6a2f3fd9 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc @@ -79,8 +79,9 @@ void PushTensorsToRunner(int tensor_height, int tensor_width, std::vector GetPackets(const CalculatorRunner& runner) { std::vector mask_packets; for (int i = 0; i < runner.Outputs().NumEntries(); ++i) { - EXPECT_EQ(runner.Outputs().Get("SEGMENTATION", i).packets.size(), 1); - mask_packets.push_back(runner.Outputs().Get("SEGMENTATION", i).packets[0]); + EXPECT_EQ(runner.Outputs().Get("CONFIDENCE_MASK", i).packets.size(), 1); + mask_packets.push_back( + runner.Outputs().Get("CONFIDENCE_MASK", i).packets[0]); } return mask_packets; } @@ -118,13 +119,10 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionOne) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:segmentation" + output_stream: "CONFIDENCE_MASK:segmentation" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: SOFTMAX - output_type: CONFIDENCE_MASK - } + segmenter_options { activation: SOFTMAX } } } )pb")); @@ -145,13 +143,10 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionFive) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:segmentation" + output_stream: "CONFIDENCE_MASK:segmentation" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: SOFTMAX - output_type: CONFIDENCE_MASK - } + segmenter_options { activation: SOFTMAX } } } )pb")); @@ -173,16 +168,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSoftmax) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:0:segmented_mask_0" - output_stream: "SEGMENTATION:1:segmented_mask_1" - output_stream: "SEGMENTATION:2:segmented_mask_2" - output_stream: "SEGMENTATION:3:segmented_mask_3" + output_stream: "CONFIDENCE_MASK:0:segmented_mask_0" + output_stream: "CONFIDENCE_MASK:1:segmented_mask_1" + output_stream: "CONFIDENCE_MASK:2:segmented_mask_2" + output_stream: "CONFIDENCE_MASK:3:segmented_mask_3" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: SOFTMAX - output_type: CONFIDENCE_MASK - } + segmenter_options { activation: SOFTMAX } } } )pb")); @@ -218,16 +210,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithNone) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:0:segmented_mask_0" - output_stream: "SEGMENTATION:1:segmented_mask_1" - output_stream: "SEGMENTATION:2:segmented_mask_2" - output_stream: "SEGMENTATION:3:segmented_mask_3" + output_stream: "CONFIDENCE_MASK:0:segmented_mask_0" + output_stream: "CONFIDENCE_MASK:1:segmented_mask_1" + output_stream: "CONFIDENCE_MASK:2:segmented_mask_2" + output_stream: "CONFIDENCE_MASK:3:segmented_mask_3" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: NONE - output_type: CONFIDENCE_MASK - } + segmenter_options { activation: NONE } } } )pb")); @@ -259,16 +248,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSigmoid) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:0:segmented_mask_0" - output_stream: "SEGMENTATION:1:segmented_mask_1" - output_stream: "SEGMENTATION:2:segmented_mask_2" - output_stream: "SEGMENTATION:3:segmented_mask_3" + output_stream: "CONFIDENCE_MASK:0:segmented_mask_0" + output_stream: "CONFIDENCE_MASK:1:segmented_mask_1" + output_stream: "CONFIDENCE_MASK:2:segmented_mask_2" + output_stream: "CONFIDENCE_MASK:3:segmented_mask_3" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: SIGMOID - output_type: CONFIDENCE_MASK - } + segmenter_options { activation: SIGMOID } } } )pb")); @@ -301,13 +287,14 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMask) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:segmentation" + output_stream: "CONFIDENCE_MASK:0:segmented_mask_0" + output_stream: "CONFIDENCE_MASK:1:segmented_mask_1" + output_stream: "CONFIDENCE_MASK:2:segmented_mask_2" + output_stream: "CONFIDENCE_MASK:3:segmented_mask_3" + output_stream: "CATEGORY_MASK:segmentation" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: NONE - output_type: CATEGORY_MASK - } + segmenter_options { activation: NONE } } } )pb")); @@ -318,11 +305,11 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMask) { tensor_height, tensor_width, std::vector(kTestValues.begin(), kTestValues.end()), &runner); MP_ASSERT_OK(runner.Run()); - ASSERT_EQ(runner.Outputs().NumEntries(), 1); + ASSERT_EQ(runner.Outputs().NumEntries(), 5); // Largest element index is 3. const int expected_index = 3; const std::vector buffer_indices = {0}; - std::vector packets = GetPackets(runner); + std::vector packets = runner.Outputs().Tag("CATEGORY_MASK").packets; EXPECT_THAT(packets, testing::ElementsAre( Uint8ImagePacket(tensor_height, tensor_width, expected_index, buffer_indices))); @@ -335,13 +322,14 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) { calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" input_stream: "OUTPUT_SIZE:size" - output_stream: "SEGMENTATION:segmentation" + output_stream: "CONFIDENCE_MASK:0:segmented_mask_0" + output_stream: "CONFIDENCE_MASK:1:segmented_mask_1" + output_stream: "CONFIDENCE_MASK:2:segmented_mask_2" + output_stream: "CONFIDENCE_MASK:3:segmented_mask_3" + output_stream: "CATEGORY_MASK:segmentation" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: NONE - output_type: CATEGORY_MASK - } + segmenter_options { activation: NONE } } } )pb")); @@ -367,7 +355,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) { const std::vector buffer_indices = { 0 * output_width + 0, 0 * output_width + 1, 1 * output_width + 0, 1 * output_width + 1}; - std::vector packets = GetPackets(runner); + std::vector packets = runner.Outputs().Tag("CATEGORY_MASK").packets; EXPECT_THAT(packets, testing::ElementsAre( Uint8ImagePacket(output_height, output_width, expected_index, buffer_indices))); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index ab1d3c84b..8f03ff086 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -37,8 +37,10 @@ namespace vision { namespace image_segmenter { namespace { -constexpr char kSegmentationStreamName[] = "segmented_mask_out"; -constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; +constexpr char kConfidenceMasksTag[] = "CONFIDENCE_MASKS"; +constexpr char kConfidenceMasksStreamName[] = "confidence_masks"; +constexpr char kCategoryMaskTag[] = "CATEGORY_MASK"; +constexpr char kCategoryMaskStreamName[] = "category_mask"; constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageTag[] = "IMAGE"; @@ -51,7 +53,6 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; using ::mediapipe::NormalizedRect; -using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: image_segmenter::proto::ImageSegmenterGraphOptions; @@ -59,21 +60,24 @@ using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: // "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph". CalculatorGraphConfig CreateGraphConfig( std::unique_ptr options, - bool enable_flow_limiting) { + bool output_category_mask, bool enable_flow_limiting) { api2::builder::Graph graph; auto& task_subgraph = graph.AddNode(kSubgraphTypeName); task_subgraph.GetOptions().Swap( options.get()); graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName); - task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> - graph.Out(kGroupedSegmentationTag); + task_subgraph.Out(kConfidenceMasksTag).SetName(kConfidenceMasksStreamName) >> + graph.Out(kConfidenceMasksTag); + if (output_category_mask) { + task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >> + graph.Out(kCategoryMaskTag); + } task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); if (enable_flow_limiting) { - return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph, - {kImageTag, kNormRectTag}, - kGroupedSegmentationTag); + return tasks::core::AddFlowLimiterCalculator( + graph, task_subgraph, {kImageTag, kNormRectTag}, kConfidenceMasksTag); } graph.In(kImageTag) >> task_subgraph.In(kImageTag); graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); @@ -91,16 +95,6 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) { options_proto->mutable_base_options()->set_use_stream_mode( options->running_mode != core::RunningMode::IMAGE); options_proto->set_display_names_locale(options->display_names_locale); - switch (options->output_type) { - case ImageSegmenterOptions::OutputType::CATEGORY_MASK: - options_proto->mutable_segmenter_options()->set_output_type( - SegmenterOptions::CATEGORY_MASK); - break; - case ImageSegmenterOptions::OutputType::CONFIDENCE_MASK: - options_proto->mutable_segmenter_options()->set_output_type( - SegmenterOptions::CONFIDENCE_MASK); - break; - } return options_proto; } @@ -145,6 +139,7 @@ absl::StatusOr> ImageSegmenter::Create( tasks::core::PacketsCallback packets_callback = nullptr; if (options->result_callback) { auto result_callback = options->result_callback; + bool output_category_mask = options->output_category_mask; packets_callback = [=](absl::StatusOr status_or_packets) { if (!status_or_packets.ok()) { @@ -156,34 +151,41 @@ absl::StatusOr> ImageSegmenter::Create( if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { return; } - Packet segmented_masks = - status_or_packets.value()[kSegmentationStreamName]; + Packet confidence_masks = + status_or_packets.value()[kConfidenceMasksStreamName]; + std::optional category_mask; + if (output_category_mask) { + category_mask = + status_or_packets.value()[kCategoryMaskStreamName].Get(); + } Packet image_packet = status_or_packets.value()[kImageOutStreamName]; - result_callback(segmented_masks.Get>(), - image_packet.Get(), - segmented_masks.Timestamp().Value() / - kMicroSecondsPerMilliSecond); + result_callback( + {{confidence_masks.Get>(), category_mask}}, + image_packet.Get(), + confidence_masks.Timestamp().Value() / + kMicroSecondsPerMilliSecond); }; } - auto image_segmenter = core::VisionTaskApiFactory::Create( CreateGraphConfig( - std::move(options_proto), + std::move(options_proto), options->output_category_mask, options->running_mode == core::RunningMode::LIVE_STREAM), std::move(options->base_options.op_resolver), options->running_mode, std::move(packets_callback)); if (!image_segmenter.ok()) { return image_segmenter.status(); } + image_segmenter.value()->output_category_mask_ = + options->output_category_mask; ASSIGN_OR_RETURN( (*image_segmenter)->labels_, GetLabelsFromGraphConfig((*image_segmenter)->runner_->GetGraphConfig())); return image_segmenter; } -absl::StatusOr> ImageSegmenter::Segment( +absl::StatusOr ImageSegmenter::Segment( mediapipe::Image image, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -201,11 +203,17 @@ absl::StatusOr> ImageSegmenter::Segment( {{kImageInStreamName, mediapipe::MakePacket(std::move(image))}, {kNormRectStreamName, MakePacket(std::move(norm_rect))}})); - return output_packets[kSegmentationStreamName].Get>(); + std::vector confidence_masks = + output_packets[kConfidenceMasksStreamName].Get>(); + std::optional category_mask; + if (output_category_mask_) { + category_mask = output_packets[kCategoryMaskStreamName].Get(); + } + return {{confidence_masks, category_mask}}; } -absl::StatusOr> ImageSegmenter::SegmentForVideo( - mediapipe::Image image, int64 timestamp_ms, +absl::StatusOr ImageSegmenter::SegmentForVideo( + mediapipe::Image image, int64_t timestamp_ms, std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( @@ -225,11 +233,17 @@ absl::StatusOr> ImageSegmenter::SegmentForVideo( {kNormRectStreamName, MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); - return output_packets[kSegmentationStreamName].Get>(); + std::vector confidence_masks = + output_packets[kConfidenceMasksStreamName].Get>(); + std::optional category_mask; + if (output_category_mask_) { + category_mask = output_packets[kCategoryMaskStreamName].Get(); + } + return {{confidence_masks, category_mask}}; } absl::Status ImageSegmenter::SegmentAsync( - Image image, int64 timestamp_ms, + Image image, int64_t timestamp_ms, std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index 076a5016c..1d18e3903 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -26,6 +26,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h" #include "tensorflow/lite/kernels/register.h" namespace mediapipe { @@ -52,23 +53,14 @@ struct ImageSegmenterOptions { // Metadata, if any. Defaults to English. std::string display_names_locale = "en"; - // The output type of segmentation results. - enum OutputType { - // Gives a single output mask where each pixel represents the class which - // the pixel in the original image was predicted to belong to. - CATEGORY_MASK = 0, - // Gives a list of output masks where, for each mask, each pixel represents - // the prediction confidence, usually in the [0, 1] range. - CONFIDENCE_MASK = 1, - }; - - OutputType output_type = OutputType::CATEGORY_MASK; + // Whether to output category mask. + bool output_category_mask = false; // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. - std::function>, - const Image&, int64)> + std::function, const Image&, + int64_t)> result_callback = nullptr; }; @@ -84,13 +76,9 @@ struct ImageSegmenterOptions { // 1 or 3). // - if type is kTfLiteFloat32, NormalizationOptions are required to be // attached to the metadata for input normalization. -// Output tensors: -// (kTfLiteUInt8/kTfLiteFloat32) -// - list of segmented masks. -// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. -// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size -// `channels`. -// - batch is always 1 +// Output ImageSegmenterResult: +// Provides confidence masks and an optional category mask if +// `output_category_mask` is set true. // An example of such model can be found at: // https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { @@ -114,12 +102,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // setting its 'rotation_degrees' field. Note that specifying a // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. - // - // If the output_type is CATEGORY_MASK, the returned vector of images is - // per-category segmented image mask. - // If the output_type is CONFIDENCE_MASK, the returned vector of images - // contains only one confidence image mask. - absl::StatusOr> Segment( + + absl::StatusOr Segment( mediapipe::Image image, std::optional image_processing_options = std::nullopt); @@ -137,13 +121,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // setting its 'rotation_degrees' field. Note that specifying a // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. - // - // If the output_type is CATEGORY_MASK, the returned vector of images is - // per-category segmented image mask. - // If the output_type is CONFIDENCE_MASK, the returned vector of images - // contains only one confidence image mask. - absl::StatusOr> SegmentForVideo( - mediapipe::Image image, int64 timestamp_ms, + absl::StatusOr SegmentForVideo( + mediapipe::Image image, int64_t timestamp_ms, std::optional image_processing_options = std::nullopt); @@ -164,17 +143,13 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // and will result in an invalid argument error being returned. // // The "result_callback" prvoides - // - A vector of segmented image masks. - // If the output_type is CATEGORY_MASK, the returned vector of images is - // per-category segmented image mask. - // If the output_type is CONFIDENCE_MASK, the returned vector of images - // contains only one confidence image mask. + // - An ImageSegmenterResult. // - The const reference to the corresponding input image that the image // segmentation runs on. Note that the const reference to the image will // no longer be valid when the callback returns. To access the image data // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. - absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms, + absl::Status SegmentAsync(mediapipe::Image image, int64_t timestamp_ms, std::optional image_processing_options = std::nullopt); @@ -182,9 +157,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { absl::Status Close() { return runner_->Close(); } // Get the category label list of the ImageSegmenter can recognize. For - // CATEGORY_MASK type, the index in the category mask corresponds to the - // category in the label list. For CONFIDENCE_MASK type, the output mask list - // at index corresponds to the category in the label list. + // CATEGORY_MASK, the index in the category mask corresponds to the category + // in the label list. For CONFIDENCE_MASK, the output mask list at index + // corresponds to the category in the label list. // // If there is no labelmap provided in the model file, empty label list is // returned. @@ -192,6 +167,7 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { private: std::vector labels_; + bool output_category_mask_; }; } // namespace image_segmenter diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index fe6265b73..4b9e7618b 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include @@ -42,6 +43,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/metadata/image_segmenter_metadata_schema_generated.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" +#include "mediapipe/util/graph_builder_utils.h" #include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map_util.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -65,10 +67,13 @@ using ::mediapipe::tasks::vision::image_segmenter::proto:: ImageSegmenterGraphOptions; using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ::tflite::TensorMetadata; -using LabelItems = mediapipe::proto_ns::Map; +using LabelItems = mediapipe::proto_ns::Map; constexpr char kSegmentationTag[] = "SEGMENTATION"; constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; +constexpr char kConfidenceMaskTag[] = "CONFIDENCE_MASK"; +constexpr char kConfidenceMasksTag[] = "CONFIDENCE_MASKS"; +constexpr char kCategoryMaskTag[] = "CATEGORY_MASK"; constexpr char kImageTag[] = "IMAGE"; constexpr char kImageCpuTag[] = "IMAGE_CPU"; constexpr char kImageGpuTag[] = "IMAGE_GPU"; @@ -80,7 +85,9 @@ constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA"; // Struct holding the different output streams produced by the image segmenter // subgraph. struct ImageSegmenterOutputs { - std::vector> segmented_masks; + std::optional>> segmented_masks; + std::optional>> confidence_masks; + std::optional> category_mask; // The same as the input image, mainly used for live stream mode. Source image; }; @@ -95,8 +102,10 @@ struct ImageAndTensorsOnDevice { } // namespace absl::Status SanityCheckOptions(const ImageSegmenterGraphOptions& options) { - if (options.segmenter_options().output_type() == - SegmenterOptions::UNSPECIFIED) { + // TODO: remove deprecated output type support. + if (options.segmenter_options().has_output_type() && + options.segmenter_options().output_type() == + SegmenterOptions::UNSPECIFIED) { return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, "`output_type` must not be UNSPECIFIED", MediaPipeTasksStatus::kInvalidArgumentError); @@ -133,9 +142,8 @@ absl::Status ConfigureTensorsToSegmentationCalculator( const core::ModelResources& model_resources, TensorsToSegmentationCalculatorOptions* options) { // Set default activation function NONE - options->mutable_segmenter_options()->set_output_type( - segmenter_option.segmenter_options().output_type()); - options->mutable_segmenter_options()->set_activation(SegmenterOptions::NONE); + options->mutable_segmenter_options()->CopyFrom( + segmenter_option.segmenter_options()); // Find the custom metadata of ImageSegmenterOptions type in model metadata. const auto* metadata_extractor = model_resources.GetMetadataExtractor(); bool found_activation_in_metadata = false; @@ -317,12 +325,14 @@ absl::StatusOr ConvertImageToTensors( } } -// An "mediapipe.tasks.vision.ImageSegmenterGraph" performs semantic -// segmentation. -// Two kinds of outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. -// Users can retrieve segmented mask of only particular category/channel from -// SEGMENTATION, and users can also get all segmented masks from -// GROUPED_SEGMENTATION. +// An "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph" performs +// semantic segmentation. The graph always output confidence masks, and an +// optional category mask if CATEGORY_MASK is connected. +// +// Two kinds of outputs for confidence mask are provided: CONFIDENCE_MASK and +// CONFIDENCE_MASKS. Users can retrieve segmented mask of only particular +// category/channel from CONFIDENCE_MASK, and users can also get all segmented +// confidence masks from CONFIDENCE_MASKS. // - Accepts CPU input images and outputs segmented masks on CPU. // // Inputs: @@ -334,11 +344,13 @@ absl::StatusOr ConvertImageToTensors( // @Optional: rect covering the whole image is used if not specified. // // Outputs: -// SEGMENTATION - mediapipe::Image @Multiple -// Segmented masks for individual category. Segmented mask of single +// CONFIDENCE_MASK - mediapipe::Image @Multiple +// Confidence masks for individual category. Confidence mask of single // category can be accessed by index based output stream. -// GROUPED_SEGMENTATION - std::vector -// The output segmented masks grouped in a vector. +// CONFIDENCE_MASKS - std::vector +// The output confidence masks grouped in a vector. +// CATEGORY_MASK - mediapipe::Image @Optional +// Optional Category mask. // IMAGE - mediapipe::Image // The image that image segmenter runs on. // @@ -369,23 +381,39 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { ASSIGN_OR_RETURN(const auto* model_resources, CreateModelResources(sc)); Graph graph; + const auto& options = sc->Options(); ASSIGN_OR_RETURN( auto output_streams, BuildSegmentationTask( - sc->Options(), *model_resources, - graph[Input(kImageTag)], - graph[Input::Optional(kNormRectTag)], graph)); + options, *model_resources, graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], + HasOutput(sc->OriginalNode(), kCategoryMaskTag), graph)); auto& merge_images_to_vector = graph.AddNode("MergeImagesToVectorCalculator"); - for (int i = 0; i < output_streams.segmented_masks.size(); ++i) { - output_streams.segmented_masks[i] >> - merge_images_to_vector[Input::Multiple("")][i]; - output_streams.segmented_masks[i] >> - graph[Output::Multiple(kSegmentationTag)][i]; + // TODO: remove deprecated output type support. + if (options.segmenter_options().has_output_type()) { + for (int i = 0; i < output_streams.segmented_masks->size(); ++i) { + output_streams.segmented_masks->at(i) >> + merge_images_to_vector[Input::Multiple("")][i]; + output_streams.segmented_masks->at(i) >> + graph[Output::Multiple(kSegmentationTag)][i]; + } + merge_images_to_vector.Out("") >> + graph[Output>(kGroupedSegmentationTag)]; + } else { + for (int i = 0; i < output_streams.confidence_masks->size(); ++i) { + output_streams.confidence_masks->at(i) >> + merge_images_to_vector[Input::Multiple("")][i]; + output_streams.confidence_masks->at(i) >> + graph[Output::Multiple(kConfidenceMaskTag)][i]; + } + merge_images_to_vector.Out("") >> + graph[Output>(kConfidenceMasksTag)]; + if (output_streams.category_mask) { + *output_streams.category_mask >> graph[Output(kCategoryMaskTag)]; + } } - merge_images_to_vector.Out("") >> - graph[Output>(kGroupedSegmentationTag)]; output_streams.image >> graph[Output(kImageTag)]; return graph.GetConfig(); } @@ -403,7 +431,8 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { absl::StatusOr BuildSegmentationTask( const ImageSegmenterGraphOptions& task_options, const core::ModelResources& model_resources, Source image_in, - Source norm_rect_in, Graph& graph) { + Source norm_rect_in, bool output_category_mask, + Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); // Adds preprocessing calculators and connects them to the graph input image @@ -435,22 +464,46 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { image_properties.Out("SIZE") >> tensor_to_images.In(kOutputSizeTag); // Exports multiple segmented masks. - std::vector> segmented_masks; - if (task_options.segmenter_options().output_type() == - SegmenterOptions::CATEGORY_MASK) { - segmented_masks.push_back( - Source(tensor_to_images[Output(kSegmentationTag)])); + // TODO: remove deprecated output type support. + if (task_options.segmenter_options().has_output_type()) { + std::vector> segmented_masks; + if (task_options.segmenter_options().output_type() == + SegmenterOptions::CATEGORY_MASK) { + segmented_masks.push_back( + Source(tensor_to_images[Output(kSegmentationTag)])); + } else { + ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, + GetOutputTensor(model_resources)); + int segmentation_streams_num = *output_tensor->shape()->rbegin(); + for (int i = 0; i < segmentation_streams_num; ++i) { + segmented_masks.push_back(Source( + tensor_to_images[Output::Multiple(kSegmentationTag)][i])); + } + } + return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks, + /*confidence_masks=*/std::nullopt, + /*category_mask=*/std::nullopt, + /*image=*/image_and_tensors.image}; } else { ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, GetOutputTensor(model_resources)); int segmentation_streams_num = *output_tensor->shape()->rbegin(); + std::vector> confidence_masks; + confidence_masks.reserve(segmentation_streams_num); for (int i = 0; i < segmentation_streams_num; ++i) { - segmented_masks.push_back(Source( - tensor_to_images[Output::Multiple(kSegmentationTag)][i])); + confidence_masks.push_back(Source( + tensor_to_images[Output::Multiple(kConfidenceMaskTag)][i])); } + return ImageSegmenterOutputs{ + /*segmented_masks=*/std::nullopt, + /*confidence_masks=*/confidence_masks, + /*category_mask=*/ + output_category_mask + ? std::make_optional( + tensor_to_images[Output(kCategoryMaskTag)]) + : std::nullopt, + /*image=*/image_and_tensors.image}; } - return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks, - /*image=*/image_and_tensors.image}; } }; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h new file mode 100644 index 000000000..fb2ec05f1 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h @@ -0,0 +1,43 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_ + +#include + +#include "mediapipe/framework/formats/image.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_segmenter { + +// The output result of ImageSegmenter +struct ImageSegmenterResult { + // Multiple masks of float image in VEC32F1 format where, for each mask, each + // pixel represents the prediction confidence, usually in the [0, 1] range. + std::vector confidence_masks; + // A category mask of uint8 image in GRAY8 format where each pixel represents + // the class which the pixel in the original image was predicted to belong to. + std::optional category_mask; +}; + +} // namespace image_segmenter +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_ diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 1d75a3fb7..1e4387491 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -36,6 +36,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" @@ -256,7 +257,6 @@ TEST(GetLabelsTest, SucceedsWithLabelsInModel) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -278,15 +278,14 @@ TEST_F(ImageModeTest, SucceedsWithCategoryMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; - + options->output_category_mask = true; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto category_masks, segmenter->Segment(image)); - EXPECT_EQ(category_masks.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_TRUE(result.category_mask.has_value()); cv::Mat actual_mask = mediapipe::formats::MatView( - category_masks[0].GetImageFrameSharedPtr().get()); + result.category_mask->GetImageFrameSharedPtr().get()); cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"), @@ -303,12 +302,11 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); - EXPECT_EQ(confidence_masks.size(), 21); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_EQ(result.confidence_masks.size(), 21); cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "cat_mask.jpg"), cv::IMREAD_GRAYSCALE); @@ -317,7 +315,7 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { // Cat category index 8. cv::Mat cat_mask = mediapipe::formats::MatView( - confidence_masks[8].GetImageFrameSharedPtr().get()); + result.confidence_masks[8].GetImageFrameSharedPtr().get()); EXPECT_THAT(cat_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -331,15 +329,14 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); ImageProcessingOptions image_processing_options; image_processing_options.rotation_degrees = -90; - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image, image_processing_options)); - EXPECT_EQ(confidence_masks.size(), 21); + EXPECT_EQ(result.confidence_masks.size(), 21); cv::Mat expected_mask = cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"), @@ -349,7 +346,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { // Cat category index 8. cv::Mat cat_mask = mediapipe::formats::MatView( - confidence_masks[8].GetImageFrameSharedPtr().get()); + result.confidence_masks[8].GetImageFrameSharedPtr().get()); EXPECT_THAT(cat_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -361,7 +358,6 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -384,12 +380,11 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); - EXPECT_EQ(confidence_masks.size(), 2); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_EQ(result.confidence_masks.size(), 2); cv::Mat expected_mask = cv::imread(JoinPath("./", kTestDataDirectory, @@ -400,7 +395,7 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { // Selfie category index 1. cv::Mat selfie_mask = mediapipe::formats::MatView( - confidence_masks[1].GetImageFrameSharedPtr().get()); + result.confidence_masks[1].GetImageFrameSharedPtr().get()); EXPECT_THAT(selfie_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -411,11 +406,10 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); - EXPECT_EQ(confidence_masks.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_EQ(result.confidence_masks.size(), 1); cv::Mat expected_mask = cv::imread(JoinPath("./", kTestDataDirectory, @@ -425,7 +419,7 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); cv::Mat selfie_mask = mediapipe::formats::MatView( - confidence_masks[0].GetImageFrameSharedPtr().get()); + result.confidence_masks[0].GetImageFrameSharedPtr().get()); EXPECT_THAT(selfie_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -436,12 +430,11 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfieSegmentation); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); - EXPECT_EQ(confidence_masks.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_EQ(result.confidence_masks.size(), 1); MP_ASSERT_OK(segmenter->Close()); cv::Mat expected_mask = cv::imread( @@ -452,7 +445,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); cv::Mat selfie_mask = mediapipe::formats::MatView( - confidence_masks[0].GetImageFrameSharedPtr().get()); + result.confidence_masks[0].GetImageFrameSharedPtr().get()); EXPECT_THAT(selfie_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -463,16 +456,15 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfieSegmentation); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; - + options->output_category_mask = true; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image)); - EXPECT_EQ(category_mask.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_TRUE(result.category_mask.has_value()); MP_ASSERT_OK(segmenter->Close()); cv::Mat selfie_mask = mediapipe::formats::MatView( - category_mask[0].GetImageFrameSharedPtr().get()); + result.category_mask->GetImageFrameSharedPtr().get()); cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "portrait_selfie_segmentation_expected_category_mask.jpg"), @@ -487,16 +479,15 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; - + options->output_category_mask = true; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image)); - EXPECT_EQ(category_mask.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_TRUE(result.category_mask.has_value()); MP_ASSERT_OK(segmenter->Close()); cv::Mat selfie_mask = mediapipe::formats::MatView( - category_mask[0].GetImageFrameSharedPtr().get()); + result.category_mask->GetImageFrameSharedPtr().get()); cv::Mat expected_mask = cv::imread( JoinPath( "./", kTestDataDirectory, @@ -512,14 +503,13 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); - EXPECT_EQ(confidence_masks.size(), 2); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_EQ(result.confidence_masks.size(), 2); cv::Mat hair_mask = mediapipe::formats::MatView( - confidence_masks[1].GetImageFrameSharedPtr().get()); + result.confidence_masks[1].GetImageFrameSharedPtr().get()); MP_ASSERT_OK(segmenter->Close()); cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"), @@ -540,7 +530,6 @@ TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; options->running_mode = core::RunningMode::VIDEO; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, @@ -572,7 +561,7 @@ TEST_F(VideoModeTest, Succeeds) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + options->output_category_mask = true; options->running_mode = core::RunningMode::VIDEO; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -580,11 +569,10 @@ TEST_F(VideoModeTest, Succeeds) { JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"), cv::IMREAD_GRAYSCALE); for (int i = 0; i < iterations; ++i) { - MP_ASSERT_OK_AND_ASSIGN(auto category_masks, - segmenter->SegmentForVideo(image, i)); - EXPECT_EQ(category_masks.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->SegmentForVideo(image, i)); + EXPECT_TRUE(result.category_mask.has_value()); cv::Mat actual_mask = mediapipe::formats::MatView( - category_masks[0].GetImageFrameSharedPtr().get()); + result.category_mask->GetImageFrameSharedPtr().get()); EXPECT_THAT(actual_mask, SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, kGoldenMaskMagnificationFactor)); @@ -601,11 +589,10 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; options->running_mode = core::RunningMode::LIVE_STREAM; options->result_callback = - [](absl::StatusOr> segmented_masks, const Image& image, - int64 timestamp_ms) {}; + [](absl::StatusOr segmented_masks, + const Image& image, int64_t timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -634,11 +621,9 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; options->running_mode = core::RunningMode::LIVE_STREAM; - options->result_callback = - [](absl::StatusOr> segmented_masks, const Image& image, - int64 timestamp_ms) {}; + options->result_callback = [](absl::StatusOr result, + const Image& image, int64_t timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK(segmenter->SegmentAsync(image, 1)); @@ -660,23 +645,23 @@ TEST_F(LiveStreamModeTest, Succeeds) { Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "segmentation_input_rotation0.jpg"))); - std::vector> segmented_masks_results; + std::vector segmented_masks_results; std::vector> image_sizes; - std::vector timestamps; + std::vector timestamps; auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + options->output_category_mask = true; options->running_mode = core::RunningMode::LIVE_STREAM; - options->result_callback = - [&segmented_masks_results, &image_sizes, ×tamps]( - absl::StatusOr> segmented_masks, - const Image& image, int64 timestamp_ms) { - MP_ASSERT_OK(segmented_masks.status()); - segmented_masks_results.push_back(std::move(segmented_masks).value()); - image_sizes.push_back({image.width(), image.height()}); - timestamps.push_back(timestamp_ms); - }; + options->result_callback = [&segmented_masks_results, &image_sizes, + ×tamps]( + absl::StatusOr result, + const Image& image, int64_t timestamp_ms) { + MP_ASSERT_OK(result.status()); + segmented_masks_results.push_back(std::move(*result->category_mask)); + image_sizes.push_back({image.width(), image.height()}); + timestamps.push_back(timestamp_ms); + }; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); for (int i = 0; i < iterations; ++i) { @@ -690,10 +675,9 @@ TEST_F(LiveStreamModeTest, Succeeds) { cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"), cv::IMREAD_GRAYSCALE); - for (const auto& segmented_masks : segmented_masks_results) { - EXPECT_EQ(segmented_masks.size(), 1); + for (const auto& category_mask : segmented_masks_results) { cv::Mat actual_mask = mediapipe::formats::MatView( - segmented_masks[0].GetImageFrameSharedPtr().get()); + category_mask.GetImageFrameSharedPtr().get()); EXPECT_THAT(actual_mask, SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, kGoldenMaskMagnificationFactor)); @@ -702,7 +686,7 @@ TEST_F(LiveStreamModeTest, Succeeds) { EXPECT_EQ(image_size.first, image.width()); EXPECT_EQ(image_size.second, image.height()); } - int64 timestamp_ms = -1; + int64_t timestamp_ms = -1; for (const auto& timestamp : timestamps) { EXPECT_GT(timestamp, timestamp_ms); timestamp_ms = timestamp; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto index be2b8a589..b1ec529d0 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto @@ -33,7 +33,7 @@ message SegmenterOptions { CONFIDENCE_MASK = 2; } // Optional output mask type. - optional OutputType output_type = 1 [default = CATEGORY_MASK]; + optional OutputType output_type = 1 [deprecated = true]; // Supported activation functions for filtering. enum Activation {