Add quality test for InteractiveSegmenter

PiperOrigin-RevId: 516968294
This commit is contained in:
MediaPipe Team 2023-03-15 17:00:50 -07:00 committed by Copybara-Service
parent 61bcddc671
commit 8f1ce5fef6
3 changed files with 81 additions and 26 deletions

View File

@ -15,8 +15,8 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h" #include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h"
#include <cstdint>
#include <memory> #include <memory>
#include <string>
#include "absl/flags/flag.h" #include "absl/flags/flag.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
@ -28,6 +28,7 @@ limitations under the License.
#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
#include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
@ -47,6 +48,7 @@ namespace {
using ::mediapipe::Image; using ::mediapipe::Image;
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::NormalizedKeypoint;
using ::mediapipe::tasks::components::containers::RectF; using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr; using ::testing::HasSubstr;
@ -55,14 +57,16 @@ using ::testing::Optional;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kPtmModel[] = "ptm_512_hdt_ptm_woid.tflite"; constexpr char kPtmModel[] = "ptm_512_hdt_ptm_woid.tflite";
constexpr char kCatsAndDogsJpg[] = "cats_and_dogs.jpg"; constexpr char kCatsAndDogsJpg[] = "cats_and_dogs.jpg";
// Golden mask for the dogs in cats_and_dogs.jpg.
constexpr char kCatsAndDogsMaskDog1[] = "cats_and_dogs_mask_dog1.png";
constexpr char kCatsAndDogsMaskDog2[] = "cats_and_dogs_mask_dog2.png";
constexpr float kGoldenMaskSimilarity = 0.98; constexpr float kGoldenMaskSimilarity = 0.97;
// Magnification factor used when creating the golden category masks to make // Magnification factor used when creating the golden category masks to make
// them more human-friendly. Each pixel in the golden masks has its value // them more human-friendly. Since interactive segmenter has only 2 categories,
// multiplied by this factor, i.e. a value of 10 means class index 1, a value of // the golden mask uses 0 or 255 for each pixel.
// 20 means class index 2, etc. constexpr int kGoldenMaskMagnificationFactor = 255;
constexpr int kGoldenMaskMagnificationFactor = 10;
// Intentionally converting output into CV_8UC1 and then again into CV_32FC1 // Intentionally converting output into CV_8UC1 and then again into CV_32FC1
// as expected outputs are stored in CV_8UC1, so this conversion allows to do // as expected outputs are stored in CV_8UC1, so this conversion allows to do
@ -155,16 +159,25 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
MediaPipeTasksStatus::kRunnerInitializationError)))); MediaPipeTasksStatus::kRunnerInitializationError))));
} }
class ImageModeTest : public tflite_shims::testing::Test {}; struct InteractiveSegmenterTestParams {
std::string test_name;
RegionOfInterest::Format format;
NormalizedKeypoint roi;
std::string golden_mask_file;
float similarity_threshold;
};
TEST_F(ImageModeTest, SucceedsWithCategoryMask) { using SucceedSegmentationWithRoi =
::testing::TestWithParam<InteractiveSegmenterTestParams>;
TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
const InteractiveSegmenterTestParams& params = GetParam();
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
Image image, Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi; RegionOfInterest interaction_roi;
interaction_roi.format = RegionOfInterest::KEYPOINT; interaction_roi.format = params.format;
interaction_roi.keypoint = interaction_roi.keypoint = params.roi;
components::containers::NormalizedKeypoint{0.25, 0.9};
auto options = std::make_unique<InteractiveSegmenterOptions>(); auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel); JoinPath("./", kTestDataDirectory, kPtmModel);
@ -175,16 +188,26 @@ TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
MP_ASSERT_OK_AND_ASSIGN(auto category_masks, MP_ASSERT_OK_AND_ASSIGN(auto category_masks,
segmenter->Segment(image, interaction_roi)); segmenter->Segment(image, interaction_roi));
EXPECT_EQ(category_masks.size(), 1); EXPECT_EQ(category_masks.size(), 1);
cv::Mat actual_mask = mediapipe::formats::MatView(
category_masks[0].GetImageFrameSharedPtr().get());
cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file),
cv::IMREAD_GRAYSCALE);
EXPECT_THAT(actual_mask,
SimilarToUint8Mask(expected_mask, params.similarity_threshold,
kGoldenMaskMagnificationFactor));
} }
TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
const auto& params = GetParam();
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
Image image, Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi; RegionOfInterest interaction_roi;
interaction_roi.format = RegionOfInterest::KEYPOINT; interaction_roi.format = params.format;
interaction_roi.keypoint = interaction_roi.keypoint = params.roi;
components::containers::NormalizedKeypoint{0.25, 0.9};
auto options = std::make_unique<InteractiveSegmenterOptions>(); auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel); JoinPath("./", kTestDataDirectory, kPtmModel);
@ -196,8 +219,32 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks,
segmenter->Segment(image, interaction_roi)); segmenter->Segment(image, interaction_roi));
EXPECT_EQ(confidence_masks.size(), 2); EXPECT_EQ(confidence_masks.size(), 2);
cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file),
cv::IMREAD_GRAYSCALE);
cv::Mat expected_mask_float;
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
cv::Mat actual_mask = mediapipe::formats::MatView(
confidence_masks[1].GetImageFrameSharedPtr().get());
EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float,
params.similarity_threshold));
} }
INSTANTIATE_TEST_SUITE_P(
SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi,
::testing::ValuesIn<InteractiveSegmenterTestParams>(
{{"PointToDog1", RegionOfInterest::KEYPOINT,
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
{"PointToDog2", RegionOfInterest::KEYPOINT,
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
kGoldenMaskSimilarity}}),
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
info) { return info.param.test_name; });
class ImageModeTest : public tflite_shims::testing::Test {};
// TODO: fix this unit test after image segmenter handled post // TODO: fix this unit test after image segmenter handled post
// processing correctly with rotated image. // processing correctly with rotated image.
TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
@ -206,8 +253,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi; RegionOfInterest interaction_roi;
interaction_roi.format = RegionOfInterest::KEYPOINT; interaction_roi.format = RegionOfInterest::KEYPOINT;
interaction_roi.keypoint = interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66};
components::containers::NormalizedKeypoint{0.25, 0.9};
auto options = std::make_unique<InteractiveSegmenterOptions>(); auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel); JoinPath("./", kTestDataDirectory, kPtmModel);
@ -230,8 +276,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi; RegionOfInterest interaction_roi;
interaction_roi.format = RegionOfInterest::KEYPOINT; interaction_roi.format = RegionOfInterest::KEYPOINT;
interaction_roi.keypoint = interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66};
components::containers::NormalizedKeypoint{0.25, 0.9};
auto options = std::make_unique<InteractiveSegmenterOptions>(); auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel); JoinPath("./", kTestDataDirectory, kPtmModel);

View File

@ -31,6 +31,8 @@ mediapipe_files(srcs = [
"cat_rotated.jpg", "cat_rotated.jpg",
"cat_rotated_mask.jpg", "cat_rotated_mask.jpg",
"cats_and_dogs.jpg", "cats_and_dogs.jpg",
"cats_and_dogs_mask_dog1.png",
"cats_and_dogs_mask_dog2.png",
"cats_and_dogs_no_resizing.jpg", "cats_and_dogs_no_resizing.jpg",
"cats_and_dogs_rotated.jpg", "cats_and_dogs_rotated.jpg",
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite", "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite",
@ -116,6 +118,8 @@ filegroup(
"cat_rotated.jpg", "cat_rotated.jpg",
"cat_rotated_mask.jpg", "cat_rotated_mask.jpg",
"cats_and_dogs.jpg", "cats_and_dogs.jpg",
"cats_and_dogs_mask_dog1.png",
"cats_and_dogs_mask_dog2.png",
"cats_and_dogs_no_resizing.jpg", "cats_and_dogs_no_resizing.jpg",
"cats_and_dogs_rotated.jpg", "cats_and_dogs_rotated.jpg",
"fist.jpg", "fist.jpg",

View File

@ -67,13 +67,7 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_BUILD", name = "com_google_mediapipe_BUILD",
sha256 = "d2b2a8346202691d7f831887c84e9642e974f64ed67851d9a58cf15c94b1f6b3", sha256 = "d2b2a8346202691d7f831887c84e9642e974f64ed67851d9a58cf15c94b1f6b3",
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=16618756636939761678323576393653"], urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=166187566369397616783235763936531678737479599640"],
)
http_file(
name = "com_google_mediapipe_BUILD_orig",
sha256 = "d86b98b82e00dd87cd46bd1429bf5eaa007b500c1a24d9316b73309f2e6c8df8",
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1678737479599640"],
) )
http_file( http_file(
@ -136,6 +130,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs.jpg?generation=1661875684064150"], urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs.jpg?generation=1661875684064150"],
) )
http_file(
name = "com_google_mediapipe_cats_and_dogs_mask_dog1_png",
sha256 = "2ab37d56ba1e46e70b3ddbfe35dac51b18b597b76904c68d7d34c7c74c677d4c",
urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_mask_dog1.png?generation=1678840350058498"],
)
http_file(
name = "com_google_mediapipe_cats_and_dogs_mask_dog2_png",
sha256 = "2010850e2dd7f520fe53b9086d70913b6fb53b178cae15a373e5ee7ffb46824a",
urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_mask_dog2.png?generation=1678840352961684"],
)
http_file( http_file(
name = "com_google_mediapipe_cats_and_dogs_no_resizing_jpg", name = "com_google_mediapipe_cats_and_dogs_no_resizing_jpg",
sha256 = "9d55933ed66bcdc63cd6509ee2518d7eed75d12db609238387ee4cc50b173e58", sha256 = "9d55933ed66bcdc63cd6509ee2518d7eed75d12db609238387ee4cc50b173e58",