Update one-class segmentation category mask behavior on CPU to match latest API

PiperOrigin-RevId: 529917830
This commit is contained in:
MediaPipe Team 2023-05-06 00:53:36 -07:00 committed by Copybara-Service
parent fb7f06b509
commit cc8847def5
3 changed files with 46 additions and 13 deletions

View File

@ -61,6 +61,8 @@ using ::mediapipe::tasks::vision::GetImageLikeTensorShape;
using ::mediapipe::tasks::vision::Shape; using ::mediapipe::tasks::vision::Shape;
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
constexpr uint8_t kUnLabeledPixelValue = 255;
void StableSoftmax(absl::Span<const float> values, void StableSoftmax(absl::Span<const float> values,
absl::Span<float> activated_values) { absl::Span<float> activated_values) {
float max_value = *std::max_element(values.begin(), values.end()); float max_value = *std::max_element(values.begin(), values.end());
@ -153,9 +155,11 @@ Image ProcessForCategoryMaskCpu(const Shape& input_shape,
} }
if (input_channels == 1) { if (input_channels == 1) {
// if the input tensor is a single mask, it is assumed to be a binary // if the input tensor is a single mask, it is assumed to be a binary
// foreground segmentation mask. For such a mask, we make foreground // foreground segmentation mask. For such a mask, instead of a true
// category 1, and background category 0. // argmax, we simply use 0.5 as the cutoff, assigning 0 (foreground) or
pixel = static_cast<uint8_t>(confidence_scores[0] > 0.5f); // 255 (background) based on whether the confidence value reaches this
// cutoff or not, respectively.
pixel = confidence_scores[0] > 0.5f ? 0 : kUnLabeledPixelValue;
} else { } else {
const int maximum_category_idx = const int maximum_category_idx =
std::max_element(confidence_scores.begin(), confidence_scores.end()) - std::max_element(confidence_scores.begin(), confidence_scores.end()) -

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
@ -425,6 +426,28 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
} }
TEST_F(ImageModeTest, SucceedsSelfieSegmentationSingleLabel) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
ASSERT_EQ(segmenter->GetLabels().size(), 1);
EXPECT_EQ(segmenter->GetLabels()[0], "selfie");
MP_ASSERT_OK(segmenter->Close());
}
TEST_F(ImageModeTest, SucceedsSelfieSegmentationLandscapeSingleLabel) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
ASSERT_EQ(segmenter->GetLabels().size(), 1);
EXPECT_EQ(segmenter->GetLabels()[0], "selfie");
MP_ASSERT_OK(segmenter->Close());
}
TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) {
Image image = Image image =
GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg")); GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg"));
@ -464,6 +487,9 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) {
EXPECT_TRUE(result.category_mask.has_value()); EXPECT_TRUE(result.category_mask.has_value());
MP_ASSERT_OK(segmenter->Close()); MP_ASSERT_OK(segmenter->Close());
MP_EXPECT_OK(
SavePngTestOutput(*result.category_mask->GetImageFrameSharedPtr(),
"portrait_selfie_segmentation_expected_category_mask"));
cv::Mat selfie_mask = mediapipe::formats::MatView( cv::Mat selfie_mask = mediapipe::formats::MatView(
result.category_mask->GetImageFrameSharedPtr().get()); result.category_mask->GetImageFrameSharedPtr().get());
cv::Mat expected_mask = cv::imread( cv::Mat expected_mask = cv::imread(
@ -471,7 +497,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) {
"portrait_selfie_segmentation_expected_category_mask.jpg"), "portrait_selfie_segmentation_expected_category_mask.jpg"),
cv::IMREAD_GRAYSCALE); cv::IMREAD_GRAYSCALE);
EXPECT_THAT(selfie_mask, EXPECT_THAT(selfie_mask,
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 255)); SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 1));
} }
TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) { TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) {
@ -487,6 +513,9 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) {
EXPECT_TRUE(result.category_mask.has_value()); EXPECT_TRUE(result.category_mask.has_value());
MP_ASSERT_OK(segmenter->Close()); MP_ASSERT_OK(segmenter->Close());
MP_EXPECT_OK(SavePngTestOutput(
*result.category_mask->GetImageFrameSharedPtr(),
"portrait_selfie_segmentation_landscape_expected_category_mask"));
cv::Mat selfie_mask = mediapipe::formats::MatView( cv::Mat selfie_mask = mediapipe::formats::MatView(
result.category_mask->GetImageFrameSharedPtr().get()); result.category_mask->GetImageFrameSharedPtr().get());
cv::Mat expected_mask = cv::imread( cv::Mat expected_mask = cv::imread(
@ -495,7 +524,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) {
"portrait_selfie_segmentation_landscape_expected_category_mask.jpg"), "portrait_selfie_segmentation_landscape_expected_category_mask.jpg"),
cv::IMREAD_GRAYSCALE); cv::IMREAD_GRAYSCALE);
EXPECT_THAT(selfie_mask, EXPECT_THAT(selfie_mask,
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 255)); SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 1));
} }
TEST_F(ImageModeTest, SucceedsHairSegmentation) { TEST_F(ImageModeTest, SucceedsHairSegmentation) {

View File

@ -960,8 +960,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_portrait_selfie_segmentation_expected_category_mask_jpg", name = "com_google_mediapipe_portrait_selfie_segmentation_expected_category_mask_jpg",
sha256 = "d8f20fa746e14067f668dd293f21bbc50ec81196d186386a6ded1278c3ec8f46", sha256 = "1400c6fccf3805bfd1644d7ed9be98dfa4f900e1720838c566963f8d9f10f5d0",
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_expected_category_mask.jpg?generation=1678606935088873"], urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_expected_category_mask.jpg?generation=1683332555306471"],
) )
http_file( http_file(
@ -972,8 +972,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_portrait_selfie_segmentation_landscape_expected_category_mask_jpg", name = "com_google_mediapipe_portrait_selfie_segmentation_landscape_expected_category_mask_jpg",
sha256 = "f5c3fa3d93f8e7289b69b8a89c2519276dfa5014dcc50ed6e86e8cd4d4ae7f27", sha256 = "a208aeeeb615fd40046d883e2c7982458e1b12edd6526e88c305c4053b0a9399",
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_landscape_expected_category_mask.jpg?generation=1678606939469429"], urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_landscape_expected_category_mask.jpg?generation=1683332557473435"],
) )
http_file( http_file(
@ -1158,14 +1158,14 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_selfie_segmentation_landscape_tflite", name = "com_google_mediapipe_selfie_segmentation_landscape_tflite",
sha256 = "28fb4c287d6295a2dba6c1f43b43315a37f927ddcd6693d635d625d176eef162", sha256 = "a77d03f4659b9f6b6c1f5106947bf40e99d7655094b6527f214ea7d451106edd",
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation_landscape.tflite?generation=1678775102234495"], urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation_landscape.tflite?generation=1683332561312022"],
) )
http_file( http_file(
name = "com_google_mediapipe_selfie_segmentation_tflite", name = "com_google_mediapipe_selfie_segmentation_tflite",
sha256 = "b0e2ec6f95107795b952b27f3d92806b45f0bc069dac76dcd264cd1b90d61c6c", sha256 = "9ee168ec7c8f2a16c56fe8e1cfbc514974cbbb7e434051b455635f1bd1462f5c",
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1678775104900954"], urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1683332563830600"],
) )
http_file( http_file(