Internal change.
PiperOrigin-RevId: 527374728
This commit is contained in:
parent
c44cc30ece
commit
a45d1f5e90
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/flags/flag.h"
|
#include "absl/flags/flag.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
|
@ -24,12 +25,15 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/deps/file_path.h"
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
#include "mediapipe/framework/formats/image_format.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/framework/packet.h"
|
#include "mediapipe/framework/packet.h"
|
||||||
#include "mediapipe/framework/port/file_helpers.h"
|
#include "mediapipe/framework/port/file_helpers.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/tool/test_util.h"
|
||||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
|
@ -76,8 +80,11 @@ constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
constexpr char kNormRectName[] = "norm_rect";
|
constexpr char kNormRectName[] = "norm_rect";
|
||||||
constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
|
constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
|
||||||
constexpr char kNormLandmarksName[] = "norm_landmarks";
|
constexpr char kNormLandmarksName[] = "norm_landmarks";
|
||||||
|
constexpr char kSegmentationMaskTag[] = "SEGMENTATION_MASK";
|
||||||
|
constexpr char kSegmentationMaskName[] = "segmentation_mask";
|
||||||
|
|
||||||
constexpr float kLiteModelFractionDiff = 0.05; // percentage
|
constexpr float kLiteModelFractionDiff = 0.05; // percentage
|
||||||
|
constexpr float kGoldenMaskSimilarity = 1.0;
|
||||||
|
|
||||||
template <typename ProtoT>
|
template <typename ProtoT>
|
||||||
ProtoT GetExpectedProto(absl::string_view filename) {
|
ProtoT GetExpectedProto(absl::string_view filename) {
|
||||||
|
@ -125,6 +132,9 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreatePoseLandmarkerGraphTaskRunner(
|
||||||
pose_landmarker.Out(kNormLandmarksTag).SetName(kNormLandmarksName) >>
|
pose_landmarker.Out(kNormLandmarksTag).SetName(kNormLandmarksName) >>
|
||||||
graph[Output<std::vector<NormalizedLandmarkList>>(kNormLandmarksTag)];
|
graph[Output<std::vector<NormalizedLandmarkList>>(kNormLandmarksTag)];
|
||||||
|
|
||||||
|
pose_landmarker.Out(kSegmentationMaskTag).SetName(kSegmentationMaskName) >>
|
||||||
|
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
|
||||||
|
|
||||||
return TaskRunner::Create(
|
return TaskRunner::Create(
|
||||||
graph.GetConfig(),
|
graph.GetConfig(),
|
||||||
absl::make_unique<tasks::core::MediaPipeBuiltinOpResolver>());
|
absl::make_unique<tasks::core::MediaPipeBuiltinOpResolver>());
|
||||||
|
@ -145,6 +155,21 @@ NormalizedRect MakeNormRect(float x_center, float y_center, float width,
|
||||||
class PoseLandmarkerGraphTest
|
class PoseLandmarkerGraphTest
|
||||||
: public testing::TestWithParam<PoseLandmarkerGraphTestParams> {};
|
: public testing::TestWithParam<PoseLandmarkerGraphTestParams> {};
|
||||||
|
|
||||||
|
// Convert pixels from float range [0,1] to uint8 range [0,255].
|
||||||
|
ImageFrame CreateUint8ImageFrame(const Image& image) {
|
||||||
|
auto* image_frame_ptr = image.GetImageFrameSharedPtr().get();
|
||||||
|
ImageFrame output_image_frame(ImageFormat::GRAY8, image_frame_ptr->Width(),
|
||||||
|
image_frame_ptr->Height(), 1);
|
||||||
|
float* pixelData =
|
||||||
|
reinterpret_cast<float*>(image_frame_ptr->MutablePixelData());
|
||||||
|
uint8_t* uint8PixelData = output_image_frame.MutablePixelData();
|
||||||
|
const int total_pixels = image_frame_ptr->Width() * image_frame_ptr->Height();
|
||||||
|
for (int i = 0; i < total_pixels; ++i) {
|
||||||
|
uint8PixelData[i] = static_cast<uint8_t>(pixelData[i] * 255.0f);
|
||||||
|
}
|
||||||
|
return output_image_frame;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(PoseLandmarkerGraphTest, Succeeds) {
|
TEST_P(PoseLandmarkerGraphTest, Succeeds) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||||
|
@ -167,6 +192,53 @@ TEST_P(PoseLandmarkerGraphTest, Succeeds) {
|
||||||
GetParam().landmarks_diff_threshold),
|
GetParam().landmarks_diff_threshold),
|
||||||
*GetParam().expected_landmarks_list));
|
*GetParam().expected_landmarks_list));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const std::vector<Image>& segmentation_masks =
|
||||||
|
(*output_packets)[kSegmentationMaskName].Get<std::vector<Image>>();
|
||||||
|
|
||||||
|
EXPECT_EQ(segmentation_masks.size(), 1);
|
||||||
|
|
||||||
|
const Image& segmentation_mask = segmentation_masks[0];
|
||||||
|
const ImageFrame segmentation_mask_image_frame =
|
||||||
|
CreateUint8ImageFrame(segmentation_mask);
|
||||||
|
|
||||||
|
auto expected_image_frame = LoadTestPng(
|
||||||
|
JoinPath("./", kTestDataDirectory, "pose_segmentation_mask_golden.png"),
|
||||||
|
ImageFormat::GRAY8);
|
||||||
|
|
||||||
|
ASSERT_EQ(segmentation_mask_image_frame.Width(),
|
||||||
|
expected_image_frame->Width());
|
||||||
|
ASSERT_EQ(segmentation_mask_image_frame.Height(),
|
||||||
|
expected_image_frame->Height());
|
||||||
|
ASSERT_EQ(segmentation_mask_image_frame.Format(),
|
||||||
|
expected_image_frame->Format());
|
||||||
|
ASSERT_EQ(segmentation_mask_image_frame.NumberOfChannels(),
|
||||||
|
expected_image_frame->NumberOfChannels());
|
||||||
|
ASSERT_EQ(segmentation_mask_image_frame.ByteDepth(),
|
||||||
|
expected_image_frame->ByteDepth());
|
||||||
|
ASSERT_EQ(segmentation_mask_image_frame.NumberOfChannels(), 1);
|
||||||
|
ASSERT_EQ(segmentation_mask_image_frame.ByteDepth(), 1);
|
||||||
|
int consistent_pixels = 0;
|
||||||
|
int num_pixels = segmentation_mask_image_frame.Width() *
|
||||||
|
segmentation_mask_image_frame.Height();
|
||||||
|
for (int i = 0; i < segmentation_mask_image_frame.Height(); ++i) {
|
||||||
|
for (int j = 0; j < segmentation_mask_image_frame.Width(); ++j) {
|
||||||
|
consistent_pixels +=
|
||||||
|
(segmentation_mask_image_frame
|
||||||
|
.PixelData()[segmentation_mask_image_frame.WidthStep() * i +
|
||||||
|
j] ==
|
||||||
|
expected_image_frame
|
||||||
|
->PixelData()[expected_image_frame->WidthStep() * i + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_GE(static_cast<float>(consistent_pixels) / num_pixels,
|
||||||
|
kGoldenMaskSimilarity);
|
||||||
|
|
||||||
|
// For visual comparison of segmentation mask output.
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto output_path,
|
||||||
|
SavePngTestOutput(segmentation_mask_image_frame,
|
||||||
|
"segmentation_mask_output"));
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
|
2
mediapipe/tasks/testdata/vision/BUILD
vendored
2
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -80,6 +80,7 @@ mediapipe_files(srcs = [
|
||||||
"pose_detection.tflite",
|
"pose_detection.tflite",
|
||||||
"pose_landmark_lite.tflite",
|
"pose_landmark_lite.tflite",
|
||||||
"pose_landmarker.task",
|
"pose_landmarker.task",
|
||||||
|
"pose_segmentation_mask_golden.png",
|
||||||
"right_hands.jpg",
|
"right_hands.jpg",
|
||||||
"right_hands_rotated.jpg",
|
"right_hands_rotated.jpg",
|
||||||
"segmentation_golden_rotation0.png",
|
"segmentation_golden_rotation0.png",
|
||||||
|
@ -143,6 +144,7 @@ filegroup(
|
||||||
"portrait_selfie_segmentation_expected_confidence_mask.jpg",
|
"portrait_selfie_segmentation_expected_confidence_mask.jpg",
|
||||||
"portrait_selfie_segmentation_landscape_expected_category_mask.jpg",
|
"portrait_selfie_segmentation_landscape_expected_category_mask.jpg",
|
||||||
"pose.jpg",
|
"pose.jpg",
|
||||||
|
"pose_segmentation_mask_golden.png",
|
||||||
"right_hands.jpg",
|
"right_hands.jpg",
|
||||||
"right_hands_rotated.jpg",
|
"right_hands_rotated.jpg",
|
||||||
"segmentation_golden_rotation0.png",
|
"segmentation_golden_rotation0.png",
|
||||||
|
|
6
third_party/external_files.bzl
vendored
6
third_party/external_files.bzl
vendored
|
@ -1018,6 +1018,12 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmarks.pbtxt?generation=1681425322701589"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmarks.pbtxt?generation=1681425322701589"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_pose_segmentation_mask_golden_png",
|
||||||
|
sha256 = "62ee418e18f317327572da5fcc988af703eb31e6f0b9e0bf3d55e6f4797d6953",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_segmentation_mask_golden.png?generation=1682541414235372"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_ptm_512_hdt_ptm_woid_tflite",
|
name = "com_google_mediapipe_ptm_512_hdt_ptm_woid_tflite",
|
||||||
sha256 = "2baa1c9783d03dd26f91e3c49efbcab11dd1361ff80e40e7209e81f84f281b6a",
|
sha256 = "2baa1c9783d03dd26f91e3c49efbcab11dd1361ff80e40e7209e81f84f281b6a",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user