From a45d1f5e90a6a02381e5d2bfd9c53fb000c76dba Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 26 Apr 2023 14:25:35 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 527374728 --- .../pose_landmarker_graph_test.cc | 72 +++++++++++++++++++ mediapipe/tasks/testdata/vision/BUILD | 2 + third_party/external_files.bzl | 6 ++ 3 files changed, 80 insertions(+) diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph_test.cc index 92bc80c5a..6f7488a5f 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph_test.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include "absl/flags/flag.h" #include "absl/status/statusor.h" @@ -24,12 +25,15 @@ limitations under the License. #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/tool/test_util.h" #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" @@ -76,8 +80,11 @@ constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormRectName[] = "norm_rect"; constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS"; constexpr char kNormLandmarksName[] = "norm_landmarks"; +constexpr char kSegmentationMaskTag[] = "SEGMENTATION_MASK"; +constexpr char kSegmentationMaskName[] = "segmentation_mask"; constexpr float kLiteModelFractionDiff = 0.05; // percentage +constexpr float kGoldenMaskSimilarity = 1.0; template ProtoT GetExpectedProto(absl::string_view filename) { @@ -125,6 +132,9 @@ absl::StatusOr> CreatePoseLandmarkerGraphTaskRunner( pose_landmarker.Out(kNormLandmarksTag).SetName(kNormLandmarksName) >> graph[Output>(kNormLandmarksTag)]; + pose_landmarker.Out(kSegmentationMaskTag).SetName(kSegmentationMaskName) >> + graph[Output>(kSegmentationMaskTag)]; + return TaskRunner::Create( graph.GetConfig(), absl::make_unique()); @@ -145,6 +155,21 @@ NormalizedRect MakeNormRect(float x_center, float y_center, float width, class PoseLandmarkerGraphTest : public testing::TestWithParam {}; +// Convert pixels from float range [0,1] to uint8 range [0,255]. +ImageFrame CreateUint8ImageFrame(const Image& image) { + auto* image_frame_ptr = image.GetImageFrameSharedPtr().get(); + ImageFrame output_image_frame(ImageFormat::GRAY8, image_frame_ptr->Width(), + image_frame_ptr->Height(), 1); + float* pixelData = + reinterpret_cast(image_frame_ptr->MutablePixelData()); + uint8_t* uint8PixelData = output_image_frame.MutablePixelData(); + const int total_pixels = image_frame_ptr->Width() * image_frame_ptr->Height(); + for (int i = 0; i < total_pixels; ++i) { + uint8PixelData[i] = static_cast(pixelData[i] * 255.0f); + } + return output_image_frame; +} + TEST_P(PoseLandmarkerGraphTest, Succeeds) { MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, @@ -167,6 +192,53 @@ TEST_P(PoseLandmarkerGraphTest, Succeeds) { GetParam().landmarks_diff_threshold), *GetParam().expected_landmarks_list)); } + + const std::vector& segmentation_masks = + (*output_packets)[kSegmentationMaskName].Get>(); + + EXPECT_EQ(segmentation_masks.size(), 1); + + const Image& segmentation_mask = segmentation_masks[0]; + const ImageFrame segmentation_mask_image_frame = + CreateUint8ImageFrame(segmentation_mask); + + auto expected_image_frame = LoadTestPng( + JoinPath("./", kTestDataDirectory, "pose_segmentation_mask_golden.png"), + ImageFormat::GRAY8); + + ASSERT_EQ(segmentation_mask_image_frame.Width(), + expected_image_frame->Width()); + ASSERT_EQ(segmentation_mask_image_frame.Height(), + expected_image_frame->Height()); + ASSERT_EQ(segmentation_mask_image_frame.Format(), + expected_image_frame->Format()); + ASSERT_EQ(segmentation_mask_image_frame.NumberOfChannels(), + expected_image_frame->NumberOfChannels()); + ASSERT_EQ(segmentation_mask_image_frame.ByteDepth(), + expected_image_frame->ByteDepth()); + ASSERT_EQ(segmentation_mask_image_frame.NumberOfChannels(), 1); + ASSERT_EQ(segmentation_mask_image_frame.ByteDepth(), 1); + int consistent_pixels = 0; + int num_pixels = segmentation_mask_image_frame.Width() * + segmentation_mask_image_frame.Height(); + for (int i = 0; i < segmentation_mask_image_frame.Height(); ++i) { + for (int j = 0; j < segmentation_mask_image_frame.Width(); ++j) { + consistent_pixels += + (segmentation_mask_image_frame + .PixelData()[segmentation_mask_image_frame.WidthStep() * i + + j] == + expected_image_frame + ->PixelData()[expected_image_frame->WidthStep() * i + j]); + } + } + + EXPECT_GE(static_cast(consistent_pixels) / num_pixels, + kGoldenMaskSimilarity); + + // For visual comparison of segmentation mask output. + MP_ASSERT_OK_AND_ASSIGN(auto output_path, + SavePngTestOutput(segmentation_mask_image_frame, + "segmentation_mask_output")); } INSTANTIATE_TEST_SUITE_P( diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index a8704123c..632e8aa4e 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -80,6 +80,7 @@ mediapipe_files(srcs = [ "pose_detection.tflite", "pose_landmark_lite.tflite", "pose_landmarker.task", + "pose_segmentation_mask_golden.png", "right_hands.jpg", "right_hands_rotated.jpg", "segmentation_golden_rotation0.png", @@ -143,6 +144,7 @@ filegroup( "portrait_selfie_segmentation_expected_confidence_mask.jpg", "portrait_selfie_segmentation_landscape_expected_category_mask.jpg", "pose.jpg", + "pose_segmentation_mask_golden.png", "right_hands.jpg", "right_hands_rotated.jpg", "segmentation_golden_rotation0.png", diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 6a1582cc7..e8ff7819c 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -1018,6 +1018,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmarks.pbtxt?generation=1681425322701589"], ) + http_file( + name = "com_google_mediapipe_pose_segmentation_mask_golden_png", + sha256 = "62ee418e18f317327572da5fcc988af703eb31e6f0b9e0bf3d55e6f4797d6953", + urls = ["https://storage.googleapis.com/mediapipe-assets/pose_segmentation_mask_golden.png?generation=1682541414235372"], + ) + http_file( name = "com_google_mediapipe_ptm_512_hdt_ptm_woid_tflite", sha256 = "2baa1c9783d03dd26f91e3c49efbcab11dd1361ff80e40e7209e81f84f281b6a",