Update one-class segmentation category mask behavior on CPU to match latest API
PiperOrigin-RevId: 529917830
This commit is contained in:
parent
fb7f06b509
commit
cc8847def5
|
@ -61,6 +61,8 @@ using ::mediapipe::tasks::vision::GetImageLikeTensorShape;
|
||||||
using ::mediapipe::tasks::vision::Shape;
|
using ::mediapipe::tasks::vision::Shape;
|
||||||
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
|
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
|
||||||
|
|
||||||
|
constexpr uint8_t kUnLabeledPixelValue = 255;
|
||||||
|
|
||||||
void StableSoftmax(absl::Span<const float> values,
|
void StableSoftmax(absl::Span<const float> values,
|
||||||
absl::Span<float> activated_values) {
|
absl::Span<float> activated_values) {
|
||||||
float max_value = *std::max_element(values.begin(), values.end());
|
float max_value = *std::max_element(values.begin(), values.end());
|
||||||
|
@ -153,9 +155,11 @@ Image ProcessForCategoryMaskCpu(const Shape& input_shape,
|
||||||
}
|
}
|
||||||
if (input_channels == 1) {
|
if (input_channels == 1) {
|
||||||
// if the input tensor is a single mask, it is assumed to be a binary
|
// if the input tensor is a single mask, it is assumed to be a binary
|
||||||
// foreground segmentation mask. For such a mask, we make foreground
|
// foreground segmentation mask. For such a mask, instead of a true
|
||||||
// category 1, and background category 0.
|
// argmax, we simply use 0.5 as the cutoff, assigning 0 (foreground) or
|
||||||
pixel = static_cast<uint8_t>(confidence_scores[0] > 0.5f);
|
// 255 (background) based on whether the confidence value reaches this
|
||||||
|
// cutoff or not, respectively.
|
||||||
|
pixel = confidence_scores[0] > 0.5f ? 0 : kUnLabeledPixelValue;
|
||||||
} else {
|
} else {
|
||||||
const int maximum_category_idx =
|
const int maximum_category_idx =
|
||||||
std::max_element(confidence_scores.begin(), confidence_scores.end()) -
|
std::max_element(confidence_scores.begin(), confidence_scores.end()) -
|
||||||
|
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
||||||
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/framework/tool/test_util.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
|
@ -425,6 +426,28 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
|
||||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ImageModeTest, SucceedsSelfieSegmentationSingleLabel) {
|
||||||
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
|
ImageSegmenter::Create(std::move(options)));
|
||||||
|
ASSERT_EQ(segmenter->GetLabels().size(), 1);
|
||||||
|
EXPECT_EQ(segmenter->GetLabels()[0], "selfie");
|
||||||
|
MP_ASSERT_OK(segmenter->Close());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ImageModeTest, SucceedsSelfieSegmentationLandscapeSingleLabel) {
|
||||||
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
|
ImageSegmenter::Create(std::move(options)));
|
||||||
|
ASSERT_EQ(segmenter->GetLabels().size(), 1);
|
||||||
|
EXPECT_EQ(segmenter->GetLabels()[0], "selfie");
|
||||||
|
MP_ASSERT_OK(segmenter->Close());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) {
|
TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) {
|
||||||
Image image =
|
Image image =
|
||||||
GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg"));
|
GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg"));
|
||||||
|
@ -464,6 +487,9 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) {
|
||||||
EXPECT_TRUE(result.category_mask.has_value());
|
EXPECT_TRUE(result.category_mask.has_value());
|
||||||
MP_ASSERT_OK(segmenter->Close());
|
MP_ASSERT_OK(segmenter->Close());
|
||||||
|
|
||||||
|
MP_EXPECT_OK(
|
||||||
|
SavePngTestOutput(*result.category_mask->GetImageFrameSharedPtr(),
|
||||||
|
"portrait_selfie_segmentation_expected_category_mask"));
|
||||||
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
||||||
result.category_mask->GetImageFrameSharedPtr().get());
|
result.category_mask->GetImageFrameSharedPtr().get());
|
||||||
cv::Mat expected_mask = cv::imread(
|
cv::Mat expected_mask = cv::imread(
|
||||||
|
@ -471,7 +497,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) {
|
||||||
"portrait_selfie_segmentation_expected_category_mask.jpg"),
|
"portrait_selfie_segmentation_expected_category_mask.jpg"),
|
||||||
cv::IMREAD_GRAYSCALE);
|
cv::IMREAD_GRAYSCALE);
|
||||||
EXPECT_THAT(selfie_mask,
|
EXPECT_THAT(selfie_mask,
|
||||||
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 255));
|
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) {
|
TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) {
|
||||||
|
@ -487,6 +513,9 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) {
|
||||||
EXPECT_TRUE(result.category_mask.has_value());
|
EXPECT_TRUE(result.category_mask.has_value());
|
||||||
MP_ASSERT_OK(segmenter->Close());
|
MP_ASSERT_OK(segmenter->Close());
|
||||||
|
|
||||||
|
MP_EXPECT_OK(SavePngTestOutput(
|
||||||
|
*result.category_mask->GetImageFrameSharedPtr(),
|
||||||
|
"portrait_selfie_segmentation_landscape_expected_category_mask"));
|
||||||
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
||||||
result.category_mask->GetImageFrameSharedPtr().get());
|
result.category_mask->GetImageFrameSharedPtr().get());
|
||||||
cv::Mat expected_mask = cv::imread(
|
cv::Mat expected_mask = cv::imread(
|
||||||
|
@ -495,7 +524,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) {
|
||||||
"portrait_selfie_segmentation_landscape_expected_category_mask.jpg"),
|
"portrait_selfie_segmentation_landscape_expected_category_mask.jpg"),
|
||||||
cv::IMREAD_GRAYSCALE);
|
cv::IMREAD_GRAYSCALE);
|
||||||
EXPECT_THAT(selfie_mask,
|
EXPECT_THAT(selfie_mask,
|
||||||
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 255));
|
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsHairSegmentation) {
|
TEST_F(ImageModeTest, SucceedsHairSegmentation) {
|
||||||
|
|
16
third_party/external_files.bzl
vendored
16
third_party/external_files.bzl
vendored
|
@ -960,8 +960,8 @@ def external_files():
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_portrait_selfie_segmentation_expected_category_mask_jpg",
|
name = "com_google_mediapipe_portrait_selfie_segmentation_expected_category_mask_jpg",
|
||||||
sha256 = "d8f20fa746e14067f668dd293f21bbc50ec81196d186386a6ded1278c3ec8f46",
|
sha256 = "1400c6fccf3805bfd1644d7ed9be98dfa4f900e1720838c566963f8d9f10f5d0",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_expected_category_mask.jpg?generation=1678606935088873"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_expected_category_mask.jpg?generation=1683332555306471"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
@ -972,8 +972,8 @@ def external_files():
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_portrait_selfie_segmentation_landscape_expected_category_mask_jpg",
|
name = "com_google_mediapipe_portrait_selfie_segmentation_landscape_expected_category_mask_jpg",
|
||||||
sha256 = "f5c3fa3d93f8e7289b69b8a89c2519276dfa5014dcc50ed6e86e8cd4d4ae7f27",
|
sha256 = "a208aeeeb615fd40046d883e2c7982458e1b12edd6526e88c305c4053b0a9399",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_landscape_expected_category_mask.jpg?generation=1678606939469429"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_landscape_expected_category_mask.jpg?generation=1683332557473435"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
@ -1158,14 +1158,14 @@ def external_files():
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_selfie_segmentation_landscape_tflite",
|
name = "com_google_mediapipe_selfie_segmentation_landscape_tflite",
|
||||||
sha256 = "28fb4c287d6295a2dba6c1f43b43315a37f927ddcd6693d635d625d176eef162",
|
sha256 = "a77d03f4659b9f6b6c1f5106947bf40e99d7655094b6527f214ea7d451106edd",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation_landscape.tflite?generation=1678775102234495"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation_landscape.tflite?generation=1683332561312022"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_selfie_segmentation_tflite",
|
name = "com_google_mediapipe_selfie_segmentation_tflite",
|
||||||
sha256 = "b0e2ec6f95107795b952b27f3d92806b45f0bc069dac76dcd264cd1b90d61c6c",
|
sha256 = "9ee168ec7c8f2a16c56fe8e1cfbc514974cbbb7e434051b455635f1bd1462f5c",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1678775104900954"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1683332563830600"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user