From cc8847def5c13263d361340a354d3987c0dd276e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sat, 6 May 2023 00:53:36 -0700 Subject: [PATCH] Update one-class segmentation category mask behavior on CPU to match latest API PiperOrigin-RevId: 529917830 --- .../tensors_to_segmentation_calculator.cc | 10 ++++-- .../image_segmenter/image_segmenter_test.cc | 33 +++++++++++++++++-- third_party/external_files.bzl | 16 ++++----- 3 files changed, 46 insertions(+), 13 deletions(-) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc index c2d1520dd..660dc59b7 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc @@ -61,6 +61,8 @@ using ::mediapipe::tasks::vision::GetImageLikeTensorShape; using ::mediapipe::tasks::vision::Shape; using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; +constexpr uint8_t kUnLabeledPixelValue = 255; + void StableSoftmax(absl::Span values, absl::Span activated_values) { float max_value = *std::max_element(values.begin(), values.end()); @@ -153,9 +155,11 @@ Image ProcessForCategoryMaskCpu(const Shape& input_shape, } if (input_channels == 1) { // if the input tensor is a single mask, it is assumed to be a binary - // foreground segmentation mask. For such a mask, we make foreground - // category 1, and background category 0. - pixel = static_cast(confidence_scores[0] > 0.5f); + // foreground segmentation mask. For such a mask, instead of a true + // argmax, we simply use 0.5 as the cutoff, assigning 0 (foreground) or + // 255 (background) based on whether the confidence value reaches this + // cutoff or not, respectively. + pixel = confidence_scores[0] > 0.5f ? 0 : kUnLabeledPixelValue; } else { const int maximum_category_idx = std::max_element(confidence_scores.begin(), confidence_scores.end()) - diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 339ec1424..656ed0715 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/test_util.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" @@ -425,6 +426,28 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } +TEST_F(ImageModeTest, SucceedsSelfieSegmentationSingleLabel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kSelfieSegmentation); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + ASSERT_EQ(segmenter->GetLabels().size(), 1); + EXPECT_EQ(segmenter->GetLabels()[0], "selfie"); + MP_ASSERT_OK(segmenter->Close()); +} + +TEST_F(ImageModeTest, SucceedsSelfieSegmentationLandscapeSingleLabel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + ASSERT_EQ(segmenter->GetLabels().size(), 1); + EXPECT_EQ(segmenter->GetLabels()[0], "selfie"); + MP_ASSERT_OK(segmenter->Close()); +} + TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { Image image = GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg")); @@ -464,6 +487,9 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) { EXPECT_TRUE(result.category_mask.has_value()); MP_ASSERT_OK(segmenter->Close()); + MP_EXPECT_OK( + SavePngTestOutput(*result.category_mask->GetImageFrameSharedPtr(), + "portrait_selfie_segmentation_expected_category_mask")); cv::Mat selfie_mask = mediapipe::formats::MatView( result.category_mask->GetImageFrameSharedPtr().get()); cv::Mat expected_mask = cv::imread( @@ -471,7 +497,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) { "portrait_selfie_segmentation_expected_category_mask.jpg"), cv::IMREAD_GRAYSCALE); EXPECT_THAT(selfie_mask, - SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 255)); + SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 1)); } TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) { @@ -487,6 +513,9 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) { EXPECT_TRUE(result.category_mask.has_value()); MP_ASSERT_OK(segmenter->Close()); + MP_EXPECT_OK(SavePngTestOutput( + *result.category_mask->GetImageFrameSharedPtr(), + "portrait_selfie_segmentation_landscape_expected_category_mask")); cv::Mat selfie_mask = mediapipe::formats::MatView( result.category_mask->GetImageFrameSharedPtr().get()); cv::Mat expected_mask = cv::imread( @@ -495,7 +524,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) { "portrait_selfie_segmentation_landscape_expected_category_mask.jpg"), cv::IMREAD_GRAYSCALE); EXPECT_THAT(selfie_mask, - SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 255)); + SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 1)); } TEST_F(ImageModeTest, SucceedsHairSegmentation) { diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 599248f48..652a2947f 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -960,8 +960,8 @@ def external_files(): http_file( name = "com_google_mediapipe_portrait_selfie_segmentation_expected_category_mask_jpg", - sha256 = "d8f20fa746e14067f668dd293f21bbc50ec81196d186386a6ded1278c3ec8f46", - urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_expected_category_mask.jpg?generation=1678606935088873"], + sha256 = "1400c6fccf3805bfd1644d7ed9be98dfa4f900e1720838c566963f8d9f10f5d0", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_expected_category_mask.jpg?generation=1683332555306471"], ) http_file( @@ -972,8 +972,8 @@ def external_files(): http_file( name = "com_google_mediapipe_portrait_selfie_segmentation_landscape_expected_category_mask_jpg", - sha256 = "f5c3fa3d93f8e7289b69b8a89c2519276dfa5014dcc50ed6e86e8cd4d4ae7f27", - urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_landscape_expected_category_mask.jpg?generation=1678606939469429"], + sha256 = "a208aeeeb615fd40046d883e2c7982458e1b12edd6526e88c305c4053b0a9399", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_landscape_expected_category_mask.jpg?generation=1683332557473435"], ) http_file( @@ -1158,14 +1158,14 @@ def external_files(): http_file( name = "com_google_mediapipe_selfie_segmentation_landscape_tflite", - sha256 = "28fb4c287d6295a2dba6c1f43b43315a37f927ddcd6693d635d625d176eef162", - urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation_landscape.tflite?generation=1678775102234495"], + sha256 = "a77d03f4659b9f6b6c1f5106947bf40e99d7655094b6527f214ea7d451106edd", + urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation_landscape.tflite?generation=1683332561312022"], ) http_file( name = "com_google_mediapipe_selfie_segmentation_tflite", - sha256 = "b0e2ec6f95107795b952b27f3d92806b45f0bc069dac76dcd264cd1b90d61c6c", - urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1678775104900954"], + sha256 = "9ee168ec7c8f2a16c56fe8e1cfbc514974cbbb7e434051b455635f1bd1462f5c", + urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1683332563830600"], ) http_file(