Support scribble input for Interactive Segmenter
PiperOrigin-RevId: 529156049
This commit is contained in:
parent
baa8fc68a1
commit
09662749ea
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
|
@ -60,6 +61,8 @@ constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
||||||
constexpr absl::string_view kSubgraphTypeName{
|
constexpr absl::string_view kSubgraphTypeName{
|
||||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
|
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
|
||||||
|
|
||||||
|
using components::containers::NormalizedKeypoint;
|
||||||
|
|
||||||
using ::mediapipe::CalculatorGraphConfig;
|
using ::mediapipe::CalculatorGraphConfig;
|
||||||
using ::mediapipe::Image;
|
using ::mediapipe::Image;
|
||||||
using ::mediapipe::NormalizedRect;
|
using ::mediapipe::NormalizedRect;
|
||||||
|
@ -115,7 +118,7 @@ absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
|
||||||
case RegionOfInterest::Format::kUnspecified:
|
case RegionOfInterest::Format::kUnspecified:
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"RegionOfInterest format not specified");
|
"RegionOfInterest format not specified");
|
||||||
case RegionOfInterest::Format::kKeyPoint:
|
case RegionOfInterest::Format::kKeyPoint: {
|
||||||
RET_CHECK(roi.keypoint.has_value());
|
RET_CHECK(roi.keypoint.has_value());
|
||||||
auto* annotation = result.add_render_annotations();
|
auto* annotation = result.add_render_annotations();
|
||||||
annotation->mutable_color()->set_r(255);
|
annotation->mutable_color()->set_r(255);
|
||||||
|
@ -124,6 +127,19 @@ absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
|
||||||
point->set_x(roi.keypoint->x);
|
point->set_x(roi.keypoint->x);
|
||||||
point->set_y(roi.keypoint->y);
|
point->set_y(roi.keypoint->y);
|
||||||
return result;
|
return result;
|
||||||
|
}
|
||||||
|
case RegionOfInterest::Format::kScribble: {
|
||||||
|
RET_CHECK(roi.scribble.has_value());
|
||||||
|
auto* annotation = result.add_render_annotations();
|
||||||
|
annotation->mutable_color()->set_r(255);
|
||||||
|
for (const NormalizedKeypoint& keypoint : *(roi.scribble)) {
|
||||||
|
auto* point = annotation->mutable_scribble()->add_point();
|
||||||
|
point->set_normalized(true);
|
||||||
|
point->set_x(keypoint.x);
|
||||||
|
point->set_y(keypoint.y);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return absl::UnimplementedError("Unrecognized format");
|
return absl::UnimplementedError("Unrecognized format");
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,6 +53,7 @@ struct RegionOfInterest {
|
||||||
enum class Format {
|
enum class Format {
|
||||||
kUnspecified = 0, // Format not specified.
|
kUnspecified = 0, // Format not specified.
|
||||||
kKeyPoint = 1, // Using keypoint to represent ROI.
|
kKeyPoint = 1, // Using keypoint to represent ROI.
|
||||||
|
kScribble = 2, // Using scribble to represent ROI.
|
||||||
};
|
};
|
||||||
|
|
||||||
// Specifies the format used to specify the region-of-interest. Note that
|
// Specifies the format used to specify the region-of-interest. Note that
|
||||||
|
@ -61,8 +62,13 @@ struct RegionOfInterest {
|
||||||
Format format = Format::kUnspecified;
|
Format format = Format::kUnspecified;
|
||||||
|
|
||||||
// Represents the ROI in keypoint format, this should be non-nullopt if
|
// Represents the ROI in keypoint format, this should be non-nullopt if
|
||||||
// `format` is `KEYPOINT`.
|
// `format` is `kKeyPoint`.
|
||||||
std::optional<components::containers::NormalizedKeypoint> keypoint;
|
std::optional<components::containers::NormalizedKeypoint> keypoint;
|
||||||
|
|
||||||
|
// Represents the ROI in scribble format, this should be non-nullopt if
|
||||||
|
// `format` is `kScribble`.
|
||||||
|
std::optional<std::vector<components::containers::NormalizedKeypoint>>
|
||||||
|
scribble;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Performs interactive segmentation on images.
|
// Performs interactive segmentation on images.
|
||||||
|
|
|
@ -18,9 +18,12 @@ limitations under the License.
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <variant>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/flags/flag.h"
|
#include "absl/flags/flag.h"
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "mediapipe/framework/deps/file_path.h"
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
@ -179,22 +182,46 @@ TEST_F(CreateFromOptionsTest, FailsWithNeitherOutputSet) {
|
||||||
struct InteractiveSegmenterTestParams {
|
struct InteractiveSegmenterTestParams {
|
||||||
std::string test_name;
|
std::string test_name;
|
||||||
RegionOfInterest::Format format;
|
RegionOfInterest::Format format;
|
||||||
NormalizedKeypoint roi;
|
std::variant<NormalizedKeypoint, std::vector<NormalizedKeypoint>> roi;
|
||||||
absl::string_view golden_mask_file;
|
absl::string_view golden_mask_file;
|
||||||
float similarity_threshold;
|
float similarity_threshold;
|
||||||
};
|
};
|
||||||
|
|
||||||
using SucceedSegmentationWithRoi =
|
class SucceedSegmentationWithRoi
|
||||||
::testing::TestWithParam<InteractiveSegmenterTestParams>;
|
: public ::testing::TestWithParam<InteractiveSegmenterTestParams> {
|
||||||
|
public:
|
||||||
|
absl::StatusOr<RegionOfInterest> TestParamsToTaskOptions() {
|
||||||
|
const InteractiveSegmenterTestParams& params = GetParam();
|
||||||
|
|
||||||
|
RegionOfInterest interaction_roi;
|
||||||
|
interaction_roi.format = params.format;
|
||||||
|
switch (params.format) {
|
||||||
|
case (RegionOfInterest::Format::kKeyPoint): {
|
||||||
|
interaction_roi.keypoint = std::get<NormalizedKeypoint>(params.roi);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case (RegionOfInterest::Format::kScribble): {
|
||||||
|
interaction_roi.scribble =
|
||||||
|
std::get<std::vector<NormalizedKeypoint>>(params.roi);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
return absl::InvalidArgumentError("Unknown ROI format");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return interaction_roi;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
|
TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi,
|
||||||
|
TestParamsToTaskOptions());
|
||||||
const InteractiveSegmenterTestParams& params = GetParam();
|
const InteractiveSegmenterTestParams& params = GetParam();
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
Image image,
|
Image image,
|
||||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||||
RegionOfInterest interaction_roi;
|
|
||||||
interaction_roi.format = params.format;
|
|
||||||
interaction_roi.keypoint = params.roi;
|
|
||||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||||
|
@ -220,13 +247,13 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
||||||
const auto& params = GetParam();
|
MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi,
|
||||||
|
TestParamsToTaskOptions());
|
||||||
|
const InteractiveSegmenterTestParams& params = GetParam();
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
Image image,
|
Image image,
|
||||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||||
RegionOfInterest interaction_roi;
|
|
||||||
interaction_roi.format = params.format;
|
|
||||||
interaction_roi.keypoint = params.roi;
|
|
||||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||||
|
@ -253,11 +280,23 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi,
|
SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi,
|
||||||
::testing::ValuesIn<InteractiveSegmenterTestParams>(
|
::testing::ValuesIn<InteractiveSegmenterTestParams>(
|
||||||
{{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
|
{// Keypoint input.
|
||||||
|
{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
|
||||||
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
|
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
|
||||||
{"PointToDog2", RegionOfInterest::Format::kKeyPoint,
|
{"PointToDog2", RegionOfInterest::Format::kKeyPoint,
|
||||||
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
|
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
|
||||||
kGoldenMaskSimilarity}}),
|
kGoldenMaskSimilarity},
|
||||||
|
// Scribble input.
|
||||||
|
{"ScribbleToDog1", RegionOfInterest::Format::kScribble,
|
||||||
|
std::vector{NormalizedKeypoint{0.44, 0.70},
|
||||||
|
NormalizedKeypoint{0.44, 0.71},
|
||||||
|
NormalizedKeypoint{0.44, 0.72}},
|
||||||
|
kCatsAndDogsMaskDog1, 0.84f},
|
||||||
|
{"ScribbleToDog2", RegionOfInterest::Format::kScribble,
|
||||||
|
std::vector{NormalizedKeypoint{0.66, 0.66},
|
||||||
|
NormalizedKeypoint{0.66, 0.67},
|
||||||
|
NormalizedKeypoint{0.66, 0.68}},
|
||||||
|
kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}),
|
||||||
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
|
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
|
||||||
info) { return info.param.test_name; });
|
info) { return info.param.test_name; });
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "mediapipe/framework/port/logging.h"
|
#include "mediapipe/framework/port/logging.h"
|
||||||
#include "mediapipe/framework/port/vector.h"
|
#include "mediapipe/framework/port/vector.h"
|
||||||
#include "mediapipe/util/color.pb.h"
|
#include "mediapipe/util/color.pb.h"
|
||||||
|
#include "mediapipe/util/render_data.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -112,6 +113,8 @@ void AnnotationRenderer::RenderDataOnImage(const RenderData& render_data) {
|
||||||
DrawGradientLine(annotation);
|
DrawGradientLine(annotation);
|
||||||
} else if (annotation.data_case() == RenderAnnotation::kArrow) {
|
} else if (annotation.data_case() == RenderAnnotation::kArrow) {
|
||||||
DrawArrow(annotation);
|
DrawArrow(annotation);
|
||||||
|
} else if (annotation.data_case() == RenderAnnotation::kScribble) {
|
||||||
|
DrawScribble(annotation);
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Unknown annotation type: " << annotation.data_case();
|
LOG(FATAL) << "Unknown annotation type: " << annotation.data_case();
|
||||||
}
|
}
|
||||||
|
@ -442,7 +445,11 @@ void AnnotationRenderer::DrawArrow(const RenderAnnotation& annotation) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) {
|
void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) {
|
||||||
const auto& point = annotation.point();
|
DrawPoint(annotation.point(), annotation);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AnnotationRenderer::DrawPoint(const RenderAnnotation::Point& point,
|
||||||
|
const RenderAnnotation& annotation) {
|
||||||
int x = -1;
|
int x = -1;
|
||||||
int y = -1;
|
int y = -1;
|
||||||
if (point.normalized()) {
|
if (point.normalized()) {
|
||||||
|
@ -460,6 +467,12 @@ void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) {
|
||||||
cv::circle(mat_image_, point_to_draw, thickness, color, -1);
|
cv::circle(mat_image_, point_to_draw, thickness, color, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AnnotationRenderer::DrawScribble(const RenderAnnotation& annotation) {
|
||||||
|
for (const RenderAnnotation::Point& point : annotation.scribble().point()) {
|
||||||
|
DrawPoint(point, annotation);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void AnnotationRenderer::DrawLine(const RenderAnnotation& annotation) {
|
void AnnotationRenderer::DrawLine(const RenderAnnotation& annotation) {
|
||||||
int x_start = -1;
|
int x_start = -1;
|
||||||
int y_start = -1;
|
int y_start = -1;
|
||||||
|
|
|
@ -96,6 +96,11 @@ class AnnotationRenderer {
|
||||||
|
|
||||||
// Draws a point on the image as described in the annotation.
|
// Draws a point on the image as described in the annotation.
|
||||||
void DrawPoint(const RenderAnnotation& annotation);
|
void DrawPoint(const RenderAnnotation& annotation);
|
||||||
|
void DrawPoint(const RenderAnnotation::Point& point,
|
||||||
|
const RenderAnnotation& annotation);
|
||||||
|
|
||||||
|
// Draws scribbles on the image as described in the annotation.
|
||||||
|
void DrawScribble(const RenderAnnotation& annotation);
|
||||||
|
|
||||||
// Draws a line segment on the image as described in the annotation.
|
// Draws a line segment on the image as described in the annotation.
|
||||||
void DrawLine(const RenderAnnotation& annotation);
|
void DrawLine(const RenderAnnotation& annotation);
|
||||||
|
|
|
@ -131,6 +131,10 @@ message RenderAnnotation {
|
||||||
optional Color color2 = 7;
|
optional Color color2 = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message Scribble {
|
||||||
|
repeated Point point = 1;
|
||||||
|
}
|
||||||
|
|
||||||
message Arrow {
|
message Arrow {
|
||||||
// The arrow head will be drawn at (x_end, y_end).
|
// The arrow head will be drawn at (x_end, y_end).
|
||||||
optional double x_start = 1;
|
optional double x_start = 1;
|
||||||
|
@ -192,6 +196,7 @@ message RenderAnnotation {
|
||||||
RoundedRectangle rounded_rectangle = 9;
|
RoundedRectangle rounded_rectangle = 9;
|
||||||
FilledRoundedRectangle filled_rounded_rectangle = 10;
|
FilledRoundedRectangle filled_rounded_rectangle = 10;
|
||||||
GradientLine gradient_line = 14;
|
GradientLine gradient_line = 14;
|
||||||
|
Scribble scribble = 15;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Thickness for drawing the annotation.
|
// Thickness for drawing the annotation.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user