Add FaceToRectCalculator.

PiperOrigin-RevId: 522595752
This commit is contained in:
Jiuqiang Tang 2023-04-07 07:42:45 -07:00 committed by Copybara-Service
parent b0d3595291
commit bca0a92c2e
4 changed files with 423 additions and 0 deletions

View File

@ -1460,3 +1460,45 @@ cc_test(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
mediapipe_proto_library(
name = "face_to_rect_calculator_proto",
srcs = ["face_to_rect_calculator.proto"],
)
cc_library(
name = "face_to_rect_calculator",
srcs = ["face_to_rect_calculator.cc"],
deps = [
":detections_to_rects_calculator",
":face_to_rect_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:vector",
"@com_google_absl//absl/status",
],
alwayslink = 1,
)
cc_test(
name = "face_to_rect_calculator_test",
size = "small",
srcs = ["face_to_rect_calculator_test.cc"],
deps = [
":face_to_rect_calculator",
":face_to_rect_calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status_matchers",
"@com_google_absl//absl/status",
],
)

View File

@ -0,0 +1,182 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cmath>
#include "absl/status/status.h"
#include "mediapipe/calculators/util/detections_to_rects_calculator.h"
#include "mediapipe/calculators/util/face_to_rect_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/vector.h"
namespace mediapipe {
// A calculator to convert face detection proto to mediapipe rect.
// node {
// calculator: "FaceToRectCalculator"
// input_stream: "DETECTION:detection"
// input_stream: "IMAGE_SIZE:frame_size"
// output_stream: "NORM_RECT:rect"
// node_options: {
// [type.googleapis.com/mediapipe.FaceToRectCalculatorOptions] {
// eye_landmark_size: 1
// nose_landmark_size: 2
// mouth_landmark_size: 2
// eye_to_mouth_scale: 3.42
// eye_to_eye_scale: 3.8
// }
// }
// }
//
class FaceToRectCalculator : public DetectionsToRectsCalculator {
public:
absl::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
// Default values.
face_options_.set_eye_landmark_size(2);
face_options_.set_nose_landmark_size(0);
face_options_.set_mouth_landmark_size(2);
face_options_.set_eye_to_mouth_mix(0.1f);
face_options_.set_eye_to_mouth_scale(3.6f);
face_options_.set_eye_to_eye_scale(4.0f);
face_options_.MergeFrom(cc->Options<FaceToRectCalculatorOptions>());
RET_CHECK(face_options_.eye_landmark_size() > 0 &&
face_options_.mouth_landmark_size() > 0)
<< "Eye landmarks and mouth landmarks cannot be empty.";
total_landmarks_ = face_options_.eye_landmark_size() * 2 +
face_options_.nose_landmark_size() +
face_options_.mouth_landmark_size();
rotate_ = true;
return absl::OkStatus();
}
private:
absl::Status DetectionToRect(const Detection& detection,
const DetectionSpec& detection_spec,
Rect* rect) override {
const int width = detection_spec.image_size->first;
const int height = detection_spec.image_size->second;
return ComputeFaceRect(detection, width, height, rect);
}
absl::Status DetectionToNormalizedRect(const Detection& detection,
const DetectionSpec& detection_spec,
NormalizedRect* rect) override {
const int width = detection_spec.image_size->first;
const int height = detection_spec.image_size->second;
Rect rect_pix;
MP_RETURN_IF_ERROR(ComputeFaceRect(detection, width, height, &rect_pix));
const float width_recip = 1.f / width;
const float height_recip = 1.f / height;
rect->set_x_center(rect_pix.x_center() * width_recip);
rect->set_y_center(rect_pix.y_center() * height_recip);
rect->set_width(rect_pix.width() * width_recip);
rect->set_height(rect_pix.height() * height_recip);
return absl::OkStatus();
}
absl::Status ComputeRotation(const Detection& detection,
const DetectionSpec& detection_spec,
float* rotation) override {
// eye_to_eye_ and eye_to_mouth_ are computed in ComputeFaceRect
Vector2_f dir =
eye_to_eye_ - Vector2_f(-eye_to_mouth_.y(), eye_to_mouth_.x());
*rotation = NormalizeRadians(target_angle_ + std::atan2(dir.y(), dir.x()));
return absl::OkStatus();
}
// Compute a face rectangle from detection landmarks.
absl::Status ComputeFaceRect(const Detection& detection, int width,
int height, Rect* rect) {
Vector2_f left_eye(0.f, 0.f), right_eye(0.f, 0.f), mouth(0.f, 0.f);
if (!GetLandmarks(detection, width, height, &left_eye, &right_eye,
&mouth)) {
return absl::InvalidArgumentError(
"Detection has wrong number of keypoints.");
}
const Vector2_f eye_center = (left_eye + right_eye) * 0.5f;
eye_to_eye_ = right_eye - left_eye;
eye_to_mouth_ = mouth - eye_center;
const Vector2_f center =
eye_center + eye_to_mouth_ * face_options_.eye_to_mouth_mix();
rect->set_x_center(std::round(center.x()));
rect->set_y_center(std::round(center.y()));
const float scale =
std::max(eye_to_mouth_.Norm() * face_options_.eye_to_mouth_scale(),
eye_to_eye_.Norm() * face_options_.eye_to_eye_scale());
rect->set_width(std::round(scale));
rect->set_height(std::round(scale));
return absl::OkStatus();
}
// Gets eyes and mouth landmarks from a face detection.
bool GetLandmarks(const Detection& detection, int width, int height,
Vector2_f* left_eye, Vector2_f* right_eye,
Vector2_f* mouth);
FaceToRectCalculatorOptions face_options_;
int total_landmarks_ = 0;
Vector2_f eye_to_eye_;
Vector2_f eye_to_mouth_;
};
REGISTER_CALCULATOR(FaceToRectCalculator);
bool FaceToRectCalculator::GetLandmarks(const Detection& detection, int width,
int height, Vector2_f* left_eye,
Vector2_f* right_eye,
Vector2_f* mouth) {
const auto& location_data = detection.location_data();
if (location_data.relative_keypoints_size() < total_landmarks_) {
return false;
}
// eyes
Vector2_f le(0.f, 0.f);
Vector2_f re(0.f, 0.f);
int i = 0;
for (; i < face_options_.eye_landmark_size(); ++i) {
const auto& left_point = location_data.relative_keypoints(i);
le += Vector2_f(left_point.x() * width, left_point.y() * height);
const auto& right_point =
location_data.relative_keypoints(i + face_options_.eye_landmark_size());
re += Vector2_f(right_point.x() * width, right_point.y() * height);
}
*left_eye = le / face_options_.eye_landmark_size();
*right_eye = re / face_options_.eye_landmark_size();
// mouth
Vector2_f m(0.f, 0.f);
i += face_options_.eye_landmark_size() + face_options_.nose_landmark_size();
for (int j = 0; j < face_options_.mouth_landmark_size(); ++j) {
const auto& point = location_data.relative_keypoints(i + j);
m += Vector2_f(point.x() * width, point.y() * height);
}
*mouth = m / face_options_.mouth_landmark_size();
return true;
}
} // namespace mediapipe

View File

@ -0,0 +1,47 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package mediapipe;
message FaceToRectCalculatorOptions {
// The number of landmarks for each eye.
int32 eye_landmark_size = 1;
// The number of landmarks for nose.
int32 nose_landmark_size = 2;
// The number of landmarks for mouth.
int32 mouth_landmark_size = 3;
// Coefficient to compute center by mixing eye center and mouth center.
// Suggest values:
// - human face: 0.1
// - cartoon face: 0.2
float eye_to_mouth_mix = 4;
// The scale transforms eye-mouth distance to crop size.
// Suggest values:
// - human face landmarks: 3.6 (100%)
// - human face detection: 3.42 (95%)
float eye_to_mouth_scale = 5;
// The scale transforms eye-eye distance to crop size.
// Suggest values:
// - human face landmarks: 4.0 (100%)
// - human face detection: 3.8 (95%)
// - cartoon face landmarks: 3.0 (75%)
float eye_to_eye_scale = 6;
}

View File

@ -0,0 +1,152 @@
#include <algorithm>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "mediapipe/calculators/util/face_to_rect_calculator.pb.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
constexpr int kImageWidth = 1280;
constexpr int kImageHeight = 720;
constexpr float kEps = 1e-5f;
Detection DetectionWithKeyPoints(
const std::vector<std::pair<float, float>>& key_points) {
Detection detection;
LocationData* location_data = detection.mutable_location_data();
std::for_each(key_points.begin(), key_points.end(),
[location_data](std::pair<float, float> kp) {
auto* new_kp = location_data->add_relative_keypoints();
new_kp->set_x(kp.first);
new_kp->set_y(kp.second);
});
return detection;
}
// Creates a FaceToRectCalculatorOptions from landmarks sizes of eye, nose, and
// mouth.
FaceToRectCalculatorOptions CreateOptions(int eye, int nose, int mouth) {
FaceToRectCalculatorOptions face_options;
face_options.set_eye_landmark_size(eye);
face_options.set_nose_landmark_size(nose);
face_options.set_mouth_landmark_size(mouth);
return face_options;
}
absl::StatusOr<NormalizedRect> RunCalculator(
Detection detection, FaceToRectCalculatorOptions options) {
CalculatorGraphConfig::Node node =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "FaceToRectCalculator"
input_stream: "DETECTION:detection"
input_stream: "IMAGE_SIZE:frame_size"
output_stream: "NORM_RECT:rect"
)pb");
node.mutable_node_options()->Add()->PackFrom(options);
CalculatorRunner runner(node);
runner.MutableInputs()
->Tag("DETECTION")
.packets.push_back(MakePacket<Detection>(std::move(detection))
.At(Timestamp::PostStream()));
runner.MutableInputs()
->Tag("IMAGE_SIZE")
.packets.push_back(
MakePacket<std::pair<int, int>>(kImageWidth, kImageHeight)
.At(Timestamp::PostStream()));
MP_RETURN_IF_ERROR(runner.Run());
const std::vector<Packet>& output = runner.Outputs().Tag("NORM_RECT").packets;
RET_CHECK_EQ(output.size(), 1);
return output[0].Get<NormalizedRect>();
}
} // namespace
TEST(FaceToRectCalculator, WrongNumberOfKeyPoints) {
auto status_or_value = RunCalculator(
/*detection=*/DetectionWithKeyPoints({
{0.3f, 0.5f}, // left eye
{0.6f, 0.45f}, // right eye
{1.0f, 1.0f}, // nose - not used
{0.5f, 0.65f}, // mouth
}),
CreateOptions(1, 1, 2));
EXPECT_FALSE(status_or_value.ok());
EXPECT_EQ(status_or_value.status().code(),
absl::StatusCode::kInvalidArgument);
}
TEST(FaceToRectCalculator, DetectionToNormalizedRect) {
auto status_or_value = RunCalculator(
/*detection=*/DetectionWithKeyPoints({
{0.3f, 0.5f}, // left eye
{0.6f, 0.45f}, // right eye
{1.0f, 1.0f}, // nose - not used
{0.5f, 0.65f}, // mouth
}),
CreateOptions(1, 1, 1));
MP_ASSERT_OK(status_or_value);
const auto& rect = status_or_value.value();
EXPECT_THAT(rect.x_center(), testing::FloatNear(0.454688f, kEps));
EXPECT_THAT(rect.y_center(), testing::FloatNear(0.493056f, kEps));
EXPECT_THAT(rect.height(), testing::FloatNear(2.14306f, kEps));
EXPECT_THAT(rect.width(), testing::FloatNear(1.20547f, kEps));
EXPECT_THAT(rect.rotation(), testing::FloatNear(-0.193622f, kEps));
}
TEST(FaceToRectCalculator, LandmarksToNormalizedRect) {
auto status_or_value = RunCalculator(
/*detection=*/DetectionWithKeyPoints({
{0.3f, 0.5f}, // left eye
{0.3f, 0.5f}, // left eye
{0.6f, 0.45f}, // right eye
{0.6f, 0.45f}, // right eye
{0.5f, 0.65f}, // mouth
{0.5f, 0.65f}, // mouth
}),
CreateOptions(2, 0, 2));
MP_ASSERT_OK(status_or_value);
const auto& rect = status_or_value.value();
EXPECT_THAT(rect.x_center(), testing::FloatNear(0.454688f, kEps));
EXPECT_THAT(rect.y_center(), testing::FloatNear(0.493056f, kEps));
EXPECT_THAT(rect.height(), testing::FloatNear(2.14306f, kEps));
EXPECT_THAT(rect.width(), testing::FloatNear(1.20547f, kEps));
EXPECT_THAT(rect.rotation(), testing::FloatNear(-0.193622f, kEps));
}
TEST(FaceToRectCalculator, LandmarksToNormalizedRectObtuseAngle) {
auto status_or_value = RunCalculator(
/*detection=*/DetectionWithKeyPoints({
{0.6f, 0.8f}, // left eye
{0.6f, 0.8f}, // left eye
{0.2f, 0.4f}, // right eye
{0.2f, 0.4f}, // right eye
{0.8f, 0.2f}, // mouth
{0.8f, 0.2f}, // mouth
}),
CreateOptions(2, 0, 2));
MP_ASSERT_OK(status_or_value);
const auto& rect = status_or_value.value();
EXPECT_THAT(rect.x_center(), testing::FloatNear(0.439844f, kEps));
EXPECT_THAT(rect.y_center(), testing::FloatNear(0.559722f, kEps));
EXPECT_THAT(rect.height(), testing::FloatNear(3.26389f, kEps));
EXPECT_THAT(rect.width(), testing::FloatNear(1.83594f, kEps));
EXPECT_THAT(rect.rotation(), testing::FloatNear(-2.35619f, kEps));
}
} // namespace mediapipe