mediapipe/mediapipe/calculators/util/landmark_projection_calculator_test.cc
MediaPipe Team e6c19885c6 Project import generated by Copybara.
GitOrigin-RevId: bb059a0721c92e8154d33ce8057b3915a25b3d7d
2021-12-13 15:56:02 -08:00

245 lines
8.2 KiB
C++

#include <array>
#include <vector>
#include "absl/memory/memory.h"
#include "mediapipe/calculators/tensor/image_to_tensor_utils.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
mediapipe::NormalizedLandmarkList input, mediapipe::NormalizedRect rect) {
mediapipe::CalculatorRunner runner(
ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(R"pb(
calculator: "LandmarkProjectionCalculator"
input_stream: "NORM_LANDMARKS:landmarks"
input_stream: "NORM_RECT:rect"
output_stream: "NORM_LANDMARKS:projected_landmarks"
)pb"));
runner.MutableInputs()
->Tag(kNormLandmarksTag)
.packets.push_back(
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
.At(Timestamp(1)));
runner.MutableInputs()
->Tag(kNormRectTag)
.packets.push_back(MakePacket<mediapipe::NormalizedRect>(std::move(rect))
.At(Timestamp(1)));
MP_RETURN_IF_ERROR(runner.Run());
const auto& output_packets = runner.Outputs().Tag(kNormLandmarksTag).packets;
RET_CHECK_EQ(output_packets.size(), 1);
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
}
TEST(LandmarkProjectionCalculatorTest, ProjectingWithDefaultRect) {
mediapipe::NormalizedLandmarkList landmarks =
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 10, y: 20, z: -0.5 }
)pb");
mediapipe::NormalizedRect rect =
ParseTextProtoOrDie<mediapipe::NormalizedRect>(
R"pb(
x_center: 0.5,
y_center: 0.5,
width: 1.0,
height: 1.0,
rotation: 0.0
)pb");
auto status_or_result = RunCalculator(std::move(landmarks), std::move(rect));
MP_ASSERT_OK(status_or_result);
EXPECT_THAT(
status_or_result.value(),
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 10, y: 20, z: -0.5 }
)pb")));
}
mediapipe::NormalizedRect GetCroppedRect() {
return ParseTextProtoOrDie<mediapipe::NormalizedRect>(
R"pb(
x_center: 0.5, y_center: 0.5, width: 0.5, height: 2, rotation: 0.0
)pb");
}
mediapipe::NormalizedLandmarkList GetCroppedRectTestInput() {
return ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 1.0, y: 1.0, z: -0.5 }
)pb");
}
mediapipe::NormalizedLandmarkList GetCroppedRectTestExpectedResult() {
return ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 0.75, y: 1.5, z: -0.25 }
)pb");
}
TEST(LandmarkProjectionCalculatorTest, ProjectingWithCroppedRect) {
auto status_or_result =
RunCalculator(GetCroppedRectTestInput(), GetCroppedRect());
MP_ASSERT_OK(status_or_result);
EXPECT_THAT(status_or_result.value(),
EqualsProto(GetCroppedRectTestExpectedResult()));
}
absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
mediapipe::NormalizedLandmarkList input, std::array<float, 16> matrix) {
mediapipe::CalculatorRunner runner(
ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>(R"pb(
calculator: "LandmarkProjectionCalculator"
input_stream: "NORM_LANDMARKS:landmarks"
input_stream: "PROJECTION_MATRIX:matrix"
output_stream: "NORM_LANDMARKS:projected_landmarks"
)pb"));
runner.MutableInputs()
->Tag(kNormLandmarksTag)
.packets.push_back(
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
.At(Timestamp(1)));
runner.MutableInputs()
->Tag(kProjectionMatrixTag)
.packets.push_back(MakePacket<std::array<float, 16>>(std::move(matrix))
.At(Timestamp(1)));
MP_RETURN_IF_ERROR(runner.Run());
const auto& output_packets = runner.Outputs().Tag(kNormLandmarksTag).packets;
RET_CHECK_EQ(output_packets.size(), 1);
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
}
TEST(LandmarkProjectionCalculatorTest, ProjectingWithIdentityMatrix) {
mediapipe::NormalizedLandmarkList landmarks =
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 10, y: 20, z: -0.5 }
)pb");
// clang-format off
std::array<float, 16> matrix = {
1.0f, 0.0f, 0.0f, 0.0f,
0.0f, 1.0f, 0.0f, 0.0f,
0.0f, 0.0f, 1.0f, 0.0f,
0.0f, 0.0f, 0.0f, 1.0f,
};
// clang-format on
auto status_or_result =
RunCalculator(std::move(landmarks), std::move(matrix));
MP_ASSERT_OK(status_or_result);
EXPECT_THAT(
status_or_result.value(),
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 10, y: 20, z: -0.5 }
)pb")));
}
TEST(LandmarkProjectionCalculatorTest, ProjectingWithCroppedRectMatrix) {
constexpr int kRectWidth = 1280;
constexpr int kRectHeight = 720;
auto roi = GetRoi(kRectWidth, kRectHeight, GetCroppedRect());
std::array<float, 16> matrix;
GetRotatedSubRectToRectTransformMatrix(roi, kRectWidth, kRectHeight,
/*flip_horizontaly=*/false, &matrix);
auto status_or_result = RunCalculator(GetCroppedRectTestInput(), matrix);
MP_ASSERT_OK(status_or_result);
EXPECT_THAT(status_or_result.value(),
EqualsProto(GetCroppedRectTestExpectedResult()));
}
TEST(LandmarkProjectionCalculatorTest, ProjectingWithScaleMatrix) {
mediapipe::NormalizedLandmarkList landmarks =
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 10, y: 20, z: -0.5 }
landmark { x: 5, y: 6, z: 7 }
)pb");
// clang-format off
std::array<float, 16> matrix = {
10.0f, 0.0f, 0.0f, 0.0f,
0.0f, 100.0f, 0.0f, 0.0f,
0.0f, 0.0f, 1.0f, 0.0f,
0.0f, 0.0f, 0.0f, 1.0f,
};
// clang-format on
auto status_or_result =
RunCalculator(std::move(landmarks), std::move(matrix));
MP_ASSERT_OK(status_or_result);
EXPECT_THAT(
status_or_result.value(),
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 100, y: 2000, z: -5 }
landmark { x: 50, y: 600, z: 70 }
)pb")));
}
TEST(LandmarkProjectionCalculatorTest, ProjectingWithTranslateMatrix) {
mediapipe::NormalizedLandmarkList landmarks =
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 10, y: 20, z: -0.5 }
)pb");
// clang-format off
std::array<float, 16> matrix = {
1.0f, 0.0f, 0.0f, 1.0f,
0.0f, 1.0f, 0.0f, 2.0f,
0.0f, 0.0f, 1.0f, 0.0f,
0.0f, 0.0f, 0.0f, 1.0f,
};
// clang-format on
auto status_or_result =
RunCalculator(std::move(landmarks), std::move(matrix));
MP_ASSERT_OK(status_or_result);
EXPECT_THAT(
status_or_result.value(),
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 11, y: 22, z: -0.5 }
)pb")));
}
TEST(LandmarkProjectionCalculatorTest, ProjectingWithRotationMatrix) {
mediapipe::NormalizedLandmarkList landmarks =
ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 4, y: 0, z: -0.5 }
)pb");
// clang-format off
// 90 degrees rotation matrix
std::array<float, 16> matrix = {
0.0f, -1.0f, 0.0f, 0.0f,
1.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 1.0f, 0.0f,
0.0f, 0.0f, 0.0f, 1.0f,
};
// clang-format on
auto status_or_result =
RunCalculator(std::move(landmarks), std::move(matrix));
MP_ASSERT_OK(status_or_result);
EXPECT_THAT(
status_or_result.value(),
EqualsProto(ParseTextProtoOrDie<mediapipe::NormalizedLandmarkList>(R"pb(
landmark { x: 0, y: 4, z: -0.5 }
)pb")));
}
} // namespace
} // namespace mediapipe