Annotate in model input scale for InteractiveSegmenter

PiperOrigin-RevId: 539245617
This commit is contained in:
MediaPipe Team 2023-06-09 21:10:56 -07:00 committed by Copybara-Service
parent 1d4a205c2e
commit ac4f60a793
5 changed files with 136 additions and 13 deletions

View File

@ -52,6 +52,7 @@ cc_library(
name = "interactive_segmenter_graph", name = "interactive_segmenter_graph",
srcs = ["interactive_segmenter_graph.cc"], srcs = ["interactive_segmenter_graph.cc"],
deps = [ deps = [
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/image:set_alpha_calculator", "//mediapipe/calculators/image:set_alpha_calculator",
"//mediapipe/calculators/util:annotation_overlay_calculator", "//mediapipe/calculators/util:annotation_overlay_calculator",
"//mediapipe/calculators/util:flat_color_image_calculator", "//mediapipe/calculators/util:flat_color_image_calculator",
@ -60,6 +61,7 @@ cc_library(
"//mediapipe/calculators/util:to_image_calculator", "//mediapipe/calculators/util:to_image_calculator",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",

View File

@ -13,12 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <memory>
#include <vector> #include <vector>
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h" #include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
@ -35,6 +37,51 @@ namespace mediapipe {
namespace tasks { namespace tasks {
namespace vision { namespace vision {
namespace interactive_segmenter { namespace interactive_segmenter {
namespace internal {
// A calculator to add thickness to the render data according to the image size,
// so that the render data is scale invariant to the image size. If the render
// data already has thickness, it will be kept as is.
class AddThicknessToRenderDataCalculator : public api2::Node {
public:
static constexpr api2::Input<Image> kImageIn{"IMAGE"};
static constexpr api2::Input<mediapipe::RenderData> kRenderDataIn{
"RENDER_DATA"};
static constexpr api2::Output<mediapipe::RenderData> kRenderDataOut{
"RENDER_DATA"};
static constexpr int kModelInputTensorWidth = 512;
static constexpr int kModelInputTensorHeight = 512;
MEDIAPIPE_NODE_CONTRACT(kImageIn, kRenderDataIn, kRenderDataOut);
absl::Status Process(CalculatorContext* cc) final {
mediapipe::RenderData render_data = kRenderDataIn(cc).Get();
Image image = kImageIn(cc).Get();
double thickness = std::max(
std::max(image.width() / static_cast<double>(kModelInputTensorWidth),
image.height() / static_cast<double>(kModelInputTensorHeight)),
1.0);
for (auto& annotation : *render_data.mutable_render_annotations()) {
if (!annotation.has_thickness()) {
annotation.set_thickness(thickness);
}
}
kRenderDataOut(cc).Send(render_data);
return absl::OkStatus();
}
};
// NOLINTBEGIN: Node registration doesn't work when part of calculator name is
// moved to next line.
// clang-format off
MEDIAPIPE_REGISTER_NODE(
::mediapipe::tasks::vision::interactive_segmenter::internal::AddThicknessToRenderDataCalculator);
// clang-format on
// NOLINTEND
} // namespace internal
namespace { namespace {
@ -59,6 +106,7 @@ constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
constexpr absl::string_view kNormRectTag{"NORM_RECT"}; constexpr absl::string_view kNormRectTag{"NORM_RECT"};
constexpr absl::string_view kRoiTag{"ROI"}; constexpr absl::string_view kRoiTag{"ROI"};
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"}; constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
constexpr absl::string_view kRenderDataTag{"RENDER_DATA"};
// Updates the graph to return `roi` stream which has same dimension as // Updates the graph to return `roi` stream which has same dimension as
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is // `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
@ -69,14 +117,23 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
const absl::string_view image_tag_with_suffix = const absl::string_view image_tag_with_suffix =
use_gpu ? kImageGpuTag : kImageCpuTag; use_gpu ? kImageGpuTag : kImageCpuTag;
// Adds thickness to the render data so that the render data is scale
// invariant to the input image size.
auto& add_thickness = graph.AddNode(
"mediapipe::tasks::vision::interactive_segmenter::internal::"
"AddThicknessToRenderDataCalculator");
image >> add_thickness.In(kImageTag);
roi >> add_thickness.In(kRenderDataTag);
auto roi_with_thickness = add_thickness.Out(kRenderDataTag);
// Generates a blank canvas with same size as input image. // Generates a blank canvas with same size as input image.
auto& flat_color = graph.AddNode("FlatColorImageCalculator"); auto& flat_color = graph.AddNode("FlatColorImageCalculator");
auto& flat_color_options = auto& flat_color_options =
flat_color.GetOptions<FlatColorImageCalculatorOptions>(); flat_color.GetOptions<FlatColorImageCalculatorOptions>();
// SetAlphaCalculator only takes 1st channel. // SetAlphaCalculator only takes 1st channel.
flat_color_options.mutable_color()->set_r(0); flat_color_options.mutable_color()->set_r(0);
image >> flat_color.In(kImageTag)[0]; image >> flat_color.In(kImageTag);
auto blank_canvas = flat_color.Out(kImageTag)[0]; auto blank_canvas = flat_color.Out(kImageTag);
auto& from_mp_image = graph.AddNode("FromImageCalculator"); auto& from_mp_image = graph.AddNode("FromImageCalculator");
blank_canvas >> from_mp_image.In(kImageTag); blank_canvas >> from_mp_image.In(kImageTag);
@ -85,7 +142,7 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
auto& roi_to_alpha = graph.AddNode("AnnotationOverlayCalculator"); auto& roi_to_alpha = graph.AddNode("AnnotationOverlayCalculator");
blank_canvas_in_cpu_or_gpu >> blank_canvas_in_cpu_or_gpu >>
roi_to_alpha.In(use_gpu ? kImageGpuTag : kImageTag); roi_to_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
roi >> roi_to_alpha.In(0); roi_with_thickness >> roi_to_alpha.In(0);
auto alpha = roi_to_alpha.Out(use_gpu ? kImageGpuTag : kImageTag); auto alpha = roi_to_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
return alpha; return alpha;
@ -163,6 +220,7 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
image >> from_mp_image.In(kImageTag); image >> from_mp_image.In(kImageTag);
auto image_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix); auto image_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix);
// Creates an RGBA image with model input tensor size.
auto alpha_in_cpu_or_gpu = RoiToAlpha(image, roi, use_gpu, graph); auto alpha_in_cpu_or_gpu = RoiToAlpha(image, roi, use_gpu, graph);
auto& set_alpha = graph.AddNode("SetAlphaCalculator"); auto& set_alpha = graph.AddNode("SetAlphaCalculator");

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/tasks/cc/components/containers/keypoint.h" #include "mediapipe/tasks/cc/components/containers/keypoint.h"
#include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
@ -70,6 +71,10 @@ constexpr absl::string_view kCatsAndDogsJpg{"cats_and_dogs.jpg"};
// Golden mask for the dogs in cats_and_dogs.jpg. // Golden mask for the dogs in cats_and_dogs.jpg.
constexpr absl::string_view kCatsAndDogsMaskDog1{"cats_and_dogs_mask_dog1.png"}; constexpr absl::string_view kCatsAndDogsMaskDog1{"cats_and_dogs_mask_dog1.png"};
constexpr absl::string_view kCatsAndDogsMaskDog2{"cats_and_dogs_mask_dog2.png"}; constexpr absl::string_view kCatsAndDogsMaskDog2{"cats_and_dogs_mask_dog2.png"};
constexpr absl::string_view kPenguinsLarge{"penguins_large.jpg"};
constexpr absl::string_view kPenguinsSmall{"penguins_small.jpg"};
constexpr absl::string_view kPenguinsSmallMask{"penguins_small_mask.png"};
constexpr absl::string_view kPenguinsLargeMask{"penguins_large_mask.png"};
constexpr float kGoldenMaskSimilarity = 0.97; constexpr float kGoldenMaskSimilarity = 0.97;
@ -183,6 +188,7 @@ struct InteractiveSegmenterTestParams {
std::string test_name; std::string test_name;
RegionOfInterest::Format format; RegionOfInterest::Format format;
std::variant<NormalizedKeypoint, std::vector<NormalizedKeypoint>> roi; std::variant<NormalizedKeypoint, std::vector<NormalizedKeypoint>> roi;
absl::string_view input_image_file;
absl::string_view golden_mask_file; absl::string_view golden_mask_file;
float similarity_threshold; float similarity_threshold;
}; };
@ -220,8 +226,8 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
const InteractiveSegmenterTestParams& params = GetParam(); const InteractiveSegmenterTestParams& params = GetParam();
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
Image image, Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); params.input_image_file)));
auto options = std::make_unique<InteractiveSegmenterOptions>(); auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel); JoinPath("./", kTestDataDirectory, kPtmModel);
@ -244,6 +250,15 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
EXPECT_THAT(actual_mask, EXPECT_THAT(actual_mask,
SimilarToUint8Mask(expected_mask, params.similarity_threshold, SimilarToUint8Mask(expected_mask, params.similarity_threshold,
kGoldenMaskMagnificationFactor)); kGoldenMaskMagnificationFactor));
cv::Mat visualized_mask;
actual_mask.convertTo(visualized_mask, CV_8UC1, /*alpha=*/255);
ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8,
visualized_mask.cols, visualized_mask.rows,
visualized_mask.step, visualized_mask.data,
[visualized_mask](uint8_t[]) {});
MP_EXPECT_OK(SavePngTestOutput(
visualized_image, absl::StrFormat("%s_category_mask", params.test_name)));
} }
TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
@ -252,8 +267,8 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
const InteractiveSegmenterTestParams& params = GetParam(); const InteractiveSegmenterTestParams& params = GetParam();
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
Image image, Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); params.input_image_file)));
auto options = std::make_unique<InteractiveSegmenterOptions>(); auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel); JoinPath("./", kTestDataDirectory, kPtmModel);
@ -275,6 +290,15 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
result.confidence_masks->at(1).GetImageFrameSharedPtr().get()); result.confidence_masks->at(1).GetImageFrameSharedPtr().get());
EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float, EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float,
params.similarity_threshold)); params.similarity_threshold));
cv::Mat visualized_mask;
actual_mask.convertTo(visualized_mask, CV_8UC1, /*alpha=*/255);
ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8,
visualized_mask.cols, visualized_mask.rows,
visualized_mask.step, visualized_mask.data,
[visualized_mask](uint8_t[]) {});
MP_EXPECT_OK(SavePngTestOutput(
visualized_image,
absl::StrFormat("%s_confidence_mask", params.test_name)));
} }
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
@ -282,21 +306,28 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn<InteractiveSegmenterTestParams>( ::testing::ValuesIn<InteractiveSegmenterTestParams>(
{// Keypoint input. {// Keypoint input.
{"PointToDog1", RegionOfInterest::Format::kKeyPoint, {"PointToDog1", RegionOfInterest::Format::kKeyPoint,
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f}, NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsJpg, kCatsAndDogsMaskDog1,
0.84f},
{"PointToDog2", RegionOfInterest::Format::kKeyPoint, {"PointToDog2", RegionOfInterest::Format::kKeyPoint,
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2, NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsJpg, kCatsAndDogsMaskDog2,
kGoldenMaskSimilarity}, kGoldenMaskSimilarity},
{"PenguinsSmall", RegionOfInterest::Format::kKeyPoint,
NormalizedKeypoint{0.329, 0.545}, kPenguinsSmall, kPenguinsSmallMask,
0.9f},
{"PenguinsLarge", RegionOfInterest::Format::kKeyPoint,
NormalizedKeypoint{0.329, 0.545}, kPenguinsLarge, kPenguinsLargeMask,
0.9f},
// Scribble input. // Scribble input.
{"ScribbleToDog1", RegionOfInterest::Format::kScribble, {"ScribbleToDog1", RegionOfInterest::Format::kScribble,
std::vector{NormalizedKeypoint{0.44, 0.70}, std::vector{NormalizedKeypoint{0.44, 0.70},
NormalizedKeypoint{0.44, 0.71}, NormalizedKeypoint{0.44, 0.71},
NormalizedKeypoint{0.44, 0.72}}, NormalizedKeypoint{0.44, 0.72}},
kCatsAndDogsMaskDog1, 0.84f}, kCatsAndDogsJpg, kCatsAndDogsMaskDog1, 0.84f},
{"ScribbleToDog2", RegionOfInterest::Format::kScribble, {"ScribbleToDog2", RegionOfInterest::Format::kScribble,
std::vector{NormalizedKeypoint{0.66, 0.66}, std::vector{NormalizedKeypoint{0.66, 0.66},
NormalizedKeypoint{0.66, 0.67}, NormalizedKeypoint{0.66, 0.67},
NormalizedKeypoint{0.66, 0.68}}, NormalizedKeypoint{0.66, 0.68}},
kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}), kCatsAndDogsJpg, kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}),
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>& [](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
info) { return info.param.test_name; }); info) { return info.param.test_name; });

View File

@ -69,6 +69,10 @@ mediapipe_files(srcs = [
"multi_objects.jpg", "multi_objects.jpg",
"multi_objects_rotated.jpg", "multi_objects_rotated.jpg",
"palm_detection_full.tflite", "palm_detection_full.tflite",
"penguins_large.jpg",
"penguins_large_mask.png",
"penguins_small.jpg",
"penguins_small_mask.png",
"pointing_up.jpg", "pointing_up.jpg",
"pointing_up_rotated.jpg", "pointing_up_rotated.jpg",
"portrait.jpg", "portrait.jpg",
@ -135,6 +139,10 @@ filegroup(
"mozart_square.jpg", "mozart_square.jpg",
"multi_objects.jpg", "multi_objects.jpg",
"multi_objects_rotated.jpg", "multi_objects_rotated.jpg",
"penguins_large.jpg",
"penguins_large_mask.png",
"penguins_small.jpg",
"penguins_small_mask.png",
"pointing_up.jpg", "pointing_up.jpg",
"pointing_up_rotated.jpg", "pointing_up_rotated.jpg",
"portrait.jpg", "portrait.jpg",

View File

@ -66,8 +66,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_BUILD", name = "com_google_mediapipe_BUILD",
sha256 = "d2b2a8346202691d7f831887c84e9642e974f64ed67851d9a58cf15c94b1f6b3", sha256 = "cfbc1404ba18ee9eb0f08e9ee66d5b51f3fac47f683a5fa0cc23b46f30e05a1f",
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=1661875663693976167832357639365316787374795996401679955080207504"], urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=1686332366306166"],
) )
http_file( http_file(
@ -904,6 +904,30 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/palm_detection_lite.tflite?generation=1661875885885770"], urls = ["https://storage.googleapis.com/mediapipe-assets/palm_detection_lite.tflite?generation=1661875885885770"],
) )
http_file(
name = "com_google_mediapipe_penguins_large_jpg",
sha256 = "3a7a74bf946b3e2b53a3953516a552df854b2854c91b3372d2d6343497ca2160",
urls = ["https://storage.googleapis.com/mediapipe-assets/penguins_large.jpg?generation=1686332378707665"],
)
http_file(
name = "com_google_mediapipe_penguins_large_mask_png",
sha256 = "8f78486266dabb1a3f28bf52750c0d005f96233fe505d5e8dcba02c6ee3a13cb",
urls = ["https://storage.googleapis.com/mediapipe-assets/penguins_large_mask.png?generation=1686332381154669"],
)
http_file(
name = "com_google_mediapipe_penguins_small_jpg",
sha256 = "708ca356d8be4fbf5b76d4f2fcd094e97122cc24934cfcca22ac3ab0f13c4632",
urls = ["https://storage.googleapis.com/mediapipe-assets/penguins_small.jpg?generation=1686332383656645"],
)
http_file(
name = "com_google_mediapipe_penguins_small_mask_png",
sha256 = "65523dd7ed468ee4be3cd0cfed5badcfa41eaa5cd06444c9ab9b71b2d5951abe",
urls = ["https://storage.googleapis.com/mediapipe-assets/penguins_small_mask.png?generation=1686332385707707"],
)
http_file( http_file(
name = "com_google_mediapipe_pointing_up_jpg", name = "com_google_mediapipe_pointing_up_jpg",
sha256 = "ecf8ca2611d08fa25948a4fc10710af9120e88243a54da6356bacea17ff3e36e", sha256 = "ecf8ca2611d08fa25948a4fc10710af9120e88243a54da6356bacea17ff3e36e",