migrate mediapipe/modules/face_geometry to mediapipe/tasks

PiperOrigin-RevId: 513284254
This commit is contained in:
MediaPipe Team 2023-03-01 10:57:33 -08:00 committed by Copybara-Service
parent 0a937eba98
commit 22fce9e136
20 changed files with 8656 additions and 0 deletions

View File

@ -0,0 +1,59 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto")
licenses(["notice"])
package(default_visibility = ["//visibility:public"])
encode_binary_proto(
name = "geometry_pipeline_metadata_detection",
input = "geometry_pipeline_metadata_detection.pbtxt",
message_type = "mediapipe.tasks.vision.face_geometry.proto.GeometryPipelineMetadata",
output = "geometry_pipeline_metadata_detection.binarypb",
deps = [
"//mediapipe/tasks/cc/vision/face_geometry/proto:geometry_pipeline_metadata_proto",
],
)
encode_binary_proto(
name = "geometry_pipeline_metadata_landmarks",
input = "geometry_pipeline_metadata_landmarks.pbtxt",
message_type = "mediapipe.tasks.vision.face_geometry.proto.GeometryPipelineMetadata",
output = "geometry_pipeline_metadata_landmarks.binarypb",
deps = [
"//mediapipe/tasks/cc/vision/face_geometry/proto:geometry_pipeline_metadata_proto",
],
)
# For backward-compatibility reasons, generate `geometry_pipeline_metadata.binarypb` from
# the `geometry_pipeline_metadata_landmarks.pbtxt` definition.
encode_binary_proto(
name = "geometry_pipeline_metadata",
input = "geometry_pipeline_metadata_landmarks.pbtxt",
message_type = "mediapipe.tasks.vision.face_geometry.proto.GeometryPipelineMetadata",
output = "geometry_pipeline_metadata.binarypb",
deps = [
"//mediapipe/tasks/cc/vision/face_geometry/proto:geometry_pipeline_metadata_proto",
],
)
# These canonical face model files are not meant to be used in runtime, but rather for asset
# creation and/or reference.
exports_files([
"canonical_face_model.fbx",
"canonical_face_model.obj",
"canonical_face_model_uv_visualization.png",
])

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 731 KiB

View File

@ -0,0 +1,78 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
input_source: FACE_DETECTION_PIPELINE
procrustes_landmark_basis { landmark_id: 0 weight: 1.0 }
procrustes_landmark_basis { landmark_id: 1 weight: 1.0 }
procrustes_landmark_basis { landmark_id: 2 weight: 1.0 }
procrustes_landmark_basis { landmark_id: 3 weight: 1.0 }
procrustes_landmark_basis { landmark_id: 4 weight: 1.0 }
procrustes_landmark_basis { landmark_id: 5 weight: 1.0 }
# NOTE: the triangular topology of the face meshes is only useful when derived
# from the 468 face landmarks, not from the 6 face detection landmarks
# (keypoints). The former don't cover the entire face and this mesh is
# defined here only to comply with the API. It should be considered as
# a placeholder and/or for debugging purposes.
#
# Use the face geometry derived from the face detection landmarks
# (keypoints) for the face pose transformation matrix, not the mesh.
canonical_mesh: {
vertex_type: VERTEX_PT
primitive_type: TRIANGLE
vertex_buffer: -3.1511454582214355
vertex_buffer: 2.6246179342269897
vertex_buffer: 3.4656630754470825
vertex_buffer: 0.349575996398926
vertex_buffer: 0.38137748837470997
vertex_buffer: 3.1511454582214355
vertex_buffer: 2.6246179342269897
vertex_buffer: 3.4656630754470825
vertex_buffer: 0.650443494319916
vertex_buffer: 0.38137999176979054
vertex_buffer: 0.0
vertex_buffer: -1.126865029335022
vertex_buffer: 7.475604057312012
vertex_buffer: 0.500025987625122
vertex_buffer: 0.547487020492554
vertex_buffer: 0.0
vertex_buffer: -4.304508209228516
vertex_buffer: 4.162498950958252
vertex_buffer: 0.499989986419678
vertex_buffer: 0.694203019142151
vertex_buffer: -7.664182186126709
vertex_buffer: 0.673132002353668
vertex_buffer: -2.435867071151733
vertex_buffer: 0.007561000064015
vertex_buffer: 0.480777025222778
vertex_buffer: 7.664182186126709
vertex_buffer: 0.673132002353668
vertex_buffer: -2.435867071151733
vertex_buffer: 0.992439985275269
vertex_buffer: 0.480777025222778
index_buffer: 0
index_buffer: 1
index_buffer: 2
index_buffer: 1
index_buffer: 5
index_buffer: 2
index_buffer: 4
index_buffer: 0
index_buffer: 2
index_buffer: 4
index_buffer: 2
index_buffer: 3
index_buffer: 2
index_buffer: 5
index_buffer: 3
}

View File

@ -0,0 +1,80 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//visibility:public"])
cc_library(
name = "geometry_pipeline",
srcs = ["geometry_pipeline.cc"],
hdrs = ["geometry_pipeline.h"],
deps = [
":mesh_3d_utils",
":procrustes_solver",
":validation_utils",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:matrix_data_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/proto:geometry_pipeline_metadata_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/proto:mesh_3d_cc_proto",
"@com_google_absl//absl/memory",
"@eigen_archive//:eigen3",
],
)
cc_library(
name = "mesh_3d_utils",
srcs = ["mesh_3d_utils.cc"],
hdrs = ["mesh_3d_utils.h"],
deps = [
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:statusor",
"//mediapipe/tasks/cc/vision/face_geometry/proto:mesh_3d_cc_proto",
],
)
cc_library(
name = "procrustes_solver",
srcs = ["procrustes_solver.cc"],
hdrs = ["procrustes_solver.h"],
deps = [
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"@com_google_absl//absl/memory",
"@eigen_archive//:eigen3",
],
)
cc_library(
name = "validation_utils",
srcs = ["validation_utils.cc"],
hdrs = ["validation_utils.h"],
deps = [
":mesh_3d_utils",
"//mediapipe/framework/formats:matrix_data_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/proto:geometry_pipeline_metadata_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/proto:mesh_3d_cc_proto",
],
)

View File

@ -0,0 +1,471 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/tasks/cc/vision/face_geometry/libs/geometry_pipeline.h"
#include <cmath>
#include <cstdint>
#include <memory>
#include <utility>
#include <vector>
#include "Eigen/Core"
#include "absl/memory/memory.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/matrix_data.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/statusor.h"
#include "mediapipe/tasks/cc/vision/face_geometry/libs/mesh_3d_utils.h"
#include "mediapipe/tasks/cc/vision/face_geometry/libs/procrustes_solver.h"
#include "mediapipe/tasks/cc/vision/face_geometry/libs/validation_utils.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/geometry_pipeline_metadata.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d.pb.h"
namespace mediapipe::tasks::vision::face_geometry {
namespace {
struct PerspectiveCameraFrustum {
// NOTE: all arguments must be validated prior to calling this constructor.
PerspectiveCameraFrustum(const proto::PerspectiveCamera& perspective_camera,
int frame_width, int frame_height) {
static constexpr float kDegreesToRadians = 3.14159265358979323846f / 180.f;
const float height_at_near =
2.f * perspective_camera.near() *
std::tan(0.5f * kDegreesToRadians *
perspective_camera.vertical_fov_degrees());
const float width_at_near = frame_width * height_at_near / frame_height;
left = -0.5f * width_at_near;
right = 0.5f * width_at_near;
bottom = -0.5f * height_at_near;
top = 0.5f * height_at_near;
near = perspective_camera.near();
far = perspective_camera.far();
}
float left;
float right;
float bottom;
float top;
float near;
float far;
};
class ScreenToMetricSpaceConverter {
public:
ScreenToMetricSpaceConverter(
proto::OriginPointLocation origin_point_location, //
proto::InputSource input_source, //
Eigen::Matrix3Xf&& canonical_metric_landmarks, //
Eigen::VectorXf&& landmark_weights, //
std::unique_ptr<ProcrustesSolver> procrustes_solver)
: origin_point_location_(origin_point_location),
input_source_(input_source),
canonical_metric_landmarks_(std::move(canonical_metric_landmarks)),
landmark_weights_(std::move(landmark_weights)),
procrustes_solver_(std::move(procrustes_solver)) {}
// Converts `screen_landmark_list` into `metric_landmark_list` and estimates
// the `pose_transform_mat`.
//
// Here's the algorithm summary:
//
// (1) Project X- and Y- screen landmark coordinates at the Z near plane.
//
// (2) Estimate a canonical-to-runtime landmark set scale by running the
// Procrustes solver using the screen runtime landmarks.
//
// On this iteration, screen landmarks are used instead of unprojected
// metric landmarks as it is not safe to unproject due to the relative
// nature of the input screen landmark Z coordinate.
//
// (3) Use the canonical-to-runtime scale from (2) to unproject the screen
// landmarks. The result is referenced as "intermediate landmarks" because
// they are the first estimation of the resuling metric landmarks, but are
// not quite there yet.
//
// (4) Estimate a canonical-to-runtime landmark set scale by running the
// Procrustes solver using the intermediate runtime landmarks.
//
// (5) Use the product of the scale factors from (2) and (4) to unproject
// the screen landmarks the second time. This is the second and the final
// estimation of the metric landmarks.
//
// (6) Multiply each of the metric landmarks by the inverse pose
// transformation matrix to align the runtime metric face landmarks with
// the canonical metric face landmarks.
//
// Note: the input screen landmarks are in the left-handed coordinate system,
// however any metric landmarks - including the canonical metric
// landmarks, the final runtime metric landmarks and any intermediate
// runtime metric landmarks - are in the right-handed coordinate system.
//
// To keep the logic correct, the landmark set handedness is changed any
// time the screen-to-metric semantic barrier is passed.
absl::Status Convert(
const mediapipe::NormalizedLandmarkList& screen_landmark_list, //
const PerspectiveCameraFrustum& pcf, //
mediapipe::LandmarkList& metric_landmark_list, //
Eigen::Matrix4f& pose_transform_mat) const {
RET_CHECK_EQ(screen_landmark_list.landmark_size(),
canonical_metric_landmarks_.cols())
<< "The number of landmarks doesn't match the number passed upon "
"initialization!";
Eigen::Matrix3Xf screen_landmarks;
ConvertLandmarkListToEigenMatrix(screen_landmark_list, screen_landmarks);
ProjectXY(pcf, screen_landmarks);
const float depth_offset = screen_landmarks.row(2).mean();
// 1st iteration: don't unproject XY because it's unsafe to do so due to
// the relative nature of the Z coordinate. Instead, run the
// first estimation on the projected XY and use that scale to
// unproject for the 2nd iteration.
Eigen::Matrix3Xf intermediate_landmarks(screen_landmarks);
ChangeHandedness(intermediate_landmarks);
ASSIGN_OR_RETURN(const float first_iteration_scale,
EstimateScale(intermediate_landmarks),
_ << "Failed to estimate first iteration scale!");
// 2nd iteration: unproject XY using the scale from the 1st iteration.
intermediate_landmarks = screen_landmarks;
MoveAndRescaleZ(pcf, depth_offset, first_iteration_scale,
intermediate_landmarks);
UnprojectXY(pcf, intermediate_landmarks);
ChangeHandedness(intermediate_landmarks);
// For face detection input landmarks, re-write Z-coord from the canonical
// landmarks.
if (input_source_ == proto::InputSource::FACE_DETECTION_PIPELINE) {
Eigen::Matrix4f intermediate_pose_transform_mat;
MP_RETURN_IF_ERROR(procrustes_solver_->SolveWeightedOrthogonalProblem(
canonical_metric_landmarks_, intermediate_landmarks,
landmark_weights_, intermediate_pose_transform_mat))
<< "Failed to estimate pose transform matrix!";
intermediate_landmarks.row(2) =
(intermediate_pose_transform_mat *
canonical_metric_landmarks_.colwise().homogeneous())
.row(2);
}
ASSIGN_OR_RETURN(const float second_iteration_scale,
EstimateScale(intermediate_landmarks),
_ << "Failed to estimate second iteration scale!");
// Use the total scale to unproject the screen landmarks.
const float total_scale = first_iteration_scale * second_iteration_scale;
MoveAndRescaleZ(pcf, depth_offset, total_scale, screen_landmarks);
UnprojectXY(pcf, screen_landmarks);
ChangeHandedness(screen_landmarks);
// At this point, screen landmarks are converted into metric landmarks.
Eigen::Matrix3Xf& metric_landmarks = screen_landmarks;
MP_RETURN_IF_ERROR(procrustes_solver_->SolveWeightedOrthogonalProblem(
canonical_metric_landmarks_, metric_landmarks, landmark_weights_,
pose_transform_mat))
<< "Failed to estimate pose transform matrix!";
// For face detection input landmarks, re-write Z-coord from the canonical
// landmarks and run the pose transform estimation again.
if (input_source_ == proto::InputSource::FACE_DETECTION_PIPELINE) {
metric_landmarks.row(2) =
(pose_transform_mat *
canonical_metric_landmarks_.colwise().homogeneous())
.row(2);
MP_RETURN_IF_ERROR(procrustes_solver_->SolveWeightedOrthogonalProblem(
canonical_metric_landmarks_, metric_landmarks, landmark_weights_,
pose_transform_mat))
<< "Failed to estimate pose transform matrix!";
}
// Multiply each of the metric landmarks by the inverse pose
// transformation matrix to align the runtime metric face landmarks with
// the canonical metric face landmarks.
metric_landmarks = (pose_transform_mat.inverse() *
metric_landmarks.colwise().homogeneous())
.topRows(3);
ConvertEigenMatrixToLandmarkList(metric_landmarks, metric_landmark_list);
return absl::OkStatus();
}
private:
void ProjectXY(const PerspectiveCameraFrustum& pcf,
Eigen::Matrix3Xf& landmarks) const {
float x_scale = pcf.right - pcf.left;
float y_scale = pcf.top - pcf.bottom;
float x_translation = pcf.left;
float y_translation = pcf.bottom;
if (origin_point_location_ == proto::OriginPointLocation::TOP_LEFT_CORNER) {
landmarks.row(1) = 1.f - landmarks.row(1).array();
}
landmarks =
landmarks.array().colwise() * Eigen::Array3f(x_scale, y_scale, x_scale);
landmarks.colwise() += Eigen::Vector3f(x_translation, y_translation, 0.f);
}
absl::StatusOr<float> EstimateScale(Eigen::Matrix3Xf& landmarks) const {
Eigen::Matrix4f transform_mat;
MP_RETURN_IF_ERROR(procrustes_solver_->SolveWeightedOrthogonalProblem(
canonical_metric_landmarks_, landmarks, landmark_weights_,
transform_mat))
<< "Failed to estimate canonical-to-runtime landmark set transform!";
return transform_mat.col(0).norm();
}
static void MoveAndRescaleZ(const PerspectiveCameraFrustum& pcf,
float depth_offset, float scale,
Eigen::Matrix3Xf& landmarks) {
landmarks.row(2) =
(landmarks.array().row(2) - depth_offset + pcf.near) / scale;
}
static void UnprojectXY(const PerspectiveCameraFrustum& pcf,
Eigen::Matrix3Xf& landmarks) {
landmarks.row(0) =
landmarks.row(0).cwiseProduct(landmarks.row(2)) / pcf.near;
landmarks.row(1) =
landmarks.row(1).cwiseProduct(landmarks.row(2)) / pcf.near;
}
static void ChangeHandedness(Eigen::Matrix3Xf& landmarks) {
landmarks.row(2) *= -1.f;
}
static void ConvertLandmarkListToEigenMatrix(
const mediapipe::NormalizedLandmarkList& landmark_list,
Eigen::Matrix3Xf& eigen_matrix) {
eigen_matrix = Eigen::Matrix3Xf(3, landmark_list.landmark_size());
for (int i = 0; i < landmark_list.landmark_size(); ++i) {
const auto& landmark = landmark_list.landmark(i);
eigen_matrix(0, i) = landmark.x();
eigen_matrix(1, i) = landmark.y();
eigen_matrix(2, i) = landmark.z();
}
}
static void ConvertEigenMatrixToLandmarkList(
const Eigen::Matrix3Xf& eigen_matrix,
mediapipe::LandmarkList& landmark_list) {
landmark_list.Clear();
for (int i = 0; i < eigen_matrix.cols(); ++i) {
auto& landmark = *landmark_list.add_landmark();
landmark.set_x(eigen_matrix(0, i));
landmark.set_y(eigen_matrix(1, i));
landmark.set_z(eigen_matrix(2, i));
}
}
const proto::OriginPointLocation origin_point_location_;
const proto::InputSource input_source_;
Eigen::Matrix3Xf canonical_metric_landmarks_;
Eigen::VectorXf landmark_weights_;
std::unique_ptr<ProcrustesSolver> procrustes_solver_;
};
class GeometryPipelineImpl : public GeometryPipeline {
public:
GeometryPipelineImpl(
const proto::PerspectiveCamera& perspective_camera, //
const proto::Mesh3d& canonical_mesh, //
uint32_t canonical_mesh_vertex_size, //
uint32_t canonical_mesh_num_vertices,
uint32_t canonical_mesh_vertex_position_offset,
std::unique_ptr<ScreenToMetricSpaceConverter> space_converter)
: perspective_camera_(perspective_camera),
canonical_mesh_(canonical_mesh),
canonical_mesh_vertex_size_(canonical_mesh_vertex_size),
canonical_mesh_num_vertices_(canonical_mesh_num_vertices),
canonical_mesh_vertex_position_offset_(
canonical_mesh_vertex_position_offset),
space_converter_(std::move(space_converter)) {}
absl::StatusOr<std::vector<proto::FaceGeometry>> EstimateFaceGeometry(
const std::vector<mediapipe::NormalizedLandmarkList>&
multi_face_landmarks,
int frame_width, int frame_height) const override {
MP_RETURN_IF_ERROR(ValidateFrameDimensions(frame_width, frame_height))
<< "Invalid frame dimensions!";
// Create a perspective camera frustum to be shared for geometry estimation
// per each face.
PerspectiveCameraFrustum pcf(perspective_camera_, frame_width,
frame_height);
std::vector<proto::FaceGeometry> multi_face_geometry;
// From this point, the meaning of "face landmarks" is clarified further as
// "screen face landmarks". This is done do distinguish from "metric face
// landmarks" that are derived during the face geometry estimation process.
for (const mediapipe::NormalizedLandmarkList& screen_face_landmarks :
multi_face_landmarks) {
// Having a too compact screen landmark list will result in numerical
// instabilities, therefore such faces are filtered.
if (IsScreenLandmarkListTooCompact(screen_face_landmarks)) {
continue;
}
// Convert the screen landmarks into the metric landmarks and get the pose
// transformation matrix.
mediapipe::LandmarkList metric_face_landmarks;
Eigen::Matrix4f pose_transform_mat;
MP_RETURN_IF_ERROR(space_converter_->Convert(screen_face_landmarks, pcf,
metric_face_landmarks,
pose_transform_mat))
<< "Failed to convert landmarks from the screen to the metric space!";
// Pack geometry data for this face.
proto::FaceGeometry face_geometry;
proto::Mesh3d* mutable_mesh = face_geometry.mutable_mesh();
// Copy the canonical face mesh as the face geometry mesh.
mutable_mesh->CopyFrom(canonical_mesh_);
// Replace XYZ vertex mesh coodinates with the metric landmark positions.
for (int i = 0; i < canonical_mesh_num_vertices_; ++i) {
uint32_t vertex_buffer_offset = canonical_mesh_vertex_size_ * i +
canonical_mesh_vertex_position_offset_;
mutable_mesh->set_vertex_buffer(vertex_buffer_offset,
metric_face_landmarks.landmark(i).x());
mutable_mesh->set_vertex_buffer(vertex_buffer_offset + 1,
metric_face_landmarks.landmark(i).y());
mutable_mesh->set_vertex_buffer(vertex_buffer_offset + 2,
metric_face_landmarks.landmark(i).z());
}
// Populate the face pose transformation matrix.
mediapipe::MatrixDataProtoFromMatrix(
pose_transform_mat, face_geometry.mutable_pose_transform_matrix());
multi_face_geometry.push_back(face_geometry);
}
return multi_face_geometry;
}
private:
static bool IsScreenLandmarkListTooCompact(
const mediapipe::NormalizedLandmarkList& screen_landmarks) {
float mean_x = 0.f;
float mean_y = 0.f;
for (int i = 0; i < screen_landmarks.landmark_size(); ++i) {
const auto& landmark = screen_landmarks.landmark(i);
mean_x += (landmark.x() - mean_x) / static_cast<float>(i + 1);
mean_y += (landmark.y() - mean_y) / static_cast<float>(i + 1);
}
float max_sq_dist = 0.f;
for (const auto& landmark : screen_landmarks.landmark()) {
const float d_x = landmark.x() - mean_x;
const float d_y = landmark.y() - mean_y;
max_sq_dist = std::max(max_sq_dist, d_x * d_x + d_y * d_y);
}
static constexpr float kIsScreenLandmarkListTooCompactThreshold = 1e-3f;
return std::sqrt(max_sq_dist) <= kIsScreenLandmarkListTooCompactThreshold;
}
const proto::PerspectiveCamera perspective_camera_;
const proto::Mesh3d canonical_mesh_;
const uint32_t canonical_mesh_vertex_size_;
const uint32_t canonical_mesh_num_vertices_;
const uint32_t canonical_mesh_vertex_position_offset_;
std::unique_ptr<ScreenToMetricSpaceConverter> space_converter_;
};
} // namespace
absl::StatusOr<std::unique_ptr<GeometryPipeline>> CreateGeometryPipeline(
const proto::Environment& environment,
const proto::GeometryPipelineMetadata& metadata) {
MP_RETURN_IF_ERROR(ValidateEnvironment(environment))
<< "Invalid environment!";
MP_RETURN_IF_ERROR(ValidateGeometryPipelineMetadata(metadata))
<< "Invalid geometry pipeline metadata!";
const auto& canonical_mesh = metadata.canonical_mesh();
RET_CHECK(HasVertexComponent(canonical_mesh.vertex_type(),
VertexComponent::POSITION))
<< "Canonical face mesh must have the `POSITION` vertex component!";
RET_CHECK(HasVertexComponent(canonical_mesh.vertex_type(),
VertexComponent::TEX_COORD))
<< "Canonical face mesh must have the `TEX_COORD` vertex component!";
uint32_t canonical_mesh_vertex_size =
GetVertexSize(canonical_mesh.vertex_type());
uint32_t canonical_mesh_num_vertices =
canonical_mesh.vertex_buffer_size() / canonical_mesh_vertex_size;
uint32_t canonical_mesh_vertex_position_offset =
GetVertexComponentOffset(canonical_mesh.vertex_type(),
VertexComponent::POSITION)
.value();
// Put the Procrustes landmark basis into Eigen matrices for an easier access.
Eigen::Matrix3Xf canonical_metric_landmarks =
Eigen::Matrix3Xf::Zero(3, canonical_mesh_num_vertices);
Eigen::VectorXf landmark_weights =
Eigen::VectorXf::Zero(canonical_mesh_num_vertices);
for (int i = 0; i < canonical_mesh_num_vertices; ++i) {
uint32_t vertex_buffer_offset =
canonical_mesh_vertex_size * i + canonical_mesh_vertex_position_offset;
canonical_metric_landmarks(0, i) =
canonical_mesh.vertex_buffer(vertex_buffer_offset);
canonical_metric_landmarks(1, i) =
canonical_mesh.vertex_buffer(vertex_buffer_offset + 1);
canonical_metric_landmarks(2, i) =
canonical_mesh.vertex_buffer(vertex_buffer_offset + 2);
}
for (const proto::WeightedLandmarkRef& wlr :
metadata.procrustes_landmark_basis()) {
uint32_t landmark_id = wlr.landmark_id();
landmark_weights(landmark_id) = wlr.weight();
}
std::unique_ptr<GeometryPipeline> result =
absl::make_unique<GeometryPipelineImpl>(
environment.perspective_camera(), canonical_mesh,
canonical_mesh_vertex_size, canonical_mesh_num_vertices,
canonical_mesh_vertex_position_offset,
absl::make_unique<ScreenToMetricSpaceConverter>(
environment.origin_point_location(),
metadata.input_source() == proto::InputSource::DEFAULT
? proto::InputSource::FACE_LANDMARK_PIPELINE
: metadata.input_source(),
std::move(canonical_metric_landmarks),
std::move(landmark_weights),
CreateFloatPrecisionProcrustesSolver()));
return result;
}
} // namespace mediapipe::tasks::vision::face_geometry

View File

@ -0,0 +1,69 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_TASKS_CC_VISION_FACE_GEOMETRY_LIBS_GEOMETRY_PIPELINE_H_
#define MEDIAPIPE_TASKS_CC_VISION_FACE_GEOMETRY_LIBS_GEOMETRY_PIPELINE_H_
#include <memory>
#include <vector>
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/statusor.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/geometry_pipeline_metadata.pb.h"
namespace mediapipe::tasks::vision::face_geometry {
// Encapsulates a stateless estimator of facial geometry in a Metric space based
// on the normalized face landmarks in the Screen space.
class GeometryPipeline {
public:
virtual ~GeometryPipeline() = default;
// Estimates geometry data for multiple faces.
//
// Returns an error status if any of the passed arguments is invalid.
//
// The result includes face geometry data for a subset of the input faces,
// however geometry data for some faces might be missing. This may happen if
// it'd be unstable to estimate the facial geometry based on a corresponding
// face landmark list for any reason (for example, if the landmark list is too
// compact).
//
// Each face landmark list must have the same number of landmarks as was
// passed upon initialization via the canonical face mesh (as a part of the
// geometry pipeline metadata).
//
// Both `frame_width` and `frame_height` must be positive.
virtual absl::StatusOr<std::vector<proto::FaceGeometry>> EstimateFaceGeometry(
const std::vector<mediapipe::NormalizedLandmarkList>&
multi_face_landmarks,
int frame_width, int frame_height) const = 0;
};
// Creates an instance of `GeometryPipeline`.
//
// Both `environment` and `metadata` must be valid (for details, please refer to
// the proto message definition comments and/or `validation_utils.h/cc`).
//
// Canonical face mesh (defined as a part of `metadata`) must have the
// `POSITION` and the `TEX_COORD` vertex components.
absl::StatusOr<std::unique_ptr<GeometryPipeline>> CreateGeometryPipeline(
const proto::Environment& environment,
const proto::GeometryPipelineMetadata& metadata);
} // namespace mediapipe::tasks::vision::face_geometry
#endif // MEDIAPIPE_TASKS_CC_VISION_FACE_GEOMETRY_LIBS_GEOMETRY_PIPELINE_H_

View File

@ -0,0 +1,103 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/tasks/cc/vision/face_geometry/libs/mesh_3d_utils.h"
#include <cstdint>
#include <cstdlib>
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/statusor.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d.pb.h"
namespace mediapipe::tasks::vision::face_geometry {
namespace {
bool HasVertexComponentVertexPT(VertexComponent vertex_component) {
switch (vertex_component) {
case VertexComponent::POSITION:
case VertexComponent::TEX_COORD:
return true;
default:
return false;
}
}
uint32_t GetVertexComponentSizeVertexPT(VertexComponent vertex_component) {
switch (vertex_component) {
case VertexComponent::POSITION:
return 3;
case VertexComponent::TEX_COORD:
return 2;
}
}
uint32_t GetVertexComponentOffsetVertexPT(VertexComponent vertex_component) {
switch (vertex_component) {
case VertexComponent::POSITION:
return 0;
case VertexComponent::TEX_COORD:
return GetVertexComponentSizeVertexPT(VertexComponent::POSITION);
}
}
} // namespace
std::size_t GetVertexSize(proto::Mesh3d::VertexType vertex_type) {
switch (vertex_type) {
case proto::Mesh3d::VERTEX_PT:
return GetVertexComponentSizeVertexPT(VertexComponent::POSITION) +
GetVertexComponentSizeVertexPT(VertexComponent::TEX_COORD);
}
}
std::size_t GetPrimitiveSize(proto::Mesh3d::PrimitiveType primitive_type) {
switch (primitive_type) {
case proto::Mesh3d::TRIANGLE:
return 3;
}
}
bool HasVertexComponent(proto::Mesh3d::VertexType vertex_type,
VertexComponent vertex_component) {
switch (vertex_type) {
case proto::Mesh3d::VERTEX_PT:
return HasVertexComponentVertexPT(vertex_component);
}
}
absl::StatusOr<uint32_t> GetVertexComponentOffset(
proto::Mesh3d::VertexType vertex_type, VertexComponent vertex_component) {
RET_CHECK(HasVertexComponentVertexPT(vertex_component))
<< "A given vertex type doesn't have the requested component!";
switch (vertex_type) {
case proto::Mesh3d::VERTEX_PT:
return GetVertexComponentOffsetVertexPT(vertex_component);
}
}
absl::StatusOr<uint32_t> GetVertexComponentSize(
proto::Mesh3d::VertexType vertex_type, VertexComponent vertex_component) {
RET_CHECK(HasVertexComponentVertexPT(vertex_component))
<< "A given vertex type doesn't have the requested component!";
switch (vertex_type) {
case proto::Mesh3d::VERTEX_PT:
return GetVertexComponentSizeVertexPT(vertex_component);
}
}
} // namespace mediapipe::tasks::vision::face_geometry

View File

@ -0,0 +1,51 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_TASKS_CC_VISION_FACE_GEOMETRY_LIBS_MESH_3D_UTILS_H_
#define MEDIAPIPE_TASKS_CC_VISION_FACE_GEOMETRY_LIBS_MESH_3D_UTILS_H_
#include <cstdint>
#include <cstdlib>
#include "mediapipe/framework/port/statusor.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d.pb.h"
namespace mediapipe::tasks::vision::face_geometry {
enum class VertexComponent { POSITION, TEX_COORD };
std::size_t GetVertexSize(proto::Mesh3d::VertexType vertex_type);
std::size_t GetPrimitiveSize(proto::Mesh3d::PrimitiveType primitive_type);
bool HasVertexComponent(proto::Mesh3d::VertexType vertex_type,
VertexComponent vertex_component);
// Computes the vertex component offset.
//
// Returns an error status if a given vertex type doesn't have the requested
// component.
absl::StatusOr<uint32_t> GetVertexComponentOffset(
proto::Mesh3d::VertexType vertex_type, VertexComponent vertex_component);
// Computes the vertex component size.
//
// Returns an error status if a given vertex type doesn't have the requested
// component.
absl::StatusOr<uint32_t> GetVertexComponentSize(
proto::Mesh3d::VertexType vertex_type, VertexComponent vertex_component);
} // namespace mediapipe::tasks::vision::face_geometry
#endif // MEDIAPIPE_TASKS_CC_VISION_FACE_GEOMETRY_LIBS_MESH_3D_UTILS_H_

View File

@ -0,0 +1,264 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/tasks/cc/vision/face_geometry/libs/procrustes_solver.h"
#include <cmath>
#include <memory>
#include "Eigen/Dense"
#include "absl/memory/memory.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/statusor.h"
namespace mediapipe::tasks::vision::face_geometry {
namespace {
class FloatPrecisionProcrustesSolver : public ProcrustesSolver {
public:
FloatPrecisionProcrustesSolver() = default;
absl::Status SolveWeightedOrthogonalProblem(
const Eigen::Matrix3Xf& source_points, //
const Eigen::Matrix3Xf& target_points, //
const Eigen::VectorXf& point_weights,
Eigen::Matrix4f& transform_mat) const override {
// Validate inputs.
MP_RETURN_IF_ERROR(ValidateInputPoints(source_points, target_points))
<< "Failed to validate weighted orthogonal problem input points!";
MP_RETURN_IF_ERROR(
ValidatePointWeights(source_points.cols(), point_weights))
<< "Failed to validate weighted orthogonal problem point weights!";
// Extract square root from the point weights.
Eigen::VectorXf sqrt_weights = ExtractSquareRoot(point_weights);
// Try to solve the WEOP problem.
MP_RETURN_IF_ERROR(InternalSolveWeightedOrthogonalProblem(
source_points, target_points, sqrt_weights, transform_mat))
<< "Failed to solve the WEOP problem!";
return absl::OkStatus();
}
private:
static constexpr float kAbsoluteErrorEps = 1e-9f;
static absl::Status ValidateInputPoints(
const Eigen::Matrix3Xf& source_points,
const Eigen::Matrix3Xf& target_points) {
RET_CHECK_GT(source_points.cols(), 0)
<< "The number of source points must be positive!";
RET_CHECK_EQ(source_points.cols(), target_points.cols())
<< "The number of source and target points must be equal!";
return absl::OkStatus();
}
static absl::Status ValidatePointWeights(
int num_points, const Eigen::VectorXf& point_weights) {
RET_CHECK_GT(point_weights.size(), 0)
<< "The number of point weights must be positive!";
RET_CHECK_EQ(point_weights.size(), num_points)
<< "The number of points and point weights must be equal!";
float total_weight = 0.f;
for (int i = 0; i < num_points; ++i) {
RET_CHECK_GE(point_weights(i), 0.f)
<< "Each point weight must be non-negative!";
total_weight += point_weights(i);
}
RET_CHECK_GT(total_weight, kAbsoluteErrorEps)
<< "The total point weight is too small!";
return absl::OkStatus();
}
static Eigen::VectorXf ExtractSquareRoot(
const Eigen::VectorXf& point_weights) {
Eigen::VectorXf sqrt_weights(point_weights);
for (int i = 0; i < sqrt_weights.size(); ++i) {
sqrt_weights(i) = std::sqrt(sqrt_weights(i));
}
return sqrt_weights;
}
// Combines a 3x3 rotation-and-scale matrix and a 3x1 translation vector into
// a single 4x4 transformation matrix.
static Eigen::Matrix4f CombineTransformMatrix(const Eigen::Matrix3f& r_and_s,
const Eigen::Vector3f& t) {
Eigen::Matrix4f result = Eigen::Matrix4f::Identity();
result.leftCols(3).topRows(3) = r_and_s;
result.col(3).topRows(3) = t;
return result;
}
// The weighted problem is thoroughly addressed in Section 2.4 of:
// D. Akca, Generalized Procrustes analysis and its applications
// in photogrammetry, 2003, https://doi.org/10.3929/ethz-a-004656648
//
// Notable differences in the code presented here are:
//
// * In the paper, the weights matrix W_p is Cholesky-decomposed as Q^T Q.
// Our W_p is diagonal (equal to diag(sqrt_weights^2)),
// so we can just set Q = diag(sqrt_weights) instead.
//
// * In the paper, the problem is presented as
// (for W_k = I and W_p = tranposed(Q) Q):
// || Q (c A T + j tranposed(t) - B) || -> min.
//
// We reformulate it as an equivalent minimization of the transpose's
// norm:
// || (c tranposed(T) tranposed(A) - tranposed(B)) tranposed(Q) || -> min,
// where tranposed(A) and tranposed(B) are the source and the target point
// clouds, respectively, c tranposed(T) is the rotation+scaling R sought
// for, and Q is diag(sqrt_weights).
//
// Most of the derivations are therefore transposed.
//
// Note: the output `transform_mat` argument is used instead of `StatusOr<>`
// return type in order to avoid Eigen memory alignment issues. Details:
// https://eigen.tuxfamily.org/dox/group__TopicStructHavingEigenMembers.html
static absl::Status InternalSolveWeightedOrthogonalProblem(
const Eigen::Matrix3Xf& sources, const Eigen::Matrix3Xf& targets,
const Eigen::VectorXf& sqrt_weights, Eigen::Matrix4f& transform_mat) {
// tranposed(A_w).
Eigen::Matrix3Xf weighted_sources =
sources.array().rowwise() * sqrt_weights.array().transpose();
// tranposed(B_w).
Eigen::Matrix3Xf weighted_targets =
targets.array().rowwise() * sqrt_weights.array().transpose();
// w = tranposed(j_w) j_w.
float total_weight = sqrt_weights.cwiseProduct(sqrt_weights).sum();
// Let C = (j_w tranposed(j_w)) / (tranposed(j_w) j_w).
// Note that C = tranposed(C), hence (I - C) = tranposed(I - C).
//
// tranposed(A_w) C = tranposed(A_w) j_w tranposed(j_w) / w =
// (tranposed(A_w) j_w) tranposed(j_w) / w = c_w tranposed(j_w),
//
// where c_w = tranposed(A_w) j_w / w is a k x 1 vector calculated here:
Eigen::Matrix3Xf twice_weighted_sources =
weighted_sources.array().rowwise() * sqrt_weights.array().transpose();
Eigen::Vector3f source_center_of_mass =
twice_weighted_sources.rowwise().sum() / total_weight;
// tranposed((I - C) A_w) = tranposed(A_w) (I - C) =
// tranposed(A_w) - tranposed(A_w) C = tranposed(A_w) - c_w tranposed(j_w).
Eigen::Matrix3Xf centered_weighted_sources =
weighted_sources - source_center_of_mass * sqrt_weights.transpose();
Eigen::Matrix3f rotation;
MP_RETURN_IF_ERROR(ComputeOptimalRotation(
weighted_targets * centered_weighted_sources.transpose(), rotation))
<< "Failed to compute the optimal rotation!";
ASSIGN_OR_RETURN(
float scale,
ComputeOptimalScale(centered_weighted_sources, weighted_sources,
weighted_targets, rotation),
_ << "Failed to compute the optimal scale!");
// R = c tranposed(T).
Eigen::Matrix3f rotation_and_scale = scale * rotation;
// Compute optimal translation for the weighted problem.
// tranposed(B_w - c A_w T) = tranposed(B_w) - R tranposed(A_w) in (54).
const auto pointwise_diffs =
weighted_targets - rotation_and_scale * weighted_sources;
// Multiplication by j_w is a respectively weighted column sum.
// (54) from the paper.
const auto weighted_pointwise_diffs =
pointwise_diffs.array().rowwise() * sqrt_weights.array().transpose();
Eigen::Vector3f translation =
weighted_pointwise_diffs.rowwise().sum() / total_weight;
transform_mat = CombineTransformMatrix(rotation_and_scale, translation);
return absl::OkStatus();
}
// `design_matrix` is a transposed LHS of (51) in the paper.
//
// Note: the output `rotation` argument is used instead of `StatusOr<>`
// return type in order to avoid Eigen memory alignment issues. Details:
// https://eigen.tuxfamily.org/dox/group__TopicStructHavingEigenMembers.html
static absl::Status ComputeOptimalRotation(
const Eigen::Matrix3f& design_matrix, Eigen::Matrix3f& rotation) {
RET_CHECK_GT(design_matrix.norm(), kAbsoluteErrorEps)
<< "Design matrix norm is too small!";
Eigen::JacobiSVD<Eigen::Matrix3f> svd(
design_matrix, Eigen::ComputeFullU | Eigen::ComputeFullV);
Eigen::Matrix3f postrotation = svd.matrixU();
Eigen::Matrix3f prerotation = svd.matrixV().transpose();
// Disallow reflection by ensuring that det(`rotation`) = +1 (and not -1),
// see "4.6 Constrained orthogonal Procrustes problems"
// in the Gower & Dijksterhuis's book "Procrustes Analysis".
// We flip the sign of the least singular value along with a column in W.
//
// Note that now the sum of singular values doesn't work for scale
// estimation due to this sign flip.
if (postrotation.determinant() * prerotation.determinant() <
static_cast<float>(0)) {
postrotation.col(2) *= static_cast<float>(-1);
}
// Transposed (52) from the paper.
rotation = postrotation * prerotation;
return absl::OkStatus();
}
static absl::StatusOr<float> ComputeOptimalScale(
const Eigen::Matrix3Xf& centered_weighted_sources,
const Eigen::Matrix3Xf& weighted_sources,
const Eigen::Matrix3Xf& weighted_targets,
const Eigen::Matrix3f& rotation) {
// tranposed(T) tranposed(A_w) (I - C).
const auto rotated_centered_weighted_sources =
rotation * centered_weighted_sources;
// Use the identity trace(A B) = sum(A * B^T)
// to avoid building large intermediate matrices (* is Hadamard product).
// (53) from the paper.
float numerator =
rotated_centered_weighted_sources.cwiseProduct(weighted_targets).sum();
float denominator =
centered_weighted_sources.cwiseProduct(weighted_sources).sum();
RET_CHECK_GT(denominator, kAbsoluteErrorEps)
<< "Scale expression denominator is too small!";
RET_CHECK_GT(numerator / denominator, kAbsoluteErrorEps)
<< "Scale is too small!";
return numerator / denominator;
}
};
} // namespace
std::unique_ptr<ProcrustesSolver> CreateFloatPrecisionProcrustesSolver() {
return absl::make_unique<FloatPrecisionProcrustesSolver>();
}
} // namespace mediapipe::tasks::vision::face_geometry

View File

@ -0,0 +1,70 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_TASKS_CC_VISION_FACE_GEOMETRY_LIBS_PROCRUSTES_SOLVER_H_
#define MEDIAPIPE_TASKS_CC_VISION_FACE_GEOMETRY_LIBS_PROCRUSTES_SOLVER_H_
#include <memory>
#include "Eigen/Dense"
#include "mediapipe/framework/port/status.h"
namespace mediapipe::tasks::vision::face_geometry {
// Encapsulates a stateless solver for the Weighted Extended Orthogonal
// Procrustes (WEOP) Problem, as defined in Section 2.4 of
// https://doi.org/10.3929/ethz-a-004656648.
//
// Given the source and the target point clouds, the algorithm estimates
// a 4x4 transformation matrix featuring the following semantic components:
//
// * Uniform scale
// * Rotation
// * Translation
//
// The matrix maps the source point cloud into the target point cloud minimizing
// the Mean Squared Error.
class ProcrustesSolver {
public:
virtual ~ProcrustesSolver() = default;
// Solves the Weighted Extended Orthogonal Procrustes (WEOP) Problem.
//
// All `source_points`, `target_points` and `point_weights` must define the
// same number of points. Elements of `point_weights` must be non-negative.
//
// A too small diameter of either of the point clouds will likely lead to
// numerical instabilities and failure to estimate the transformation.
//
// A too small point cloud total weight will likely lead to numerical
// instabilities and failure to estimate the transformation too.
//
// Small point coordinate deviation for either of the point cloud will likely
// result in a failure as it will make the solution very unstable if possible.
//
// Note: the output `transform_mat` argument is used instead of `StatusOr<>`
// return type in order to avoid Eigen memory alignment issues. Details:
// https://eigen.tuxfamily.org/dox/group__TopicStructHavingEigenMembers.html
virtual absl::Status SolveWeightedOrthogonalProblem(
const Eigen::Matrix3Xf& source_points, //
const Eigen::Matrix3Xf& target_points, //
const Eigen::VectorXf& point_weights, //
Eigen::Matrix4f& transform_mat) const = 0;
};
std::unique_ptr<ProcrustesSolver> CreateFloatPrecisionProcrustesSolver();
} // namespace mediapipe::tasks::vision::face_geometry
#endif // MEDIAPIPE_TASKS_CC_VISION_FACE_GEOMETRY_LIBS_PROCRUSTES_SOLVER_H_

View File

@ -0,0 +1,127 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/tasks/cc/vision/face_geometry/libs/validation_utils.h"
#include <cstdint>
#include <cstdlib>
#include "mediapipe/framework/formats/matrix_data.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/vision/face_geometry/libs/mesh_3d_utils.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/geometry_pipeline_metadata.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d.pb.h"
namespace mediapipe::tasks::vision::face_geometry {
absl::Status ValidatePerspectiveCamera(
const proto::PerspectiveCamera& perspective_camera) {
static constexpr float kAbsoluteErrorEps = 1e-9f;
RET_CHECK_GT(perspective_camera.near(), kAbsoluteErrorEps)
<< "Near Z must be greater than 0 with a margin of 10^{-9}!";
RET_CHECK_GT(perspective_camera.far(),
perspective_camera.near() + kAbsoluteErrorEps)
<< "Far Z must be greater than Near Z with a margin of 10^{-9}!";
RET_CHECK_GT(perspective_camera.vertical_fov_degrees(), kAbsoluteErrorEps)
<< "Vertical FOV must be positive with a margin of 10^{-9}!";
RET_CHECK_LT(perspective_camera.vertical_fov_degrees() + kAbsoluteErrorEps,
180.f)
<< "Vertical FOV must be less than 180 degrees with a margin of 10^{-9}";
return absl::OkStatus();
}
absl::Status ValidateEnvironment(const proto::Environment& environment) {
MP_RETURN_IF_ERROR(
ValidatePerspectiveCamera(environment.perspective_camera()))
<< "Invalid perspective camera!";
return absl::OkStatus();
}
absl::Status ValidateMesh3d(const proto::Mesh3d& mesh_3d) {
const std::size_t vertex_size = GetVertexSize(mesh_3d.vertex_type());
const std::size_t primitive_type = GetPrimitiveSize(mesh_3d.primitive_type());
RET_CHECK_EQ(mesh_3d.vertex_buffer_size() % vertex_size, 0)
<< "Vertex buffer size must a multiple of the vertex size!";
RET_CHECK_EQ(mesh_3d.index_buffer_size() % primitive_type, 0)
<< "Index buffer size must a multiple of the primitive size!";
const int num_vertices = mesh_3d.vertex_buffer_size() / vertex_size;
for (uint32_t idx : mesh_3d.index_buffer()) {
RET_CHECK_LT(idx, num_vertices)
<< "All mesh indices must refer to an existing vertex!";
}
return absl::OkStatus();
}
absl::Status ValidateFaceGeometry(const proto::FaceGeometry& face_geometry) {
MP_RETURN_IF_ERROR(ValidateMesh3d(face_geometry.mesh())) << "Invalid mesh!";
static constexpr char kInvalid4x4MatrixMessage[] =
"Pose transformation matrix must be a 4x4 matrix!";
const mediapipe::MatrixData& pose_transform_matrix =
face_geometry.pose_transform_matrix();
RET_CHECK_EQ(pose_transform_matrix.rows(), 4) << kInvalid4x4MatrixMessage;
RET_CHECK_EQ(pose_transform_matrix.rows(), 4) << kInvalid4x4MatrixMessage;
RET_CHECK_EQ(pose_transform_matrix.packed_data_size(), 16)
<< kInvalid4x4MatrixMessage;
return absl::OkStatus();
}
absl::Status ValidateGeometryPipelineMetadata(
const proto::GeometryPipelineMetadata& metadata) {
MP_RETURN_IF_ERROR(ValidateMesh3d(metadata.canonical_mesh()))
<< "Invalid canonical mesh!";
RET_CHECK_GT(metadata.procrustes_landmark_basis_size(), 0)
<< "Procrustes landmark basis must be non-empty!";
const int num_vertices =
metadata.canonical_mesh().vertex_buffer_size() /
GetVertexSize(metadata.canonical_mesh().vertex_type());
for (const proto::WeightedLandmarkRef& wlr :
metadata.procrustes_landmark_basis()) {
RET_CHECK_LT(wlr.landmark_id(), num_vertices)
<< "All Procrustes basis indices must refer to an existing canonical "
"mesh vertex!";
RET_CHECK_GE(wlr.weight(), 0.f)
<< "All Procrustes basis landmarks must have a non-negative weight!";
}
return absl::OkStatus();
}
absl::Status ValidateFrameDimensions(int frame_width, int frame_height) {
RET_CHECK_GT(frame_width, 0) << "Frame width must be positive!";
RET_CHECK_GT(frame_height, 0) << "Frame height must be positive!";
return absl::OkStatus();
}
} // namespace mediapipe::tasks::vision::face_geometry

View File

@ -0,0 +1,70 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_TASKS_CC_VISION_FACE_GEOMETRY_LIBS_VALIDATION_UTILS_H_
#define MEDIAPIPE_TASKS_CC_VISION_FACE_GEOMETRY_LIBS_VALIDATION_UTILS_H_
#include "mediapipe/framework/port/status.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/geometry_pipeline_metadata.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d.pb.h"
namespace mediapipe::tasks::vision::face_geometry {
// Validates `perspective_camera`.
//
// Near Z must be greater than 0 with a margin of `1e-9`.
// Far Z must be greater than Near Z with a margin of `1e-9`.
// Vertical FOV must be in range (0, 180) with a margin of `1e-9` on the range
// edges.
absl::Status ValidatePerspectiveCamera(
const proto::PerspectiveCamera& perspective_camera);
// Validates `environment`.
//
// Environment's perspective camera must be valid.
absl::Status ValidateEnvironment(const proto::Environment& environment);
// Validates `mesh_3d`.
//
// Mesh vertex buffer size must a multiple of the vertex size.
// Mesh index buffer size must a multiple of the primitive size.
// All mesh indices must reference an existing mesh vertex.
absl::Status ValidateMesh3d(const proto::Mesh3d& mesh_3d);
// Validates `face_geometry`.
//
// Face mesh must be valid.
// Face pose transformation matrix must be a 4x4 matrix.
absl::Status ValidateFaceGeometry(const proto::FaceGeometry& face_geometry);
// Validates `metadata`.
//
// Canonical face mesh must be valid.
// Procrustes landmark basis must be non-empty.
// All Procrustes basis indices must reference an existing canonical mesh
// vertex.
// All Procrustes basis landmarks must have a non-negative weight.
absl::Status ValidateGeometryPipelineMetadata(
const proto::GeometryPipelineMetadata& metadata);
// Validates frame dimensions.
//
// Both frame width and frame height must be positive.
absl::Status ValidateFrameDimensions(int frame_width, int frame_height);
} // namespace mediapipe::tasks::vision::face_geometry
#endif // MEDIAPIPE_TASKS_CC_VISION_FACE_GEOMETRY_LIBS_VALIDATION_UTILS_H_

View File

@ -0,0 +1,46 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
licenses(["notice"])
package(default_visibility = ["//visibility:public"])
mediapipe_proto_library(
name = "environment_proto",
srcs = ["environment.proto"],
)
mediapipe_proto_library(
name = "face_geometry_proto",
srcs = ["face_geometry.proto"],
deps = [
":mesh_3d_proto",
"//mediapipe/framework/formats:matrix_data_proto",
],
)
mediapipe_proto_library(
name = "geometry_pipeline_metadata_proto",
srcs = ["geometry_pipeline_metadata.proto"],
deps = [
":mesh_3d_proto",
],
)
mediapipe_proto_library(
name = "mesh_3d_proto",
srcs = ["mesh_3d.proto"],
)

View File

@ -0,0 +1,84 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe.tasks.vision.face_geometry.proto;
option java_package = "mediapipe.tasks.vision.facegeometry.proto";
option java_outer_classname = "EnvironmentProto";
// Defines the (0, 0) origin point location of the environment.
//
// The variation in the origin point location can be traced back to the memory
// layout of the camera video frame buffers.
//
// Usually, the memory layout for most CPU (and also some GPU) camera video
// frame buffers results in having the (0, 0) origin point located in the
// Top Left corner.
//
// On the contrary, the memory layout for most GPU camera video frame buffers
// results in having the (0, 0) origin point located in the Bottom Left corner.
//
// Let's consider the following example:
//
// (A) ---------------+
// ___ |
// | (1) | | |
// | / \ | | |
// | |---|===|-| |
// | |---| | | |
// | / \ | | |
// | | | | | |
// | | (2) |=| | |
// | | | | | |
// | |_______| |_| |
// | |@| |@| | | |
// | ___________|_|_ |
// |
// (B) ---------------+
//
// On this example, (1) and (2) have the same X coordinate regardless of the
// origin point location. However, having the origin point located at (A)
// (Top Left corner) results in (1) having a smaller Y coordinate if compared to
// (2). Similarly, having the origin point located at (B) (Bottom Left corner)
// results in (1) having a greater Y coordinate if compared to (2).
//
// Providing the correct origin point location for your environment and making
// sure all the input landmarks are in-sync with this location is crucial
// for receiving the correct output face geometry and visual renders.
enum OriginPointLocation {
BOTTOM_LEFT_CORNER = 1;
TOP_LEFT_CORNER = 2;
}
// The perspective camera is defined through its vertical FOV angle and the
// Z-clipping planes. The aspect ratio is a runtime variable for the face
// geometry module and should be provided alongside the face landmarks in order
// to estimate the face geometry on a given frame.
//
// More info on Perspective Cameras:
// http://www.songho.ca/opengl/gl_projectionmatrix.html#perspective
message PerspectiveCamera {
// `0 < vertical_fov_degrees < 180`.
optional float vertical_fov_degrees = 1;
// `0 < near < far`.
optional float near = 2;
optional float far = 3;
}
message Environment {
optional OriginPointLocation origin_point_location = 1;
optional PerspectiveCamera perspective_camera = 2;
}

View File

@ -0,0 +1,60 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe.tasks.vision.face_geometry.proto;
import "mediapipe/framework/formats/matrix_data.proto";
import "mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d.proto";
option java_package = "mediapipe.tasks.vision.facegeometry.proto";
option java_outer_classname = "FaceGeometryProto";
// Defines the face geometry pipeline estimation result format.
message FaceGeometry {
// Defines a mesh surface for a face. The face mesh vertex IDs are the same as
// the face landmark IDs.
//
// XYZ coordinates exist in the right-handed Metric 3D space configured by an
// environment. UV coodinates are taken from the canonical face mesh model.
//
// XY coordinates are guaranteed to match the screen positions of
// the input face landmarks after (1) being multiplied by the face pose
// transformation matrix and then (2) being projected with a perspective
// camera matrix of the same environment.
//
// NOTE: the triangular topology of the face mesh is only useful when derived
// from the 468 face landmarks, not from the 6 face detection landmarks
// (keypoints). The former don't cover the entire face and this mesh is
// defined here only to comply with the API. It should be considered as
// a placeholder and/or for debugging purposes.
//
// Use the face geometry derived from the face detection landmarks
// (keypoints) for the face pose transformation matrix, not the mesh.
optional Mesh3d mesh = 1;
// Defines a face pose transformation matrix, which provides mapping from
// the static canonical face model to the runtime face. Tries to distinguish
// a head pose change from a facial expression change and to only reflect the
// former.
//
// Is a 4x4 matrix and contains only the following components:
// * Uniform scale
// * Rotation
// * Translation
//
// The last row is guaranteed to be `[0 0 0 1]`.
optional mediapipe.MatrixData pose_transform_matrix = 2;
}

View File

@ -0,0 +1,63 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe.tasks.vision.face_geometry.proto;
import "mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d.proto";
option java_package = "mediapipe.tasks.vision.facegeometry.proto";
option java_outer_classname = "GeometryPipelineMetadataProto";
enum InputSource {
DEFAULT = 0; // FACE_LANDMARK_PIPELINE
FACE_LANDMARK_PIPELINE = 1;
FACE_DETECTION_PIPELINE = 2;
}
message WeightedLandmarkRef {
// Defines the landmark ID. References an existing face landmark ID.
optional uint32 landmark_id = 1;
// Defines the landmark weight. The larger the weight the more influence this
// landmark has in the basis.
//
// Is positive.
optional float weight = 2;
}
// Next field ID: 4
message GeometryPipelineMetadata {
// Defines the source of the input landmarks to let the underlying geometry
// pipeline to adjust in order to produce the best results.
//
// Face landmark pipeline is expected to produce 3D landmarks with relative Z
// coordinate, which is scaled as the X coordinate assuming the weak
// perspective projection camera model.
//
// Face landmark pipeline is expected to produce 2D landmarks with Z
// coordinate being equal to 0.
optional InputSource input_source = 3;
// Defines a mesh surface for a canonical face. The canonical face mesh vertex
// IDs are the same as the face landmark IDs.
//
// XYZ coordinates are defined in centimeter units.
optional Mesh3d canonical_mesh = 1;
// Defines a weighted landmark basis for running the Procrustes solver
// algorithm inside the geometry pipeline.
//
// A good basis sets face landmark weights in way to distinguish a head pose
// change from a facial expression change and to only respond to the former.
repeated WeightedLandmarkRef procrustes_landmark_basis = 2;
}

View File

@ -0,0 +1,41 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe.tasks.vision.face_geometry.proto;
option java_package = "mediapipe.tasks.vision.facegeometry.proto";
option java_outer_classname = "Mesh3dProto";
message Mesh3d {
enum VertexType {
// Is defined by 5 coordinates: Position (XYZ) + Texture coordinate (UV).
VERTEX_PT = 0;
}
enum PrimitiveType {
// Is defined by 3 indices: triangle vertex IDs.
TRIANGLE = 0;
}
optional VertexType vertex_type = 1;
optional PrimitiveType primitive_type = 2;
// Vertex buffer size is a multiple of the vertex size (e.g., 5 for
// VERTEX_PT).
repeated float vertex_buffer = 3;
// Index buffer size is a multiple of the primitive size (e.g., 3 for
// TRIANGLE).
repeated uint32 index_buffer = 4;
}