Internal change.

PiperOrigin-RevId: 527374728
This commit is contained in:
MediaPipe Team 2023-04-26 14:25:35 -07:00 committed by Copybara-Service
parent c44cc30ece
commit a45d1f5e90
3 changed files with 80 additions and 0 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <optional>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
@ -24,12 +25,15 @@ limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_format.pb.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
@ -76,8 +80,11 @@ constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kNormRectName[] = "norm_rect";
constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
constexpr char kNormLandmarksName[] = "norm_landmarks";
constexpr char kSegmentationMaskTag[] = "SEGMENTATION_MASK";
constexpr char kSegmentationMaskName[] = "segmentation_mask";
constexpr float kLiteModelFractionDiff = 0.05; // percentage
constexpr float kGoldenMaskSimilarity = 1.0;
template <typename ProtoT>
ProtoT GetExpectedProto(absl::string_view filename) {
@ -125,6 +132,9 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreatePoseLandmarkerGraphTaskRunner(
pose_landmarker.Out(kNormLandmarksTag).SetName(kNormLandmarksName) >>
graph[Output<std::vector<NormalizedLandmarkList>>(kNormLandmarksTag)];
pose_landmarker.Out(kSegmentationMaskTag).SetName(kSegmentationMaskName) >>
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
return TaskRunner::Create(
graph.GetConfig(),
absl::make_unique<tasks::core::MediaPipeBuiltinOpResolver>());
@ -145,6 +155,21 @@ NormalizedRect MakeNormRect(float x_center, float y_center, float width,
class PoseLandmarkerGraphTest
: public testing::TestWithParam<PoseLandmarkerGraphTestParams> {};
// Convert pixels from float range [0,1] to uint8 range [0,255].
ImageFrame CreateUint8ImageFrame(const Image& image) {
auto* image_frame_ptr = image.GetImageFrameSharedPtr().get();
ImageFrame output_image_frame(ImageFormat::GRAY8, image_frame_ptr->Width(),
image_frame_ptr->Height(), 1);
float* pixelData =
reinterpret_cast<float*>(image_frame_ptr->MutablePixelData());
uint8_t* uint8PixelData = output_image_frame.MutablePixelData();
const int total_pixels = image_frame_ptr->Width() * image_frame_ptr->Height();
for (int i = 0; i < total_pixels; ++i) {
uint8PixelData[i] = static_cast<uint8_t>(pixelData[i] * 255.0f);
}
return output_image_frame;
}
TEST_P(PoseLandmarkerGraphTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
@ -167,6 +192,53 @@ TEST_P(PoseLandmarkerGraphTest, Succeeds) {
GetParam().landmarks_diff_threshold),
*GetParam().expected_landmarks_list));
}
const std::vector<Image>& segmentation_masks =
(*output_packets)[kSegmentationMaskName].Get<std::vector<Image>>();
EXPECT_EQ(segmentation_masks.size(), 1);
const Image& segmentation_mask = segmentation_masks[0];
const ImageFrame segmentation_mask_image_frame =
CreateUint8ImageFrame(segmentation_mask);
auto expected_image_frame = LoadTestPng(
JoinPath("./", kTestDataDirectory, "pose_segmentation_mask_golden.png"),
ImageFormat::GRAY8);
ASSERT_EQ(segmentation_mask_image_frame.Width(),
expected_image_frame->Width());
ASSERT_EQ(segmentation_mask_image_frame.Height(),
expected_image_frame->Height());
ASSERT_EQ(segmentation_mask_image_frame.Format(),
expected_image_frame->Format());
ASSERT_EQ(segmentation_mask_image_frame.NumberOfChannels(),
expected_image_frame->NumberOfChannels());
ASSERT_EQ(segmentation_mask_image_frame.ByteDepth(),
expected_image_frame->ByteDepth());
ASSERT_EQ(segmentation_mask_image_frame.NumberOfChannels(), 1);
ASSERT_EQ(segmentation_mask_image_frame.ByteDepth(), 1);
int consistent_pixels = 0;
int num_pixels = segmentation_mask_image_frame.Width() *
segmentation_mask_image_frame.Height();
for (int i = 0; i < segmentation_mask_image_frame.Height(); ++i) {
for (int j = 0; j < segmentation_mask_image_frame.Width(); ++j) {
consistent_pixels +=
(segmentation_mask_image_frame
.PixelData()[segmentation_mask_image_frame.WidthStep() * i +
j] ==
expected_image_frame
->PixelData()[expected_image_frame->WidthStep() * i + j]);
}
}
EXPECT_GE(static_cast<float>(consistent_pixels) / num_pixels,
kGoldenMaskSimilarity);
// For visual comparison of segmentation mask output.
MP_ASSERT_OK_AND_ASSIGN(auto output_path,
SavePngTestOutput(segmentation_mask_image_frame,
"segmentation_mask_output"));
}
INSTANTIATE_TEST_SUITE_P(

View File

@ -80,6 +80,7 @@ mediapipe_files(srcs = [
"pose_detection.tflite",
"pose_landmark_lite.tflite",
"pose_landmarker.task",
"pose_segmentation_mask_golden.png",
"right_hands.jpg",
"right_hands_rotated.jpg",
"segmentation_golden_rotation0.png",
@ -143,6 +144,7 @@ filegroup(
"portrait_selfie_segmentation_expected_confidence_mask.jpg",
"portrait_selfie_segmentation_landscape_expected_category_mask.jpg",
"pose.jpg",
"pose_segmentation_mask_golden.png",
"right_hands.jpg",
"right_hands_rotated.jpg",
"segmentation_golden_rotation0.png",

View File

@ -1018,6 +1018,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmarks.pbtxt?generation=1681425322701589"],
)
http_file(
name = "com_google_mediapipe_pose_segmentation_mask_golden_png",
sha256 = "62ee418e18f317327572da5fcc988af703eb31e6f0b9e0bf3d55e6f4797d6953",
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_segmentation_mask_golden.png?generation=1682541414235372"],
)
http_file(
name = "com_google_mediapipe_ptm_512_hdt_ptm_woid_tflite",
sha256 = "2baa1c9783d03dd26f91e3c49efbcab11dd1361ff80e40e7209e81f84f281b6a",