Support scribble input for Interactive Segmenter

PiperOrigin-RevId: 529156049
This commit is contained in:
MediaPipe Team 2023-05-03 11:58:26 -07:00 committed by Copybara-Service
parent baa8fc68a1
commit 09662749ea
6 changed files with 99 additions and 15 deletions

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
@ -60,6 +61,8 @@ constexpr absl::string_view kNormRectTag{"NORM_RECT"};
constexpr absl::string_view kSubgraphTypeName{ constexpr absl::string_view kSubgraphTypeName{
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"}; "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
using components::containers::NormalizedKeypoint;
using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Image; using ::mediapipe::Image;
using ::mediapipe::NormalizedRect; using ::mediapipe::NormalizedRect;
@ -115,7 +118,7 @@ absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
case RegionOfInterest::Format::kUnspecified: case RegionOfInterest::Format::kUnspecified:
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
"RegionOfInterest format not specified"); "RegionOfInterest format not specified");
case RegionOfInterest::Format::kKeyPoint: case RegionOfInterest::Format::kKeyPoint: {
RET_CHECK(roi.keypoint.has_value()); RET_CHECK(roi.keypoint.has_value());
auto* annotation = result.add_render_annotations(); auto* annotation = result.add_render_annotations();
annotation->mutable_color()->set_r(255); annotation->mutable_color()->set_r(255);
@ -124,6 +127,19 @@ absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
point->set_x(roi.keypoint->x); point->set_x(roi.keypoint->x);
point->set_y(roi.keypoint->y); point->set_y(roi.keypoint->y);
return result; return result;
}
case RegionOfInterest::Format::kScribble: {
RET_CHECK(roi.scribble.has_value());
auto* annotation = result.add_render_annotations();
annotation->mutable_color()->set_r(255);
for (const NormalizedKeypoint& keypoint : *(roi.scribble)) {
auto* point = annotation->mutable_scribble()->add_point();
point->set_normalized(true);
point->set_x(keypoint.x);
point->set_y(keypoint.y);
}
return result;
}
} }
return absl::UnimplementedError("Unrecognized format"); return absl::UnimplementedError("Unrecognized format");
} }

View File

@ -53,6 +53,7 @@ struct RegionOfInterest {
enum class Format { enum class Format {
kUnspecified = 0, // Format not specified. kUnspecified = 0, // Format not specified.
kKeyPoint = 1, // Using keypoint to represent ROI. kKeyPoint = 1, // Using keypoint to represent ROI.
kScribble = 2, // Using scribble to represent ROI.
}; };
// Specifies the format used to specify the region-of-interest. Note that // Specifies the format used to specify the region-of-interest. Note that
@ -61,8 +62,13 @@ struct RegionOfInterest {
Format format = Format::kUnspecified; Format format = Format::kUnspecified;
// Represents the ROI in keypoint format, this should be non-nullopt if // Represents the ROI in keypoint format, this should be non-nullopt if
// `format` is `KEYPOINT`. // `format` is `kKeyPoint`.
std::optional<components::containers::NormalizedKeypoint> keypoint; std::optional<components::containers::NormalizedKeypoint> keypoint;
// Represents the ROI in scribble format, this should be non-nullopt if
// `format` is `kScribble`.
std::optional<std::vector<components::containers::NormalizedKeypoint>>
scribble;
}; };
// Performs interactive segmentation on images. // Performs interactive segmentation on images.

View File

@ -18,9 +18,12 @@ limitations under the License.
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <variant>
#include <vector>
#include "absl/flags/flag.h" #include "absl/flags/flag.h"
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
@ -179,22 +182,46 @@ TEST_F(CreateFromOptionsTest, FailsWithNeitherOutputSet) {
struct InteractiveSegmenterTestParams { struct InteractiveSegmenterTestParams {
std::string test_name; std::string test_name;
RegionOfInterest::Format format; RegionOfInterest::Format format;
NormalizedKeypoint roi; std::variant<NormalizedKeypoint, std::vector<NormalizedKeypoint>> roi;
absl::string_view golden_mask_file; absl::string_view golden_mask_file;
float similarity_threshold; float similarity_threshold;
}; };
using SucceedSegmentationWithRoi = class SucceedSegmentationWithRoi
::testing::TestWithParam<InteractiveSegmenterTestParams>; : public ::testing::TestWithParam<InteractiveSegmenterTestParams> {
public:
absl::StatusOr<RegionOfInterest> TestParamsToTaskOptions() {
const InteractiveSegmenterTestParams& params = GetParam();
RegionOfInterest interaction_roi;
interaction_roi.format = params.format;
switch (params.format) {
case (RegionOfInterest::Format::kKeyPoint): {
interaction_roi.keypoint = std::get<NormalizedKeypoint>(params.roi);
break;
}
case (RegionOfInterest::Format::kScribble): {
interaction_roi.scribble =
std::get<std::vector<NormalizedKeypoint>>(params.roi);
break;
}
default: {
return absl::InvalidArgumentError("Unknown ROI format");
}
}
return interaction_roi;
}
};
TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) { TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi,
TestParamsToTaskOptions());
const InteractiveSegmenterTestParams& params = GetParam(); const InteractiveSegmenterTestParams& params = GetParam();
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
Image image, Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi;
interaction_roi.format = params.format;
interaction_roi.keypoint = params.roi;
auto options = std::make_unique<InteractiveSegmenterOptions>(); auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel); JoinPath("./", kTestDataDirectory, kPtmModel);
@ -220,13 +247,13 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
} }
TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
const auto& params = GetParam(); MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi,
TestParamsToTaskOptions());
const InteractiveSegmenterTestParams& params = GetParam();
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
Image image, Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi;
interaction_roi.format = params.format;
interaction_roi.keypoint = params.roi;
auto options = std::make_unique<InteractiveSegmenterOptions>(); auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel); JoinPath("./", kTestDataDirectory, kPtmModel);
@ -253,11 +280,23 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi, SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi,
::testing::ValuesIn<InteractiveSegmenterTestParams>( ::testing::ValuesIn<InteractiveSegmenterTestParams>(
{{"PointToDog1", RegionOfInterest::Format::kKeyPoint, {// Keypoint input.
{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f}, NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
{"PointToDog2", RegionOfInterest::Format::kKeyPoint, {"PointToDog2", RegionOfInterest::Format::kKeyPoint,
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2, NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
kGoldenMaskSimilarity}}), kGoldenMaskSimilarity},
// Scribble input.
{"ScribbleToDog1", RegionOfInterest::Format::kScribble,
std::vector{NormalizedKeypoint{0.44, 0.70},
NormalizedKeypoint{0.44, 0.71},
NormalizedKeypoint{0.44, 0.72}},
kCatsAndDogsMaskDog1, 0.84f},
{"ScribbleToDog2", RegionOfInterest::Format::kScribble,
std::vector{NormalizedKeypoint{0.66, 0.66},
NormalizedKeypoint{0.66, 0.67},
NormalizedKeypoint{0.66, 0.68}},
kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}),
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>& [](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
info) { return info.param.test_name; }); info) { return info.param.test_name; });

View File

@ -22,6 +22,7 @@
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/vector.h" #include "mediapipe/framework/port/vector.h"
#include "mediapipe/util/color.pb.h" #include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe { namespace mediapipe {
namespace { namespace {
@ -112,6 +113,8 @@ void AnnotationRenderer::RenderDataOnImage(const RenderData& render_data) {
DrawGradientLine(annotation); DrawGradientLine(annotation);
} else if (annotation.data_case() == RenderAnnotation::kArrow) { } else if (annotation.data_case() == RenderAnnotation::kArrow) {
DrawArrow(annotation); DrawArrow(annotation);
} else if (annotation.data_case() == RenderAnnotation::kScribble) {
DrawScribble(annotation);
} else { } else {
LOG(FATAL) << "Unknown annotation type: " << annotation.data_case(); LOG(FATAL) << "Unknown annotation type: " << annotation.data_case();
} }
@ -442,7 +445,11 @@ void AnnotationRenderer::DrawArrow(const RenderAnnotation& annotation) {
} }
void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) { void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) {
const auto& point = annotation.point(); DrawPoint(annotation.point(), annotation);
}
void AnnotationRenderer::DrawPoint(const RenderAnnotation::Point& point,
const RenderAnnotation& annotation) {
int x = -1; int x = -1;
int y = -1; int y = -1;
if (point.normalized()) { if (point.normalized()) {
@ -460,6 +467,12 @@ void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) {
cv::circle(mat_image_, point_to_draw, thickness, color, -1); cv::circle(mat_image_, point_to_draw, thickness, color, -1);
} }
void AnnotationRenderer::DrawScribble(const RenderAnnotation& annotation) {
for (const RenderAnnotation::Point& point : annotation.scribble().point()) {
DrawPoint(point, annotation);
}
}
void AnnotationRenderer::DrawLine(const RenderAnnotation& annotation) { void AnnotationRenderer::DrawLine(const RenderAnnotation& annotation) {
int x_start = -1; int x_start = -1;
int y_start = -1; int y_start = -1;

View File

@ -96,6 +96,11 @@ class AnnotationRenderer {
// Draws a point on the image as described in the annotation. // Draws a point on the image as described in the annotation.
void DrawPoint(const RenderAnnotation& annotation); void DrawPoint(const RenderAnnotation& annotation);
void DrawPoint(const RenderAnnotation::Point& point,
const RenderAnnotation& annotation);
// Draws scribbles on the image as described in the annotation.
void DrawScribble(const RenderAnnotation& annotation);
// Draws a line segment on the image as described in the annotation. // Draws a line segment on the image as described in the annotation.
void DrawLine(const RenderAnnotation& annotation); void DrawLine(const RenderAnnotation& annotation);

View File

@ -131,6 +131,10 @@ message RenderAnnotation {
optional Color color2 = 7; optional Color color2 = 7;
} }
message Scribble {
repeated Point point = 1;
}
message Arrow { message Arrow {
// The arrow head will be drawn at (x_end, y_end). // The arrow head will be drawn at (x_end, y_end).
optional double x_start = 1; optional double x_start = 1;
@ -192,6 +196,7 @@ message RenderAnnotation {
RoundedRectangle rounded_rectangle = 9; RoundedRectangle rounded_rectangle = 9;
FilledRoundedRectangle filled_rounded_rectangle = 10; FilledRoundedRectangle filled_rounded_rectangle = 10;
GradientLine gradient_line = 14; GradientLine gradient_line = 14;
Scribble scribble = 15;
} }
// Thickness for drawing the annotation. // Thickness for drawing the annotation.