Support scribble input for Interactive Segmenter
PiperOrigin-RevId: 529156049
This commit is contained in:
parent
baa8fc68a1
commit
09662749ea
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
|
@ -60,6 +61,8 @@ constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
|||
constexpr absl::string_view kSubgraphTypeName{
|
||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
|
||||
|
||||
using components::containers::NormalizedKeypoint;
|
||||
|
||||
using ::mediapipe::CalculatorGraphConfig;
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::NormalizedRect;
|
||||
|
@ -115,7 +118,7 @@ absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
|
|||
case RegionOfInterest::Format::kUnspecified:
|
||||
return absl::InvalidArgumentError(
|
||||
"RegionOfInterest format not specified");
|
||||
case RegionOfInterest::Format::kKeyPoint:
|
||||
case RegionOfInterest::Format::kKeyPoint: {
|
||||
RET_CHECK(roi.keypoint.has_value());
|
||||
auto* annotation = result.add_render_annotations();
|
||||
annotation->mutable_color()->set_r(255);
|
||||
|
@ -124,6 +127,19 @@ absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
|
|||
point->set_x(roi.keypoint->x);
|
||||
point->set_y(roi.keypoint->y);
|
||||
return result;
|
||||
}
|
||||
case RegionOfInterest::Format::kScribble: {
|
||||
RET_CHECK(roi.scribble.has_value());
|
||||
auto* annotation = result.add_render_annotations();
|
||||
annotation->mutable_color()->set_r(255);
|
||||
for (const NormalizedKeypoint& keypoint : *(roi.scribble)) {
|
||||
auto* point = annotation->mutable_scribble()->add_point();
|
||||
point->set_normalized(true);
|
||||
point->set_x(keypoint.x);
|
||||
point->set_y(keypoint.y);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return absl::UnimplementedError("Unrecognized format");
|
||||
}
|
||||
|
|
|
@ -53,6 +53,7 @@ struct RegionOfInterest {
|
|||
enum class Format {
|
||||
kUnspecified = 0, // Format not specified.
|
||||
kKeyPoint = 1, // Using keypoint to represent ROI.
|
||||
kScribble = 2, // Using scribble to represent ROI.
|
||||
};
|
||||
|
||||
// Specifies the format used to specify the region-of-interest. Note that
|
||||
|
@ -61,8 +62,13 @@ struct RegionOfInterest {
|
|||
Format format = Format::kUnspecified;
|
||||
|
||||
// Represents the ROI in keypoint format, this should be non-nullopt if
|
||||
// `format` is `KEYPOINT`.
|
||||
// `format` is `kKeyPoint`.
|
||||
std::optional<components::containers::NormalizedKeypoint> keypoint;
|
||||
|
||||
// Represents the ROI in scribble format, this should be non-nullopt if
|
||||
// `format` is `kScribble`.
|
||||
std::optional<std::vector<components::containers::NormalizedKeypoint>>
|
||||
scribble;
|
||||
};
|
||||
|
||||
// Performs interactive segmentation on images.
|
||||
|
|
|
@ -18,9 +18,12 @@ limitations under the License.
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
|
@ -179,22 +182,46 @@ TEST_F(CreateFromOptionsTest, FailsWithNeitherOutputSet) {
|
|||
struct InteractiveSegmenterTestParams {
|
||||
std::string test_name;
|
||||
RegionOfInterest::Format format;
|
||||
NormalizedKeypoint roi;
|
||||
std::variant<NormalizedKeypoint, std::vector<NormalizedKeypoint>> roi;
|
||||
absl::string_view golden_mask_file;
|
||||
float similarity_threshold;
|
||||
};
|
||||
|
||||
using SucceedSegmentationWithRoi =
|
||||
::testing::TestWithParam<InteractiveSegmenterTestParams>;
|
||||
class SucceedSegmentationWithRoi
|
||||
: public ::testing::TestWithParam<InteractiveSegmenterTestParams> {
|
||||
public:
|
||||
absl::StatusOr<RegionOfInterest> TestParamsToTaskOptions() {
|
||||
const InteractiveSegmenterTestParams& params = GetParam();
|
||||
|
||||
RegionOfInterest interaction_roi;
|
||||
interaction_roi.format = params.format;
|
||||
switch (params.format) {
|
||||
case (RegionOfInterest::Format::kKeyPoint): {
|
||||
interaction_roi.keypoint = std::get<NormalizedKeypoint>(params.roi);
|
||||
break;
|
||||
}
|
||||
case (RegionOfInterest::Format::kScribble): {
|
||||
interaction_roi.scribble =
|
||||
std::get<std::vector<NormalizedKeypoint>>(params.roi);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
return absl::InvalidArgumentError("Unknown ROI format");
|
||||
}
|
||||
}
|
||||
|
||||
return interaction_roi;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi,
|
||||
TestParamsToTaskOptions());
|
||||
const InteractiveSegmenterTestParams& params = GetParam();
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||
RegionOfInterest interaction_roi;
|
||||
interaction_roi.format = params.format;
|
||||
interaction_roi.keypoint = params.roi;
|
||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
|
@ -220,13 +247,13 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
|
|||
}
|
||||
|
||||
TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
||||
const auto& params = GetParam();
|
||||
MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi,
|
||||
TestParamsToTaskOptions());
|
||||
const InteractiveSegmenterTestParams& params = GetParam();
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||
RegionOfInterest interaction_roi;
|
||||
interaction_roi.format = params.format;
|
||||
interaction_roi.keypoint = params.roi;
|
||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
|
@ -253,11 +280,23 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
|||
INSTANTIATE_TEST_SUITE_P(
|
||||
SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi,
|
||||
::testing::ValuesIn<InteractiveSegmenterTestParams>(
|
||||
{{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
|
||||
{// Keypoint input.
|
||||
{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
|
||||
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
|
||||
{"PointToDog2", RegionOfInterest::Format::kKeyPoint,
|
||||
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
|
||||
kGoldenMaskSimilarity}}),
|
||||
kGoldenMaskSimilarity},
|
||||
// Scribble input.
|
||||
{"ScribbleToDog1", RegionOfInterest::Format::kScribble,
|
||||
std::vector{NormalizedKeypoint{0.44, 0.70},
|
||||
NormalizedKeypoint{0.44, 0.71},
|
||||
NormalizedKeypoint{0.44, 0.72}},
|
||||
kCatsAndDogsMaskDog1, 0.84f},
|
||||
{"ScribbleToDog2", RegionOfInterest::Format::kScribble,
|
||||
std::vector{NormalizedKeypoint{0.66, 0.66},
|
||||
NormalizedKeypoint{0.66, 0.67},
|
||||
NormalizedKeypoint{0.66, 0.68}},
|
||||
kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}),
|
||||
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
|
||||
info) { return info.param.test_name; });
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/vector.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
@ -112,6 +113,8 @@ void AnnotationRenderer::RenderDataOnImage(const RenderData& render_data) {
|
|||
DrawGradientLine(annotation);
|
||||
} else if (annotation.data_case() == RenderAnnotation::kArrow) {
|
||||
DrawArrow(annotation);
|
||||
} else if (annotation.data_case() == RenderAnnotation::kScribble) {
|
||||
DrawScribble(annotation);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown annotation type: " << annotation.data_case();
|
||||
}
|
||||
|
@ -442,7 +445,11 @@ void AnnotationRenderer::DrawArrow(const RenderAnnotation& annotation) {
|
|||
}
|
||||
|
||||
void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) {
|
||||
const auto& point = annotation.point();
|
||||
DrawPoint(annotation.point(), annotation);
|
||||
}
|
||||
|
||||
void AnnotationRenderer::DrawPoint(const RenderAnnotation::Point& point,
|
||||
const RenderAnnotation& annotation) {
|
||||
int x = -1;
|
||||
int y = -1;
|
||||
if (point.normalized()) {
|
||||
|
@ -460,6 +467,12 @@ void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) {
|
|||
cv::circle(mat_image_, point_to_draw, thickness, color, -1);
|
||||
}
|
||||
|
||||
void AnnotationRenderer::DrawScribble(const RenderAnnotation& annotation) {
|
||||
for (const RenderAnnotation::Point& point : annotation.scribble().point()) {
|
||||
DrawPoint(point, annotation);
|
||||
}
|
||||
}
|
||||
|
||||
void AnnotationRenderer::DrawLine(const RenderAnnotation& annotation) {
|
||||
int x_start = -1;
|
||||
int y_start = -1;
|
||||
|
|
|
@ -96,6 +96,11 @@ class AnnotationRenderer {
|
|||
|
||||
// Draws a point on the image as described in the annotation.
|
||||
void DrawPoint(const RenderAnnotation& annotation);
|
||||
void DrawPoint(const RenderAnnotation::Point& point,
|
||||
const RenderAnnotation& annotation);
|
||||
|
||||
// Draws scribbles on the image as described in the annotation.
|
||||
void DrawScribble(const RenderAnnotation& annotation);
|
||||
|
||||
// Draws a line segment on the image as described in the annotation.
|
||||
void DrawLine(const RenderAnnotation& annotation);
|
||||
|
|
|
@ -131,6 +131,10 @@ message RenderAnnotation {
|
|||
optional Color color2 = 7;
|
||||
}
|
||||
|
||||
message Scribble {
|
||||
repeated Point point = 1;
|
||||
}
|
||||
|
||||
message Arrow {
|
||||
// The arrow head will be drawn at (x_end, y_end).
|
||||
optional double x_start = 1;
|
||||
|
@ -192,6 +196,7 @@ message RenderAnnotation {
|
|||
RoundedRectangle rounded_rectangle = 9;
|
||||
FilledRoundedRectangle filled_rounded_rectangle = 10;
|
||||
GradientLine gradient_line = 14;
|
||||
Scribble scribble = 15;
|
||||
}
|
||||
|
||||
// Thickness for drawing the annotation.
|
||||
|
|
Loading…
Reference in New Issue
Block a user