From a1ce19f68ee4cab9b7e36a03e103267b97b8325d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 7 Apr 2023 10:43:23 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 522631851 --- .../cc/vision/interactive_segmenter/BUILD | 8 ++ .../interactive_segmenter.cc | 92 +++++++++++++------ .../interactive_segmenter.h | 47 ++++------ .../interactive_segmenter_graph.cc | 74 +++++++++------ .../interactive_segmenter_test.cc | 79 +++++++++------- 5 files changed, 184 insertions(+), 116 deletions(-) diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD b/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD index 13de87491..8552383ac 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD @@ -24,21 +24,26 @@ cc_library( hdrs = ["interactive_segmenter.h"], deps = [ ":interactive_segmenter_graph", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/containers:keypoint", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_result", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/util:color_cc_proto", "//mediapipe/util:render_data_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -61,9 +66,12 @@ cc_library( "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/util:color_cc_proto", + "//mediapipe/util:graph_builder_utils", "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:render_data_cc_proto", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ] + select({ "//mediapipe/gpu:disable_gpu": [], diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc index 853baec29..9d7111e75 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc @@ -15,16 +15,24 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h" +#include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" #include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/util/color.pb.h" @@ -36,23 +44,26 @@ namespace vision { namespace interactive_segmenter { namespace { -constexpr char kSegmentationStreamName[] = "segmented_mask_out"; +constexpr char kConfidenceMasksStreamName[] = "confidence_masks"; +constexpr char kCategoryMaskStreamName[] = "category_mask"; constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; constexpr char kRoiStreamName[] = "roi_in"; constexpr char kNormRectStreamName[] = "norm_rect_in"; -constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; -constexpr char kImageTag[] = "IMAGE"; -constexpr char kRoiTag[] = "ROI"; -constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"}; +constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"}; +constexpr absl::string_view kImageTag{"IMAGE"}; +constexpr absl::string_view kRoiTag{"ROI"}; +constexpr absl::string_view kNormRectTag{"NORM_RECT"}; -constexpr char kSubgraphTypeName[] = - "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"; +constexpr absl::string_view kSubgraphTypeName{ + "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"}; using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; using ::mediapipe::NormalizedRect; +using ::mediapipe::tasks::vision::image_segmenter::ImageSegmenterResult; using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: image_segmenter::proto::ImageSegmenterGraphOptions; @@ -60,7 +71,8 @@ using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: // Creates a MediaPipe graph config that only contains a single subgraph node of // "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph". CalculatorGraphConfig CreateGraphConfig( - std::unique_ptr options) { + std::unique_ptr options, + bool output_confidence_masks, bool output_category_mask) { api2::builder::Graph graph; auto& task_subgraph = graph.AddNode(kSubgraphTypeName); task_subgraph.GetOptions().Swap( @@ -68,8 +80,15 @@ CalculatorGraphConfig CreateGraphConfig( graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kRoiTag).SetName(kRoiStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName); - task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> - graph.Out(kGroupedSegmentationTag); + if (output_confidence_masks) { + task_subgraph.Out(kConfidenceMasksTag) + .SetName(kConfidenceMasksStreamName) >> + graph.Out(kConfidenceMasksTag); + } + if (output_category_mask) { + task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >> + graph.Out(kCategoryMaskTag); + } task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); graph.In(kImageTag) >> task_subgraph.In(kImageTag); @@ -86,16 +105,6 @@ ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) { auto base_options_proto = std::make_unique( tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); options_proto->mutable_base_options()->Swap(base_options_proto.get()); - switch (options->output_type) { - case InteractiveSegmenterOptions::OutputType::CATEGORY_MASK: - options_proto->mutable_segmenter_options()->set_output_type( - SegmenterOptions::CATEGORY_MASK); - break; - case InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK: - options_proto->mutable_segmenter_options()->set_output_type( - SegmenterOptions::CONFIDENCE_MASK); - break; - } return options_proto; } @@ -104,10 +113,10 @@ ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) { absl::StatusOr ConvertRoiToRenderData(const RegionOfInterest& roi) { RenderData result; switch (roi.format) { - case RegionOfInterest::UNSPECIFIED: + case RegionOfInterest::Format::kUnspecified: return absl::InvalidArgumentError( "RegionOfInterest format not specified"); - case RegionOfInterest::KEYPOINT: + case RegionOfInterest::Format::kKeyPoint: RET_CHECK(roi.keypoint.has_value()); auto* annotation = result.add_render_annotations(); annotation->mutable_color()->set_r(255); @@ -125,15 +134,29 @@ absl::StatusOr ConvertRoiToRenderData(const RegionOfInterest& roi) { absl::StatusOr> InteractiveSegmenter::Create( std::unique_ptr options) { - auto options_proto = ConvertImageSegmenterOptionsToProto(options.get()); - return core::VisionTaskApiFactory::Create( - CreateGraphConfig(std::move(options_proto)), - std::move(options->base_options.op_resolver), core::RunningMode::IMAGE, - /*packets_callback=*/nullptr); + if (!options->output_confidence_masks && !options->output_category_mask) { + return absl::InvalidArgumentError( + "At least one of `output_confidence_masks` and `output_category_mask` " + "must be set."); + } + std::unique_ptr options_proto = + ConvertImageSegmenterOptionsToProto(options.get()); + ASSIGN_OR_RETURN( + std::unique_ptr segmenter, + (core::VisionTaskApiFactory::Create( + CreateGraphConfig(std::move(options_proto), + options->output_confidence_masks, + options->output_category_mask), + std::move(options->base_options.op_resolver), + core::RunningMode::IMAGE, + /*packets_callback=*/nullptr))); + segmenter->output_category_mask_ = options->output_category_mask; + segmenter->output_confidence_masks_ = options->output_confidence_masks; + return segmenter; } -absl::StatusOr> InteractiveSegmenter::Segment( +absl::StatusOr InteractiveSegmenter::Segment( mediapipe::Image image, const RegionOfInterest& roi, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -154,7 +177,16 @@ absl::StatusOr> InteractiveSegmenter::Segment( mediapipe::MakePacket(std::move(roi_as_render_data))}, {kNormRectStreamName, MakePacket(std::move(norm_rect))}})); - return output_packets[kSegmentationStreamName].Get>(); + std::optional> confidence_masks; + if (output_confidence_masks_) { + confidence_masks = + output_packets[kConfidenceMasksStreamName].Get>(); + } + std::optional category_mask; + if (output_category_mask_) { + category_mask = output_packets[kCategoryMaskStreamName].Get(); + } + return {{confidence_masks, category_mask}}; } } // namespace interactive_segmenter diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h index 420b22462..350777f31 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h @@ -21,12 +21,14 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/tasks/cc/components/containers/keypoint.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h" namespace mediapipe { namespace tasks { @@ -39,30 +41,24 @@ struct InteractiveSegmenterOptions { // file with metadata, accelerator options, op resolver, etc. tasks::core::BaseOptions base_options; - // The output type of segmentation results. - enum OutputType { - // Gives a single output mask where each pixel represents the class which - // the pixel in the original image was predicted to belong to. - CATEGORY_MASK = 0, - // Gives a list of output masks where, for each mask, each pixel represents - // the prediction confidence, usually in the [0, 1] range. - CONFIDENCE_MASK = 1, - }; + // Whether to output confidence masks. + bool output_confidence_masks = true; - OutputType output_type = OutputType::CATEGORY_MASK; + // Whether to output category mask. + bool output_category_mask = false; }; // The Region-Of-Interest (ROI) to interact with. struct RegionOfInterest { - enum Format { - UNSPECIFIED = 0, // Format not specified. - KEYPOINT = 1, // Using keypoint to represent ROI. + enum class Format { + kUnspecified = 0, // Format not specified. + kKeyPoint = 1, // Using keypoint to represent ROI. }; // Specifies the format used to specify the region-of-interest. Note that // using `UNSPECIFIED` is invalid and will lead to an `InvalidArgument` status // being returned. - Format format = Format::UNSPECIFIED; + Format format = Format::kUnspecified; // Represents the ROI in keypoint format, this should be non-nullopt if // `format` is `KEYPOINT`. @@ -84,13 +80,11 @@ struct RegionOfInterest { // - RGB inputs is supported (`channels` is required to be 3). // - if type is kTfLiteFloat32, NormalizationOptions are required to be // attached to the metadata for input normalization. -// Output tensors: -// (kTfLiteUInt8/kTfLiteFloat32) -// - list of segmented masks. -// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. -// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size -// `channels`. -// - batch is always 1 +// Output ImageSegmenterResult: +// Provides optional confidence masks if `output_confidence_masks` is set +// true, and an optional category mask if `output_category_mask` is set +// true. At least one of `output_confidence_masks` and `output_category_mask` +// must be set to true. class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi { public: using BaseVisionTaskApi::BaseVisionTaskApi; @@ -114,18 +108,17 @@ class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi { // setting its 'rotation_degrees' field. Note that specifying a // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. - // - // If the output_type is CATEGORY_MASK, the returned vector of images is - // per-category segmented image mask. - // If the output_type is CONFIDENCE_MASK, the returned vector of images - // contains only one confidence image mask. - absl::StatusOr> Segment( + absl::StatusOr Segment( mediapipe::Image image, const RegionOfInterest& roi, std::optional image_processing_options = std::nullopt); // Shuts down the InteractiveSegmenter when all works are done. absl::Status Close() { return runner_->Close(); } + + private: + bool output_confidence_masks_; + bool output_category_mask_; }; } // namespace interactive_segmenter diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc index 4c0cd2a88..b907e2156 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "mediapipe/calculators/util/flat_color_image_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" @@ -23,8 +26,9 @@ limitations under the License. #include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/util/color.pb.h" -#include "mediapipe/util/label_map.pb.h" +#include "mediapipe/util/graph_builder_utils.h" #include "mediapipe/util/render_data.pb.h" namespace mediapipe { @@ -42,16 +46,18 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; -constexpr char kSegmentationTag[] = "SEGMENTATION"; -constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; -constexpr char kImageTag[] = "IMAGE"; -constexpr char kImageCpuTag[] = "IMAGE_CPU"; -constexpr char kImageGpuTag[] = "IMAGE_GPU"; -constexpr char kAlphaTag[] = "ALPHA"; -constexpr char kAlphaGpuTag[] = "ALPHA_GPU"; -constexpr char kNormRectTag[] = "NORM_RECT"; -constexpr char kRoiTag[] = "ROI"; -constexpr char kVideoTag[] = "VIDEO"; +constexpr absl::string_view kSegmentationTag{"SEGMENTATION"}; +constexpr absl::string_view kGroupedSegmentationTag{"GROUPED_SEGMENTATION"}; +constexpr absl::string_view kConfidenceMaskTag{"CONFIDENCE_MASK"}; +constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"}; +constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"}; +constexpr absl::string_view kImageTag{"IMAGE"}; +constexpr absl::string_view kImageCpuTag{"IMAGE_CPU"}; +constexpr absl::string_view kImageGpuTag{"IMAGE_GPU"}; +constexpr absl::string_view kAlphaTag{"ALPHA"}; +constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"}; +constexpr absl::string_view kNormRectTag{"NORM_RECT"}; +constexpr absl::string_view kRoiTag{"ROI"}; // Updates the graph to return `roi` stream which has same dimension as // `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is @@ -87,11 +93,10 @@ Source<> RoiToAlpha(Source image, Source roi, bool use_gpu, } // namespace // An "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph" -// performs semantic segmentation given user's region-of-interest. Two kinds of -// outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. Users can -// retrieve segmented mask of only particular category/channel from -// SEGMENTATION, and users can also get all segmented masks from -// GROUPED_SEGMENTATION. +// performs semantic segmentation given the user's region-of-interest. The graph +// can output optional confidence masks if CONFIDENCE_MASKS is connected, and an +// optional category mask if CATEGORY_MASK is connected. At least one of +// CONFIDENCE_MASK, CONFIDENCE_MASKS and CATEGORY_MASK must be connected. // - Accepts CPU input images and outputs segmented masks on CPU. // // Inputs: @@ -106,11 +111,13 @@ Source<> RoiToAlpha(Source image, Source roi, bool use_gpu, // @Optional: rect covering the whole image is used if not specified. // // Outputs: -// SEGMENTATION - mediapipe::Image @Multiple -// Segmented masks for individual category. Segmented mask of single +// CONFIDENCE_MASK - mediapipe::Image @Multiple +// Confidence masks for individual category. Confidence mask of single // category can be accessed by index based output stream. -// GROUPED_SEGMENTATION - std::vector -// The output segmented masks grouped in a vector. +// CONFIDENCE_MASKS - std::vector @Optional +// The output confidence masks grouped in a vector. +// CATEGORY_MASK - mediapipe::Image @Optional +// Optional Category mask. // IMAGE - mediapipe::Image // The image that image segmenter runs on. // @@ -129,9 +136,6 @@ Source<> RoiToAlpha(Source image, Source roi, bool use_gpu, // file_name: "/path/to/model.tflite" // } // } -// segmenter_options { -// output_type: CONFIDENCE_MASK -// } // } // } // } @@ -176,10 +180,26 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph { image_with_set_alpha >> image_segmenter.In(kImageTag); norm_rect >> image_segmenter.In(kNormRectTag); - image_segmenter.Out(kSegmentationTag) >> - graph[Output(kSegmentationTag)]; - image_segmenter.Out(kGroupedSegmentationTag) >> - graph[Output>(kGroupedSegmentationTag)]; + // TODO: remove deprecated output type support. + if (task_options.segmenter_options().has_output_type()) { + image_segmenter.Out(kSegmentationTag) >> + graph[Output(kSegmentationTag)]; + image_segmenter.Out(kGroupedSegmentationTag) >> + graph[Output>(kGroupedSegmentationTag)]; + } else { + if (HasOutput(sc->OriginalNode(), kConfidenceMaskTag)) { + image_segmenter.Out(kConfidenceMaskTag) >> + graph[Output(kConfidenceMaskTag)]; + } + if (HasOutput(sc->OriginalNode(), kConfidenceMasksTag)) { + image_segmenter.Out(kConfidenceMasksTag) >> + graph[Output(kConfidenceMasksTag)]; + } + if (HasOutput(sc->OriginalNode(), kCategoryMaskTag)) { + image_segmenter.Out(kCategoryMaskTag) >> + graph[Output(kCategoryMaskTag)]; + } + } image_segmenter.Out(kImageTag) >> graph[Output(kImageTag)]; return graph.GetConfig(); diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc index dbc3bbe4c..40c2bb342 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc @@ -17,8 +17,11 @@ limitations under the License. #include #include +#include #include "absl/flags/flag.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" @@ -39,6 +42,7 @@ limitations under the License. #include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/kernels/builtin_op_kernels.h" #include "tensorflow/lite/mutable_op_resolver.h" +#include "testing/base/public/gmock.h" namespace mediapipe { namespace tasks { @@ -53,13 +57,16 @@ using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; +using ::testing::SizeIs; +using ::testing::status::StatusIs; -constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; -constexpr char kPtmModel[] = "ptm_512_hdt_ptm_woid.tflite"; -constexpr char kCatsAndDogsJpg[] = "cats_and_dogs.jpg"; +constexpr absl::string_view kTestDataDirectory{ + "/mediapipe/tasks/testdata/vision/"}; +constexpr absl::string_view kPtmModel{"ptm_512_hdt_ptm_woid.tflite"}; +constexpr absl::string_view kCatsAndDogsJpg{"cats_and_dogs.jpg"}; // Golden mask for the dogs in cats_and_dogs.jpg. -constexpr char kCatsAndDogsMaskDog1[] = "cats_and_dogs_mask_dog1.png"; -constexpr char kCatsAndDogsMaskDog2[] = "cats_and_dogs_mask_dog2.png"; +constexpr absl::string_view kCatsAndDogsMaskDog1{"cats_and_dogs_mask_dog1.png"}; +constexpr absl::string_view kCatsAndDogsMaskDog2{"cats_and_dogs_mask_dog2.png"}; constexpr float kGoldenMaskSimilarity = 0.97; @@ -135,35 +142,45 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { JoinPath("./", kTestDataDirectory, kPtmModel); options->base_options.op_resolver = absl::make_unique(); - auto segmenter_or = InteractiveSegmenter::Create(std::move(options)); + auto segmenter = InteractiveSegmenter::Create(std::move(options)); // TODO: Make MediaPipe InferenceCalculator report the detailed // interpreter errors (e.g., "Encountered unresolved custom op"). - EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal); + EXPECT_EQ(segmenter.status().code(), absl::StatusCode::kInternal); EXPECT_THAT( - segmenter_or.status().message(), + segmenter.status().message(), testing::HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk")); } TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { - absl::StatusOr> segmenter_or = + absl::StatusOr> segmenter = InteractiveSegmenter::Create( std::make_unique()); - EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(segmenter.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT( - segmenter_or.status().message(), + segmenter.status().message(), HasSubstr("ExternalFile must specify at least one of 'file_content', " "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); - EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload), + EXPECT_THAT(segmenter.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( MediaPipeTasksStatus::kRunnerInitializationError)))); } +TEST_F(CreateFromOptionsTest, FailsWithNeitherOutputSet) { + auto options = std::make_unique(); + options->output_category_mask = false; + options->output_confidence_masks = false; + + EXPECT_THAT(InteractiveSegmenter::Create(std::move(options)), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("At least one of"))); +} + struct InteractiveSegmenterTestParams { std::string test_name; RegionOfInterest::Format format; NormalizedKeypoint roi; - std::string golden_mask_file; + absl::string_view golden_mask_file; float similarity_threshold; }; @@ -181,16 +198,18 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); - options->output_type = InteractiveSegmenterOptions::OutputType::CATEGORY_MASK; + options->output_confidence_masks = false; + options->output_category_mask = true; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, InteractiveSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto category_masks, + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image, interaction_roi)); - EXPECT_EQ(category_masks.size(), 1); + EXPECT_TRUE(result.category_mask.has_value()); + EXPECT_FALSE(result.confidence_masks.has_value()); cv::Mat actual_mask = mediapipe::formats::MatView( - category_masks[0].GetImageFrameSharedPtr().get()); + result.category_mask->GetImageFrameSharedPtr().get()); cv::Mat expected_mask = cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file), @@ -211,14 +230,13 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); - options->output_type = - InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, InteractiveSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image, interaction_roi)); - EXPECT_EQ(confidence_masks.size(), 2); + EXPECT_FALSE(result.category_mask.has_value()); + EXPECT_THAT(result.confidence_masks, Optional(SizeIs(2))); cv::Mat expected_mask = cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file), @@ -227,7 +245,7 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); cv::Mat actual_mask = mediapipe::formats::MatView( - confidence_masks[1].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(1).GetImageFrameSharedPtr().get()); EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float, params.similarity_threshold)); } @@ -235,9 +253,9 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { INSTANTIATE_TEST_SUITE_P( SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi, ::testing::ValuesIn( - {{"PointToDog1", RegionOfInterest::KEYPOINT, + {{"PointToDog1", RegionOfInterest::Format::kKeyPoint, NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f}, - {"PointToDog2", RegionOfInterest::KEYPOINT, + {"PointToDog2", RegionOfInterest::Format::kKeyPoint, NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}), [](const ::testing::TestParamInfo& @@ -252,22 +270,21 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); RegionOfInterest interaction_roi; - interaction_roi.format = RegionOfInterest::KEYPOINT; + interaction_roi.format = RegionOfInterest::Format::kKeyPoint; interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66}; auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); - options->output_type = - InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, InteractiveSegmenter::Create(std::move(options))); ImageProcessingOptions image_processing_options; image_processing_options.rotation_degrees = -90; MP_ASSERT_OK_AND_ASSIGN( - auto confidence_masks, + auto result, segmenter->Segment(image, interaction_roi, image_processing_options)); - EXPECT_EQ(confidence_masks.size(), 2); + EXPECT_FALSE(result.category_mask.has_value()); + EXPECT_EQ(result.confidence_masks->size(), 2); } TEST_F(ImageModeTest, FailsWithRegionOfInterest) { @@ -275,13 +292,11 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); RegionOfInterest interaction_roi; - interaction_roi.format = RegionOfInterest::KEYPOINT; + interaction_roi.format = RegionOfInterest::Format::kKeyPoint; interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66}; auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); - options->output_type = - InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, InteractiveSegmenter::Create(std::move(options)));