From 09662749ea433f1d5b8b4b9a9b86c341a1574658 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 3 May 2023 11:58:26 -0700 Subject: [PATCH] Support scribble input for Interactive Segmenter PiperOrigin-RevId: 529156049 --- .../interactive_segmenter.cc | 18 +++++- .../interactive_segmenter.h | 8 ++- .../interactive_segmenter_test.cc | 63 +++++++++++++++---- mediapipe/util/annotation_renderer.cc | 15 ++++- mediapipe/util/annotation_renderer.h | 5 ++ mediapipe/util/render_data.proto | 5 ++ 6 files changed, 99 insertions(+), 15 deletions(-) diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc index af2a3f50c..c0d89c87d 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/keypoint.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" @@ -60,6 +61,8 @@ constexpr absl::string_view kNormRectTag{"NORM_RECT"}; constexpr absl::string_view kSubgraphTypeName{ "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"}; +using components::containers::NormalizedKeypoint; + using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; using ::mediapipe::NormalizedRect; @@ -115,7 +118,7 @@ absl::StatusOr ConvertRoiToRenderData(const RegionOfInterest& roi) { case RegionOfInterest::Format::kUnspecified: return absl::InvalidArgumentError( "RegionOfInterest format not specified"); - case RegionOfInterest::Format::kKeyPoint: + case RegionOfInterest::Format::kKeyPoint: { RET_CHECK(roi.keypoint.has_value()); auto* annotation = result.add_render_annotations(); annotation->mutable_color()->set_r(255); @@ -124,6 +127,19 @@ absl::StatusOr ConvertRoiToRenderData(const RegionOfInterest& roi) { point->set_x(roi.keypoint->x); point->set_y(roi.keypoint->y); return result; + } + case RegionOfInterest::Format::kScribble: { + RET_CHECK(roi.scribble.has_value()); + auto* annotation = result.add_render_annotations(); + annotation->mutable_color()->set_r(255); + for (const NormalizedKeypoint& keypoint : *(roi.scribble)) { + auto* point = annotation->mutable_scribble()->add_point(); + point->set_normalized(true); + point->set_x(keypoint.x); + point->set_y(keypoint.y); + } + return result; + } } return absl::UnimplementedError("Unrecognized format"); } diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h index ad4a238df..ad8a558df 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h @@ -53,6 +53,7 @@ struct RegionOfInterest { enum class Format { kUnspecified = 0, // Format not specified. kKeyPoint = 1, // Using keypoint to represent ROI. + kScribble = 2, // Using scribble to represent ROI. }; // Specifies the format used to specify the region-of-interest. Note that @@ -61,8 +62,13 @@ struct RegionOfInterest { Format format = Format::kUnspecified; // Represents the ROI in keypoint format, this should be non-nullopt if - // `format` is `KEYPOINT`. + // `format` is `kKeyPoint`. std::optional keypoint; + + // Represents the ROI in scribble format, this should be non-nullopt if + // `format` is `kScribble`. + std::optional> + scribble; }; // Performs interactive segmentation on images. diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc index 443247aea..16d065f61 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc @@ -18,9 +18,12 @@ limitations under the License. #include #include #include +#include +#include #include "absl/flags/flag.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image.h" @@ -179,22 +182,46 @@ TEST_F(CreateFromOptionsTest, FailsWithNeitherOutputSet) { struct InteractiveSegmenterTestParams { std::string test_name; RegionOfInterest::Format format; - NormalizedKeypoint roi; + std::variant> roi; absl::string_view golden_mask_file; float similarity_threshold; }; -using SucceedSegmentationWithRoi = - ::testing::TestWithParam; +class SucceedSegmentationWithRoi + : public ::testing::TestWithParam { + public: + absl::StatusOr TestParamsToTaskOptions() { + const InteractiveSegmenterTestParams& params = GetParam(); + + RegionOfInterest interaction_roi; + interaction_roi.format = params.format; + switch (params.format) { + case (RegionOfInterest::Format::kKeyPoint): { + interaction_roi.keypoint = std::get(params.roi); + break; + } + case (RegionOfInterest::Format::kScribble): { + interaction_roi.scribble = + std::get>(params.roi); + break; + } + default: { + return absl::InvalidArgumentError("Unknown ROI format"); + } + } + + return interaction_roi; + } +}; TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) { + MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi, + TestParamsToTaskOptions()); const InteractiveSegmenterTestParams& params = GetParam(); + MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); - RegionOfInterest interaction_roi; - interaction_roi.format = params.format; - interaction_roi.keypoint = params.roi; auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); @@ -220,13 +247,13 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) { } TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { - const auto& params = GetParam(); + MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi, + TestParamsToTaskOptions()); + const InteractiveSegmenterTestParams& params = GetParam(); + MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); - RegionOfInterest interaction_roi; - interaction_roi.format = params.format; - interaction_roi.keypoint = params.roi; auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); @@ -253,11 +280,23 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { INSTANTIATE_TEST_SUITE_P( SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi, ::testing::ValuesIn( - {{"PointToDog1", RegionOfInterest::Format::kKeyPoint, + {// Keypoint input. + {"PointToDog1", RegionOfInterest::Format::kKeyPoint, NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f}, {"PointToDog2", RegionOfInterest::Format::kKeyPoint, NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2, - kGoldenMaskSimilarity}}), + kGoldenMaskSimilarity}, + // Scribble input. + {"ScribbleToDog1", RegionOfInterest::Format::kScribble, + std::vector{NormalizedKeypoint{0.44, 0.70}, + NormalizedKeypoint{0.44, 0.71}, + NormalizedKeypoint{0.44, 0.72}}, + kCatsAndDogsMaskDog1, 0.84f}, + {"ScribbleToDog2", RegionOfInterest::Format::kScribble, + std::vector{NormalizedKeypoint{0.66, 0.66}, + NormalizedKeypoint{0.66, 0.67}, + NormalizedKeypoint{0.66, 0.68}}, + kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}), [](const ::testing::TestParamInfo& info) { return info.param.test_name; }); diff --git a/mediapipe/util/annotation_renderer.cc b/mediapipe/util/annotation_renderer.cc index 5188da896..d8516f9bc 100644 --- a/mediapipe/util/annotation_renderer.cc +++ b/mediapipe/util/annotation_renderer.cc @@ -22,6 +22,7 @@ #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/vector.h" #include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" namespace mediapipe { namespace { @@ -112,6 +113,8 @@ void AnnotationRenderer::RenderDataOnImage(const RenderData& render_data) { DrawGradientLine(annotation); } else if (annotation.data_case() == RenderAnnotation::kArrow) { DrawArrow(annotation); + } else if (annotation.data_case() == RenderAnnotation::kScribble) { + DrawScribble(annotation); } else { LOG(FATAL) << "Unknown annotation type: " << annotation.data_case(); } @@ -442,7 +445,11 @@ void AnnotationRenderer::DrawArrow(const RenderAnnotation& annotation) { } void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) { - const auto& point = annotation.point(); + DrawPoint(annotation.point(), annotation); +} + +void AnnotationRenderer::DrawPoint(const RenderAnnotation::Point& point, + const RenderAnnotation& annotation) { int x = -1; int y = -1; if (point.normalized()) { @@ -460,6 +467,12 @@ void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) { cv::circle(mat_image_, point_to_draw, thickness, color, -1); } +void AnnotationRenderer::DrawScribble(const RenderAnnotation& annotation) { + for (const RenderAnnotation::Point& point : annotation.scribble().point()) { + DrawPoint(point, annotation); + } +} + void AnnotationRenderer::DrawLine(const RenderAnnotation& annotation) { int x_start = -1; int y_start = -1; diff --git a/mediapipe/util/annotation_renderer.h b/mediapipe/util/annotation_renderer.h index 380bc3614..ae0cf976e 100644 --- a/mediapipe/util/annotation_renderer.h +++ b/mediapipe/util/annotation_renderer.h @@ -96,6 +96,11 @@ class AnnotationRenderer { // Draws a point on the image as described in the annotation. void DrawPoint(const RenderAnnotation& annotation); + void DrawPoint(const RenderAnnotation::Point& point, + const RenderAnnotation& annotation); + + // Draws scribbles on the image as described in the annotation. + void DrawScribble(const RenderAnnotation& annotation); // Draws a line segment on the image as described in the annotation. void DrawLine(const RenderAnnotation& annotation); diff --git a/mediapipe/util/render_data.proto b/mediapipe/util/render_data.proto index fee02fff3..897d5fa37 100644 --- a/mediapipe/util/render_data.proto +++ b/mediapipe/util/render_data.proto @@ -131,6 +131,10 @@ message RenderAnnotation { optional Color color2 = 7; } + message Scribble { + repeated Point point = 1; + } + message Arrow { // The arrow head will be drawn at (x_end, y_end). optional double x_start = 1; @@ -192,6 +196,7 @@ message RenderAnnotation { RoundedRectangle rounded_rectangle = 9; FilledRoundedRectangle filled_rounded_rectangle = 10; GradientLine gradient_line = 14; + Scribble scribble = 15; } // Thickness for drawing the annotation.