From d5def9e24dd4d769759d763b62232ff726716789 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 20:30:05 -0700 Subject: [PATCH] Image segmenter output both confidence masks and category mask optionally. PiperOrigin-RevId: 522227345 --- .../tensors_to_segmentation_calculator.cc | 31 ++++-- .../vision/image_segmenter/image_segmenter.cc | 49 +++++++--- .../vision/image_segmenter/image_segmenter.h | 10 +- .../image_segmenter/image_segmenter_graph.cc | 95 ++++++++++++------- .../image_segmenter/image_segmenter_result.h | 2 +- .../image_segmenter/image_segmenter_test.cc | 25 ++--- 6 files changed, 139 insertions(+), 73 deletions(-) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc index 49ad18029..790285546 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status_macros.h" @@ -210,8 +211,9 @@ std::vector ProcessForConfidenceMaskCpu(const Shape& input_shape, } // namespace // Converts Tensors from a vector of Tensor to Segmentation masks. The -// calculator always output confidence masks, and an optional category mask if -// CATEGORY_MASK is connected. +// calculator can output optional confidence masks if CONFIDENCE_MASK is +// connected, and an optional category mask if CATEGORY_MASK is connected. At +// least one of CONFIDENCE_MASK and CATEGORY_MASK must be connected. // // Performs optional resizing to OUTPUT_SIZE dimension if provided, // otherwise the segmented masks is the same size as input tensor. @@ -296,6 +298,13 @@ absl::Status TensorsToSegmentationCalculator::Open( SegmenterOptions::UNSPECIFIED) << "Must specify output_type as one of " "[CONFIDENCE_MASK|CATEGORY_MASK]."; + } else { + if (!cc->Outputs().HasTag("CONFIDENCE_MASK") && + !cc->Outputs().HasTag("CATEGORY_MASK")) { + return absl::InvalidArgumentError( + "At least one of CONFIDENCE_MASK and CATEGORY_MASK must be " + "connected."); + } } #ifdef __EMSCRIPTEN__ MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_)); @@ -366,14 +375,16 @@ absl::Status TensorsToSegmentationCalculator::Process( return absl::OkStatus(); } - std::vector confidence_masks = - ProcessForConfidenceMaskCpu(input_shape, - {/* height= */ output_height, - /* width= */ output_width, - /* channels= */ input_shape.channels}, - options_.segmenter_options(), tensors_buffer); - for (int i = 0; i < confidence_masks.size(); ++i) { - kConfidenceMaskOut(cc)[i].Send(std::move(confidence_masks[i])); + if (cc->Outputs().HasTag("CONFIDENCE_MASK")) { + std::vector confidence_masks = ProcessForConfidenceMaskCpu( + input_shape, + {/* height= */ output_height, + /* width= */ output_width, + /* channels= */ input_shape.channels}, + options_.segmenter_options(), tensors_buffer); + for (int i = 0; i < confidence_masks.size(); ++i) { + kConfidenceMaskOut(cc)[i].Send(std::move(confidence_masks[i])); + } } if (cc->Outputs().HasTag("CATEGORY_MASK")) { kCategoryMaskOut(cc).Send(ProcessForCategoryMaskCpu( diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 8f03ff086..33c868e05 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -60,15 +60,19 @@ using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: // "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph". CalculatorGraphConfig CreateGraphConfig( std::unique_ptr options, - bool output_category_mask, bool enable_flow_limiting) { + bool output_confidence_masks, bool output_category_mask, + bool enable_flow_limiting) { api2::builder::Graph graph; auto& task_subgraph = graph.AddNode(kSubgraphTypeName); task_subgraph.GetOptions().Swap( options.get()); graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName); - task_subgraph.Out(kConfidenceMasksTag).SetName(kConfidenceMasksStreamName) >> - graph.Out(kConfidenceMasksTag); + if (output_confidence_masks) { + task_subgraph.Out(kConfidenceMasksTag) + .SetName(kConfidenceMasksStreamName) >> + graph.Out(kConfidenceMasksTag); + } if (output_category_mask) { task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >> graph.Out(kCategoryMaskTag); @@ -135,11 +139,17 @@ absl::StatusOr> GetLabelsFromGraphConfig( absl::StatusOr> ImageSegmenter::Create( std::unique_ptr options) { + if (!options->output_confidence_masks && !options->output_category_mask) { + return absl::InvalidArgumentError( + "At least one of `output_confidence_masks` and `output_category_mask` " + "must be set."); + } auto options_proto = ConvertImageSegmenterOptionsToProto(options.get()); tasks::core::PacketsCallback packets_callback = nullptr; if (options->result_callback) { auto result_callback = options->result_callback; bool output_category_mask = options->output_category_mask; + bool output_confidence_masks = options->output_confidence_masks; packets_callback = [=](absl::StatusOr status_or_packets) { if (!status_or_packets.ok()) { @@ -151,8 +161,12 @@ absl::StatusOr> ImageSegmenter::Create( if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { return; } - Packet confidence_masks = - status_or_packets.value()[kConfidenceMasksStreamName]; + std::optional> confidence_masks; + if (output_confidence_masks) { + confidence_masks = + status_or_packets.value()[kConfidenceMasksStreamName] + .Get>(); + } std::optional category_mask; if (output_category_mask) { category_mask = @@ -160,23 +174,24 @@ absl::StatusOr> ImageSegmenter::Create( } Packet image_packet = status_or_packets.value()[kImageOutStreamName]; result_callback( - {{confidence_masks.Get>(), category_mask}}, - image_packet.Get(), - confidence_masks.Timestamp().Value() / - kMicroSecondsPerMilliSecond); + {{confidence_masks, category_mask}}, image_packet.Get(), + image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); }; } auto image_segmenter = core::VisionTaskApiFactory::Create( CreateGraphConfig( - std::move(options_proto), options->output_category_mask, + std::move(options_proto), options->output_confidence_masks, + options->output_category_mask, options->running_mode == core::RunningMode::LIVE_STREAM), std::move(options->base_options.op_resolver), options->running_mode, std::move(packets_callback)); if (!image_segmenter.ok()) { return image_segmenter.status(); } + image_segmenter.value()->output_confidence_masks_ = + options->output_confidence_masks; image_segmenter.value()->output_category_mask_ = options->output_category_mask; ASSIGN_OR_RETURN( @@ -203,8 +218,11 @@ absl::StatusOr ImageSegmenter::Segment( {{kImageInStreamName, mediapipe::MakePacket(std::move(image))}, {kNormRectStreamName, MakePacket(std::move(norm_rect))}})); - std::vector confidence_masks = - output_packets[kConfidenceMasksStreamName].Get>(); + std::optional> confidence_masks; + if (output_confidence_masks_) { + confidence_masks = + output_packets[kConfidenceMasksStreamName].Get>(); + } std::optional category_mask; if (output_category_mask_) { category_mask = output_packets[kCategoryMaskStreamName].Get(); @@ -233,8 +251,11 @@ absl::StatusOr ImageSegmenter::SegmentForVideo( {kNormRectStreamName, MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); - std::vector confidence_masks = - output_packets[kConfidenceMasksStreamName].Get>(); + std::optional> confidence_masks; + if (output_confidence_masks_) { + confidence_masks = + output_packets[kConfidenceMasksStreamName].Get>(); + } std::optional category_mask; if (output_category_mask_) { category_mask = output_packets[kCategoryMaskStreamName].Get(); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index 1d18e3903..352d6b273 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -53,6 +53,9 @@ struct ImageSegmenterOptions { // Metadata, if any. Defaults to English. std::string display_names_locale = "en"; + // Whether to output confidence masks. + bool output_confidence_masks = true; + // Whether to output category mask. bool output_category_mask = false; @@ -77,8 +80,10 @@ struct ImageSegmenterOptions { // - if type is kTfLiteFloat32, NormalizationOptions are required to be // attached to the metadata for input normalization. // Output ImageSegmenterResult: -// Provides confidence masks and an optional category mask if -// `output_category_mask` is set true. +// Provides optional confidence masks if `output_confidence_masks` is set +// true, and an optional category mask if `output_category_mask` is set +// true. At least one of `output_confidence_masks` and `output_category_mask` +// must be set to true. // An example of such model can be found at: // https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { @@ -167,6 +172,7 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { private: std::vector labels_; + bool output_confidence_masks_; bool output_category_mask_; }; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 4b9e7618b..840e7933a 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -326,8 +326,10 @@ absl::StatusOr ConvertImageToTensors( } // An "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph" performs -// semantic segmentation. The graph always output confidence masks, and an -// optional category mask if CATEGORY_MASK is connected. +// semantic segmentation. The graph can output optional confidence masks if +// CONFIDENCE_MASKS is connected, and an optional category mask if CATEGORY_MASK +// is connected. At least one of CONFIDENCE_MASK, CONFIDENCE_MASKS and +// CATEGORY_MASK must be connected. // // Two kinds of outputs for confidence mask are provided: CONFIDENCE_MASK and // CONFIDENCE_MASKS. Users can retrieve segmented mask of only particular @@ -347,7 +349,7 @@ absl::StatusOr ConvertImageToTensors( // CONFIDENCE_MASK - mediapipe::Image @Multiple // Confidence masks for individual category. Confidence mask of single // category can be accessed by index based output stream. -// CONFIDENCE_MASKS - std::vector +// CONFIDENCE_MASKS - std::vector @Optional // The output confidence masks grouped in a vector. // CATEGORY_MASK - mediapipe::Image @Optional // Optional Category mask. @@ -356,7 +358,7 @@ absl::StatusOr ConvertImageToTensors( // // Example: // node { -// calculator: "mediapipe.tasks.vision.ImageSegmenterGraph" +// calculator: "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph" // input_stream: "IMAGE:image" // output_stream: "SEGMENTATION:segmented_masks" // options { @@ -382,17 +384,20 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { CreateModelResources(sc)); Graph graph; const auto& options = sc->Options(); + // TODO: remove deprecated output type support. + if (!options.segmenter_options().has_output_type()) { + MP_RETURN_IF_ERROR(SanityCheck(sc)); + } ASSIGN_OR_RETURN( auto output_streams, BuildSegmentationTask( options, *model_resources, graph[Input(kImageTag)], - graph[Input::Optional(kNormRectTag)], - HasOutput(sc->OriginalNode(), kCategoryMaskTag), graph)); + graph[Input::Optional(kNormRectTag)], graph)); - auto& merge_images_to_vector = - graph.AddNode("MergeImagesToVectorCalculator"); // TODO: remove deprecated output type support. if (options.segmenter_options().has_output_type()) { + auto& merge_images_to_vector = + graph.AddNode("MergeImagesToVectorCalculator"); for (int i = 0; i < output_streams.segmented_masks->size(); ++i) { output_streams.segmented_masks->at(i) >> merge_images_to_vector[Input::Multiple("")][i]; @@ -402,14 +407,18 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { merge_images_to_vector.Out("") >> graph[Output>(kGroupedSegmentationTag)]; } else { - for (int i = 0; i < output_streams.confidence_masks->size(); ++i) { - output_streams.confidence_masks->at(i) >> - merge_images_to_vector[Input::Multiple("")][i]; - output_streams.confidence_masks->at(i) >> - graph[Output::Multiple(kConfidenceMaskTag)][i]; + if (output_streams.confidence_masks) { + auto& merge_images_to_vector = + graph.AddNode("MergeImagesToVectorCalculator"); + for (int i = 0; i < output_streams.confidence_masks->size(); ++i) { + output_streams.confidence_masks->at(i) >> + merge_images_to_vector[Input::Multiple("")][i]; + output_streams.confidence_masks->at(i) >> + graph[Output::Multiple(kConfidenceMaskTag)][i]; + } + merge_images_to_vector.Out("") >> + graph[Output>::Optional(kConfidenceMasksTag)]; } - merge_images_to_vector.Out("") >> - graph[Output>(kConfidenceMasksTag)]; if (output_streams.category_mask) { *output_streams.category_mask >> graph[Output(kCategoryMaskTag)]; } @@ -419,6 +428,19 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { } private: + absl::Status SanityCheck(mediapipe::SubgraphContext* sc) { + const auto& node = sc->OriginalNode(); + output_confidence_masks_ = HasOutput(node, kConfidenceMaskTag) || + HasOutput(node, kConfidenceMasksTag); + output_category_mask_ = HasOutput(node, kCategoryMaskTag); + if (!output_confidence_masks_ && !output_category_mask_) { + return absl::InvalidArgumentError( + "At least one of CONFIDENCE_MASK, CONFIDENCE_MASKS and CATEGORY_MASK " + "must be connected."); + } + return absl::OkStatus(); + } + // Adds a mediapipe image segmentation task pipeline graph into the provided // builder::Graph instance. The segmentation pipeline takes images // (mediapipe::Image) as the input and returns segmented image mask as output. @@ -431,8 +453,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { absl::StatusOr BuildSegmentationTask( const ImageSegmenterGraphOptions& task_options, const core::ModelResources& model_resources, Source image_in, - Source norm_rect_in, bool output_category_mask, - Graph& graph) { + Source norm_rect_in, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); // Adds preprocessing calculators and connects them to the graph input image @@ -485,26 +506,32 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { /*category_mask=*/std::nullopt, /*image=*/image_and_tensors.image}; } else { - ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, - GetOutputTensor(model_resources)); - int segmentation_streams_num = *output_tensor->shape()->rbegin(); - std::vector> confidence_masks; - confidence_masks.reserve(segmentation_streams_num); - for (int i = 0; i < segmentation_streams_num; ++i) { - confidence_masks.push_back(Source( - tensor_to_images[Output::Multiple(kConfidenceMaskTag)][i])); + std::optional>> confidence_masks; + if (output_confidence_masks_) { + ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, + GetOutputTensor(model_resources)); + int segmentation_streams_num = *output_tensor->shape()->rbegin(); + confidence_masks = std::vector>(); + confidence_masks->reserve(segmentation_streams_num); + for (int i = 0; i < segmentation_streams_num; ++i) { + confidence_masks->push_back(Source( + tensor_to_images[Output::Multiple(kConfidenceMaskTag)] + [i])); + } } - return ImageSegmenterOutputs{ - /*segmented_masks=*/std::nullopt, - /*confidence_masks=*/confidence_masks, - /*category_mask=*/ - output_category_mask - ? std::make_optional( - tensor_to_images[Output(kCategoryMaskTag)]) - : std::nullopt, - /*image=*/image_and_tensors.image}; + std::optional> category_mask; + if (output_category_mask_) { + category_mask = tensor_to_images[Output(kCategoryMaskTag)]; + } + return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt, + /*confidence_masks=*/confidence_masks, + /*category_mask=*/category_mask, + /*image=*/image_and_tensors.image}; } } + + bool output_confidence_masks_ = false; + bool output_category_mask_ = false; }; REGISTER_MEDIAPIPE_GRAPH( diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h index fb2ec05f1..f14ee7a90 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h @@ -29,7 +29,7 @@ namespace image_segmenter { struct ImageSegmenterResult { // Multiple masks of float image in VEC32F1 format where, for each mask, each // pixel represents the prediction confidence, usually in the [0, 1] range. - std::vector confidence_masks; + std::optional> confidence_masks; // A category mask of uint8 image in GRAY8 format where each pixel represents // the class which the pixel in the original image was predicted to belong to. std::optional category_mask; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 1e4387491..0c5a61486 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -278,6 +278,7 @@ TEST_F(ImageModeTest, SucceedsWithCategoryMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_confidence_masks = false; options->output_category_mask = true; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -306,7 +307,7 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); - EXPECT_EQ(result.confidence_masks.size(), 21); + EXPECT_EQ(result.confidence_masks->size(), 21); cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "cat_mask.jpg"), cv::IMREAD_GRAYSCALE); @@ -315,7 +316,7 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { // Cat category index 8. cv::Mat cat_mask = mediapipe::formats::MatView( - result.confidence_masks[8].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(8).GetImageFrameSharedPtr().get()); EXPECT_THAT(cat_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -336,7 +337,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { image_processing_options.rotation_degrees = -90; MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image, image_processing_options)); - EXPECT_EQ(result.confidence_masks.size(), 21); + EXPECT_EQ(result.confidence_masks->size(), 21); cv::Mat expected_mask = cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"), @@ -346,7 +347,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { // Cat category index 8. cv::Mat cat_mask = mediapipe::formats::MatView( - result.confidence_masks[8].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(8).GetImageFrameSharedPtr().get()); EXPECT_THAT(cat_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -384,7 +385,7 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); - EXPECT_EQ(result.confidence_masks.size(), 2); + EXPECT_EQ(result.confidence_masks->size(), 2); cv::Mat expected_mask = cv::imread(JoinPath("./", kTestDataDirectory, @@ -395,7 +396,7 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { // Selfie category index 1. cv::Mat selfie_mask = mediapipe::formats::MatView( - result.confidence_masks[1].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(1).GetImageFrameSharedPtr().get()); EXPECT_THAT(selfie_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -409,7 +410,7 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); - EXPECT_EQ(result.confidence_masks.size(), 1); + EXPECT_EQ(result.confidence_masks->size(), 1); cv::Mat expected_mask = cv::imread(JoinPath("./", kTestDataDirectory, @@ -419,7 +420,7 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); cv::Mat selfie_mask = mediapipe::formats::MatView( - result.confidence_masks[0].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(0).GetImageFrameSharedPtr().get()); EXPECT_THAT(selfie_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -434,7 +435,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); - EXPECT_EQ(result.confidence_masks.size(), 1); + EXPECT_EQ(result.confidence_masks->size(), 1); MP_ASSERT_OK(segmenter->Close()); cv::Mat expected_mask = cv::imread( @@ -445,7 +446,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); cv::Mat selfie_mask = mediapipe::formats::MatView( - result.confidence_masks[0].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(0).GetImageFrameSharedPtr().get()); EXPECT_THAT(selfie_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -506,10 +507,10 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); - EXPECT_EQ(result.confidence_masks.size(), 2); + EXPECT_EQ(result.confidence_masks->size(), 2); cv::Mat hair_mask = mediapipe::formats::MatView( - result.confidence_masks[1].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(1).GetImageFrameSharedPtr().get()); MP_ASSERT_OK(segmenter->Close()); cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"),