GeometryPipelineCalculator support single face landmarks input.

PiperOrigin-RevId: 516701488
This commit is contained in:
MediaPipe Team 2023-03-14 20:00:59 -07:00 committed by Copybara-Service
parent 141cf843ae
commit cafff14135
2 changed files with 106 additions and 30 deletions

View File

@ -60,6 +60,7 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core:external_file_handler", "//mediapipe/tasks/cc/core:external_file_handler",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/libs:geometry_pipeline", "//mediapipe/tasks/cc/vision/face_geometry/libs:geometry_pipeline",
@ -69,6 +70,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/face_geometry/proto:geometry_pipeline_metadata_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/proto:geometry_pipeline_metadata_cc_proto",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:str_format",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -18,12 +18,14 @@
#include <vector> #include <vector>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/port/statusor.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/external_file_handler.h" #include "mediapipe/tasks/cc/core/external_file_handler.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h"
@ -41,13 +43,50 @@ static constexpr char kEnvironmentTag[] = "ENVIRONMENT";
static constexpr char kImageSizeTag[] = "IMAGE_SIZE"; static constexpr char kImageSizeTag[] = "IMAGE_SIZE";
static constexpr char kMultiFaceGeometryTag[] = "MULTI_FACE_GEOMETRY"; static constexpr char kMultiFaceGeometryTag[] = "MULTI_FACE_GEOMETRY";
static constexpr char kMultiFaceLandmarksTag[] = "MULTI_FACE_LANDMARKS"; static constexpr char kMultiFaceLandmarksTag[] = "MULTI_FACE_LANDMARKS";
static constexpr char kFaceGeometryTag[] = "FACE_GEOMETRY";
static constexpr char kFaceLandmarksTag[] = "FACE_LANDMARKS";
using ::mediapipe::tasks::vision::face_geometry::proto::Environment; using ::mediapipe::tasks::vision::face_geometry::proto::Environment;
using ::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry; using ::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry;
using ::mediapipe::tasks::vision::face_geometry::proto:: using ::mediapipe::tasks::vision::face_geometry::proto::
GeometryPipelineMetadata; GeometryPipelineMetadata;
// A calculator that renders a visual effect for multiple faces. absl::Status SanityCheck(CalculatorContract* cc) {
if (!(cc->Inputs().HasTag(kFaceLandmarksTag) ^
cc->Inputs().HasTag(kMultiFaceLandmarksTag))) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Only one of %s and %s can be set at a time.",
kFaceLandmarksTag, kMultiFaceLandmarksTag));
}
if (!(cc->Outputs().HasTag(kFaceGeometryTag) ^
cc->Outputs().HasTag(kMultiFaceGeometryTag))) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Only one of %s and %s can be set at a time.",
kFaceGeometryTag, kMultiFaceGeometryTag));
}
if (cc->Inputs().HasTag(kFaceLandmarksTag) !=
cc->Outputs().HasTag(kFaceGeometryTag)) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"%s and %s must both be set or neither be set and a time.",
kFaceLandmarksTag, kFaceGeometryTag));
}
if (cc->Inputs().HasTag(kMultiFaceLandmarksTag) !=
cc->Outputs().HasTag(kMultiFaceGeometryTag)) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"%s and %s must both be set or neither be set and a time.",
kMultiFaceLandmarksTag, kMultiFaceGeometryTag));
}
return absl::OkStatus();
}
// A calculator that renders a visual effect for multiple faces. Support single
// face landmarks or multiple face landmarks.
// //
// Inputs: // Inputs:
// IMAGE_SIZE (`std::pair<int, int>`, required): // IMAGE_SIZE (`std::pair<int, int>`, required):
@ -58,8 +97,12 @@ using ::mediapipe::tasks::vision::face_geometry::proto::
// ratio. If used as-is, the resulting face geometry visualization should be // ratio. If used as-is, the resulting face geometry visualization should be
// happening on a frame with the same ratio as well. // happening on a frame with the same ratio as well.
// //
// MULTI_FACE_LANDMARKS (`std::vector<NormalizedLandmarkList>`, required): // MULTI_FACE_LANDMARKS (`std::vector<NormalizedLandmarkList>`, optional):
// A vector of face landmark lists. // A vector of face landmark lists. If connected, the output stream
// MULTI_FACE_GEOMETRY must be connected.
// FACE_LANDMARKS (NormalizedLandmarkList, optional):
// A NormalizedLandmarkList of single face landmark lists. If connected, the
// output stream FACE_GEOMETRY must be connected.
// //
// Input side packets: // Input side packets:
// ENVIRONMENT (`proto::Environment`, required) // ENVIRONMENT (`proto::Environment`, required)
@ -67,8 +110,10 @@ using ::mediapipe::tasks::vision::face_geometry::proto::
// as well as virtual camera parameters. // as well as virtual camera parameters.
// //
// Output: // Output:
// MULTI_FACE_GEOMETRY (`std::vector<FaceGeometry>`, required): // MULTI_FACE_GEOMETRY (`std::vector<FaceGeometry>`, optional):
// A vector of face geometry data. // A vector of face geometry data if MULTI_FACE_LANDMARKS is connected .
// FACE_GEOMETRY (FaceGeometry, optional):
// A FaceGeometry of the face landmarks if FACE_LANDMARKS is connected.
// //
// Options: // Options:
// metadata_file (`ExternalFile`, optional): // metadata_file (`ExternalFile`, optional):
@ -81,13 +126,21 @@ class GeometryPipelineCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Tag(kEnvironmentTag).Set<Environment>(); cc->InputSidePackets().Tag(kEnvironmentTag).Set<Environment>();
MP_RETURN_IF_ERROR(SanityCheck(cc));
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>(); cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
if (cc->Inputs().HasTag(kMultiFaceLandmarksTag)) {
cc->Inputs() cc->Inputs()
.Tag(kMultiFaceLandmarksTag) .Tag(kMultiFaceLandmarksTag)
.Set<std::vector<mediapipe::NormalizedLandmarkList>>(); .Set<std::vector<mediapipe::NormalizedLandmarkList>>();
cc->Outputs().Tag(kMultiFaceGeometryTag).Set<std::vector<FaceGeometry>>(); cc->Outputs().Tag(kMultiFaceGeometryTag).Set<std::vector<FaceGeometry>>();
return absl::OkStatus(); return absl::OkStatus();
} else {
cc->Inputs()
.Tag(kFaceLandmarksTag)
.Set<mediapipe::NormalizedLandmarkList>();
cc->Outputs().Tag(kFaceGeometryTag).Set<FaceGeometry>();
return absl::OkStatus();
}
} }
absl::Status Open(CalculatorContext* cc) override { absl::Status Open(CalculatorContext* cc) override {
@ -112,7 +165,6 @@ class GeometryPipelineCalculator : public CalculatorBase {
ASSIGN_OR_RETURN(geometry_pipeline_, ASSIGN_OR_RETURN(geometry_pipeline_,
CreateGeometryPipeline(environment, metadata), CreateGeometryPipeline(environment, metadata),
_ << "Failed to create a geometry pipeline!"); _ << "Failed to create a geometry pipeline!");
return absl::OkStatus(); return absl::OkStatus();
} }
@ -121,12 +173,15 @@ class GeometryPipelineCalculator : public CalculatorBase {
// to have a non-empty packet. In case this requirement is not met, there's // to have a non-empty packet. In case this requirement is not met, there's
// nothing to be processed at the current timestamp. // nothing to be processed at the current timestamp.
if (cc->Inputs().Tag(kImageSizeTag).IsEmpty() || if (cc->Inputs().Tag(kImageSizeTag).IsEmpty() ||
cc->Inputs().Tag(kMultiFaceLandmarksTag).IsEmpty()) { (cc->Inputs().Tag(kMultiFaceLandmarksTag).IsEmpty() &&
cc->Inputs().Tag(kFaceLandmarksTag).IsEmpty())) {
return absl::OkStatus(); return absl::OkStatus();
} }
const auto& image_size = const auto& image_size =
cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>(); cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();
if (cc->Inputs().HasTag(kMultiFaceLandmarksTag)) {
const auto& multi_face_landmarks = const auto& multi_face_landmarks =
cc->Inputs() cc->Inputs()
.Tag(kMultiFaceLandmarksTag) .Tag(kMultiFaceLandmarksTag)
@ -147,6 +202,25 @@ class GeometryPipelineCalculator : public CalculatorBase {
.AddPacket(mediapipe::Adopt<std::vector<FaceGeometry>>( .AddPacket(mediapipe::Adopt<std::vector<FaceGeometry>>(
multi_face_geometry.release()) multi_face_geometry.release())
.At(cc->InputTimestamp())); .At(cc->InputTimestamp()));
} else {
const auto& face_landmarks =
cc->Inputs()
.Tag(kMultiFaceLandmarksTag)
.Get<mediapipe::NormalizedLandmarkList>();
ASSIGN_OR_RETURN(
std::vector<FaceGeometry> multi_face_geometry,
geometry_pipeline_->EstimateFaceGeometry(
{face_landmarks}, //
/*frame_width*/ image_size.first,
/*frame_height*/ image_size.second),
_ << "Failed to estimate face geometry for multiple faces!");
cc->Outputs()
.Tag(kFaceGeometryTag)
.AddPacket(mediapipe::MakePacket<FaceGeometry>(multi_face_geometry[0])
.At(cc->InputTimestamp()));
}
return absl::OkStatus(); return absl::OkStatus();
} }