From ac4f60a79385a78ef5af2b8505ee653862bd1c8c Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 9 Jun 2023 21:10:56 -0700 Subject: [PATCH] Annotate in model input scale for InteractiveSegmenter PiperOrigin-RevId: 539245617 --- .../cc/vision/interactive_segmenter/BUILD | 2 + .../interactive_segmenter_graph.cc | 64 ++++++++++++++++++- .../interactive_segmenter_test.cc | 47 +++++++++++--- mediapipe/tasks/testdata/vision/BUILD | 8 +++ third_party/external_files.bzl | 28 +++++++- 5 files changed, 136 insertions(+), 13 deletions(-) diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD b/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD index d02b5db36..177cbf43a 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD @@ -52,6 +52,7 @@ cc_library( name = "interactive_segmenter_graph", srcs = ["interactive_segmenter_graph.cc"], deps = [ + "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/image:set_alpha_calculator", "//mediapipe/calculators/util:annotation_overlay_calculator", "//mediapipe/calculators/util:flat_color_image_calculator", @@ -60,6 +61,7 @@ cc_library( "//mediapipe/calculators/util:to_image_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc index 5bb3e8ece..5ae2792fe 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc @@ -13,12 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "mediapipe/calculators/util/flat_color_image_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image.h" @@ -35,6 +37,51 @@ namespace mediapipe { namespace tasks { namespace vision { namespace interactive_segmenter { +namespace internal { + +// A calculator to add thickness to the render data according to the image size, +// so that the render data is scale invariant to the image size. If the render +// data already has thickness, it will be kept as is. +class AddThicknessToRenderDataCalculator : public api2::Node { + public: + static constexpr api2::Input kImageIn{"IMAGE"}; + static constexpr api2::Input kRenderDataIn{ + "RENDER_DATA"}; + static constexpr api2::Output kRenderDataOut{ + "RENDER_DATA"}; + + static constexpr int kModelInputTensorWidth = 512; + static constexpr int kModelInputTensorHeight = 512; + + MEDIAPIPE_NODE_CONTRACT(kImageIn, kRenderDataIn, kRenderDataOut); + + absl::Status Process(CalculatorContext* cc) final { + mediapipe::RenderData render_data = kRenderDataIn(cc).Get(); + Image image = kImageIn(cc).Get(); + double thickness = std::max( + std::max(image.width() / static_cast(kModelInputTensorWidth), + image.height() / static_cast(kModelInputTensorHeight)), + 1.0); + + for (auto& annotation : *render_data.mutable_render_annotations()) { + if (!annotation.has_thickness()) { + annotation.set_thickness(thickness); + } + } + kRenderDataOut(cc).Send(render_data); + return absl::OkStatus(); + } +}; + +// NOLINTBEGIN: Node registration doesn't work when part of calculator name is +// moved to next line. +// clang-format off +MEDIAPIPE_REGISTER_NODE( + ::mediapipe::tasks::vision::interactive_segmenter::internal::AddThicknessToRenderDataCalculator); +// clang-format on +// NOLINTEND + +} // namespace internal namespace { @@ -59,6 +106,7 @@ constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"}; constexpr absl::string_view kNormRectTag{"NORM_RECT"}; constexpr absl::string_view kRoiTag{"ROI"}; constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"}; +constexpr absl::string_view kRenderDataTag{"RENDER_DATA"}; // Updates the graph to return `roi` stream which has same dimension as // `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is @@ -69,14 +117,23 @@ Source<> RoiToAlpha(Source image, Source roi, bool use_gpu, const absl::string_view image_tag_with_suffix = use_gpu ? kImageGpuTag : kImageCpuTag; + // Adds thickness to the render data so that the render data is scale + // invariant to the input image size. + auto& add_thickness = graph.AddNode( + "mediapipe::tasks::vision::interactive_segmenter::internal::" + "AddThicknessToRenderDataCalculator"); + image >> add_thickness.In(kImageTag); + roi >> add_thickness.In(kRenderDataTag); + auto roi_with_thickness = add_thickness.Out(kRenderDataTag); + // Generates a blank canvas with same size as input image. auto& flat_color = graph.AddNode("FlatColorImageCalculator"); auto& flat_color_options = flat_color.GetOptions(); // SetAlphaCalculator only takes 1st channel. flat_color_options.mutable_color()->set_r(0); - image >> flat_color.In(kImageTag)[0]; - auto blank_canvas = flat_color.Out(kImageTag)[0]; + image >> flat_color.In(kImageTag); + auto blank_canvas = flat_color.Out(kImageTag); auto& from_mp_image = graph.AddNode("FromImageCalculator"); blank_canvas >> from_mp_image.In(kImageTag); @@ -85,7 +142,7 @@ Source<> RoiToAlpha(Source image, Source roi, bool use_gpu, auto& roi_to_alpha = graph.AddNode("AnnotationOverlayCalculator"); blank_canvas_in_cpu_or_gpu >> roi_to_alpha.In(use_gpu ? kImageGpuTag : kImageTag); - roi >> roi_to_alpha.In(0); + roi_with_thickness >> roi_to_alpha.In(0); auto alpha = roi_to_alpha.Out(use_gpu ? kImageGpuTag : kImageTag); return alpha; @@ -163,6 +220,7 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph { image >> from_mp_image.In(kImageTag); auto image_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix); + // Creates an RGBA image with model input tensor size. auto alpha_in_cpu_or_gpu = RoiToAlpha(image, roi, use_gpu, graph); auto& set_alpha = graph.AddNode("SetAlphaCalculator"); diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc index 16d065f61..2bb06428e 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/test_util.h" #include "mediapipe/tasks/cc/components/containers/keypoint.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" @@ -70,6 +71,10 @@ constexpr absl::string_view kCatsAndDogsJpg{"cats_and_dogs.jpg"}; // Golden mask for the dogs in cats_and_dogs.jpg. constexpr absl::string_view kCatsAndDogsMaskDog1{"cats_and_dogs_mask_dog1.png"}; constexpr absl::string_view kCatsAndDogsMaskDog2{"cats_and_dogs_mask_dog2.png"}; +constexpr absl::string_view kPenguinsLarge{"penguins_large.jpg"}; +constexpr absl::string_view kPenguinsSmall{"penguins_small.jpg"}; +constexpr absl::string_view kPenguinsSmallMask{"penguins_small_mask.png"}; +constexpr absl::string_view kPenguinsLargeMask{"penguins_large_mask.png"}; constexpr float kGoldenMaskSimilarity = 0.97; @@ -183,6 +188,7 @@ struct InteractiveSegmenterTestParams { std::string test_name; RegionOfInterest::Format format; std::variant> roi; + absl::string_view input_image_file; absl::string_view golden_mask_file; float similarity_threshold; }; @@ -220,8 +226,8 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) { const InteractiveSegmenterTestParams& params = GetParam(); MP_ASSERT_OK_AND_ASSIGN( - Image image, - DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + params.input_image_file))); auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); @@ -244,6 +250,15 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) { EXPECT_THAT(actual_mask, SimilarToUint8Mask(expected_mask, params.similarity_threshold, kGoldenMaskMagnificationFactor)); + + cv::Mat visualized_mask; + actual_mask.convertTo(visualized_mask, CV_8UC1, /*alpha=*/255); + ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8, + visualized_mask.cols, visualized_mask.rows, + visualized_mask.step, visualized_mask.data, + [visualized_mask](uint8_t[]) {}); + MP_EXPECT_OK(SavePngTestOutput( + visualized_image, absl::StrFormat("%s_category_mask", params.test_name))); } TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { @@ -252,8 +267,8 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { const InteractiveSegmenterTestParams& params = GetParam(); MP_ASSERT_OK_AND_ASSIGN( - Image image, - DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + params.input_image_file))); auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); @@ -275,6 +290,15 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { result.confidence_masks->at(1).GetImageFrameSharedPtr().get()); EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float, params.similarity_threshold)); + cv::Mat visualized_mask; + actual_mask.convertTo(visualized_mask, CV_8UC1, /*alpha=*/255); + ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8, + visualized_mask.cols, visualized_mask.rows, + visualized_mask.step, visualized_mask.data, + [visualized_mask](uint8_t[]) {}); + MP_EXPECT_OK(SavePngTestOutput( + visualized_image, + absl::StrFormat("%s_confidence_mask", params.test_name))); } INSTANTIATE_TEST_SUITE_P( @@ -282,21 +306,28 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn( {// Keypoint input. {"PointToDog1", RegionOfInterest::Format::kKeyPoint, - NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f}, + NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsJpg, kCatsAndDogsMaskDog1, + 0.84f}, {"PointToDog2", RegionOfInterest::Format::kKeyPoint, - NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2, + NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsJpg, kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}, + {"PenguinsSmall", RegionOfInterest::Format::kKeyPoint, + NormalizedKeypoint{0.329, 0.545}, kPenguinsSmall, kPenguinsSmallMask, + 0.9f}, + {"PenguinsLarge", RegionOfInterest::Format::kKeyPoint, + NormalizedKeypoint{0.329, 0.545}, kPenguinsLarge, kPenguinsLargeMask, + 0.9f}, // Scribble input. {"ScribbleToDog1", RegionOfInterest::Format::kScribble, std::vector{NormalizedKeypoint{0.44, 0.70}, NormalizedKeypoint{0.44, 0.71}, NormalizedKeypoint{0.44, 0.72}}, - kCatsAndDogsMaskDog1, 0.84f}, + kCatsAndDogsJpg, kCatsAndDogsMaskDog1, 0.84f}, {"ScribbleToDog2", RegionOfInterest::Format::kScribble, std::vector{NormalizedKeypoint{0.66, 0.66}, NormalizedKeypoint{0.66, 0.67}, NormalizedKeypoint{0.66, 0.68}}, - kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}), + kCatsAndDogsJpg, kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}), [](const ::testing::TestParamInfo& info) { return info.param.test_name; }); diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index e2622a3c8..4fde58e02 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -69,6 +69,10 @@ mediapipe_files(srcs = [ "multi_objects.jpg", "multi_objects_rotated.jpg", "palm_detection_full.tflite", + "penguins_large.jpg", + "penguins_large_mask.png", + "penguins_small.jpg", + "penguins_small_mask.png", "pointing_up.jpg", "pointing_up_rotated.jpg", "portrait.jpg", @@ -135,6 +139,10 @@ filegroup( "mozart_square.jpg", "multi_objects.jpg", "multi_objects_rotated.jpg", + "penguins_large.jpg", + "penguins_large_mask.png", + "penguins_small.jpg", + "penguins_small_mask.png", "pointing_up.jpg", "pointing_up_rotated.jpg", "portrait.jpg", diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 722ec3426..4b51d9de0 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -66,8 +66,8 @@ def external_files(): http_file( name = "com_google_mediapipe_BUILD", - sha256 = "d2b2a8346202691d7f831887c84e9642e974f64ed67851d9a58cf15c94b1f6b3", - urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=1661875663693976167832357639365316787374795996401679955080207504"], + sha256 = "cfbc1404ba18ee9eb0f08e9ee66d5b51f3fac47f683a5fa0cc23b46f30e05a1f", + urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=1686332366306166"], ) http_file( @@ -904,6 +904,30 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/palm_detection_lite.tflite?generation=1661875885885770"], ) + http_file( + name = "com_google_mediapipe_penguins_large_jpg", + sha256 = "3a7a74bf946b3e2b53a3953516a552df854b2854c91b3372d2d6343497ca2160", + urls = ["https://storage.googleapis.com/mediapipe-assets/penguins_large.jpg?generation=1686332378707665"], + ) + + http_file( + name = "com_google_mediapipe_penguins_large_mask_png", + sha256 = "8f78486266dabb1a3f28bf52750c0d005f96233fe505d5e8dcba02c6ee3a13cb", + urls = ["https://storage.googleapis.com/mediapipe-assets/penguins_large_mask.png?generation=1686332381154669"], + ) + + http_file( + name = "com_google_mediapipe_penguins_small_jpg", + sha256 = "708ca356d8be4fbf5b76d4f2fcd094e97122cc24934cfcca22ac3ab0f13c4632", + urls = ["https://storage.googleapis.com/mediapipe-assets/penguins_small.jpg?generation=1686332383656645"], + ) + + http_file( + name = "com_google_mediapipe_penguins_small_mask_png", + sha256 = "65523dd7ed468ee4be3cd0cfed5badcfa41eaa5cd06444c9ab9b71b2d5951abe", + urls = ["https://storage.googleapis.com/mediapipe-assets/penguins_small_mask.png?generation=1686332385707707"], + ) + http_file( name = "com_google_mediapipe_pointing_up_jpg", sha256 = "ecf8ca2611d08fa25948a4fc10710af9120e88243a54da6356bacea17ff3e36e",