Blendshapes graph take smoothed face landmarks as input.

PiperOrigin-RevId: 527640341
This commit is contained in:
MediaPipe Team 2023-04-27 11:43:58 -07:00 committed by Copybara-Service
parent 82b8e4d7bf
commit 3ca2427cc8
5 changed files with 178 additions and 187 deletions

View File

@ -73,6 +73,8 @@ cc_library(
":tensors_to_face_landmarks_graph",
"//mediapipe/calculators/core:begin_loop_calculator",
"//mediapipe/calculators/core:end_loop_calculator",
"//mediapipe/calculators/core:get_vector_item_calculator",
"//mediapipe/calculators/core:get_vector_item_calculator_cc_proto",
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/image:image_properties_calculator",
@ -86,6 +88,8 @@ cc_library(
"//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto",
"//mediapipe/calculators/util:landmark_letterbox_removal_calculator",
"//mediapipe/calculators/util:landmark_projection_calculator",
"//mediapipe/calculators/util:landmarks_smoothing_calculator",
"//mediapipe/calculators/util:landmarks_smoothing_calculator_cc_proto",
"//mediapipe/calculators/util:landmarks_to_detection_calculator",
"//mediapipe/calculators/util:rect_transformation_calculator",
"//mediapipe/calculators/util:rect_transformation_calculator_cc_proto",
@ -194,8 +198,6 @@ cc_library(
"//mediapipe/calculators/util:association_norm_rect_calculator",
"//mediapipe/calculators/util:collection_has_min_size_calculator",
"//mediapipe/calculators/util:collection_has_min_size_calculator_cc_proto",
"//mediapipe/calculators/util:landmarks_smoothing_calculator",
"//mediapipe/calculators/util:landmarks_smoothing_calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "mediapipe/calculators/core/get_vector_item_calculator.pb.h"
#include "mediapipe/calculators/util/association_calculator.pb.h"
#include "mediapipe/calculators/util/collection_has_min_size_calculator.pb.h"
#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/classification.pb.h"
@ -172,19 +171,6 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
return absl::OkStatus();
}
void ConfigureLandmarksSmoothingCalculator(
mediapipe::LandmarksSmoothingCalculatorOptions& options) {
// Min cutoff 0.05 results into ~0.01 alpha in landmark EMA filter when
// landmark is static.
options.mutable_one_euro_filter()->set_min_cutoff(0.05f);
// Beta 80.0 in combintation with min_cutoff 0.05 results into ~0.94
// alpha in landmark EMA filter when landmark is moving fast.
options.mutable_one_euro_filter()->set_beta(80.0f);
// Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity
// EMA filter.
options.mutable_one_euro_filter()->set_derivate_cutoff(1.0f);
}
} // namespace
// A "mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph" performs face
@ -464,32 +450,17 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
auto image_size = image_properties.Out(kSizeTag);
// Apply smoothing filter only on the single face landmarks, because
// landmakrs smoothing calculator doesn't support multiple landmarks yet.
// landmarks smoothing calculator doesn't support multiple landmarks yet.
if (face_detector_options.num_faces() == 1) {
// Get the single face landmarks
auto& get_vector_item =
graph.AddNode("GetNormalizedLandmarkListVectorItemCalculator");
get_vector_item.GetOptions<mediapipe::GetVectorItemCalculatorOptions>()
.set_item_index(0);
face_landmarks >> get_vector_item.In(kVectorTag);
auto single_face_landmarks = get_vector_item.Out(kItemTag);
// Apply smoothing filter on face landmarks.
auto& landmarks_smoothing = graph.AddNode("LandmarksSmoothingCalculator");
ConfigureLandmarksSmoothingCalculator(
landmarks_smoothing
.GetOptions<mediapipe::LandmarksSmoothingCalculatorOptions>());
single_face_landmarks >> landmarks_smoothing.In(kNormLandmarksTag);
image_size >> landmarks_smoothing.In(kImageSizeTag);
auto smoothed_single_face_landmarks =
landmarks_smoothing.Out(kNormFilteredLandmarksTag);
// Wrap the single face landmarks into a vector of landmarks.
auto& concatenate_vector =
graph.AddNode("ConcatenateNormalizedLandmarkListVectorCalculator");
smoothed_single_face_landmarks >> concatenate_vector.In("");
face_landmarks = concatenate_vector.Out("")
.Cast<std::vector<NormalizedLandmarkList>>();
face_landmarks_detector_graph
.GetOptions<FaceLandmarksDetectorGraphOptions>()
.set_smooth_landmarks(true);
} else if (face_detector_options.num_faces() > 1 &&
face_landmarks_detector_graph
.GetOptions<FaceLandmarksDetectorGraphOptions>()
.smooth_landmarks()) {
return absl::InvalidArgumentError(
"Currently face landmarks smoothing only support a single face.");
}
if (tasks_options.base_options().use_stream_mode()) {
@ -533,9 +504,10 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
// Back edge.
face_rects_for_next_frame >> previous_loopback.In(kLoopTag);
} else {
// While not in stream mode, the input images are not guaranteed to be in
// series, and we don't want to enable the tracking and rect associations
// between input images. Always use the face detector graph.
// While not in stream mode, the input images are not guaranteed to be
// in series, and we don't want to enable the tracking and rect
// associations between input images. Always use the face detector
// graph.
image_in >> face_detector.In(kImageTag);
if (norm_rect_in) {
*norm_rect_in >> face_detector.In(kNormRectTag);
@ -571,7 +543,8 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
}
// TODO: Replace PassThroughCalculator with a calculator that
// converts the pixel data to be stored on the target storage (CPU vs GPU).
// converts the pixel data to be stored on the target storage (CPU vs
// GPU).
auto& pass_through = graph.AddNode("PassThroughCalculator");
image_in >> pass_through.In("");

View File

@ -19,10 +19,13 @@ limitations under the License.
#include <utility>
#include <vector>
#include "mediapipe/calculators/core/get_vector_item_calculator.h"
#include "mediapipe/calculators/core/get_vector_item_calculator.pb.h"
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_floats_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_landmarks_calculator.pb.h"
#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h"
#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h"
#include "mediapipe/calculators/util/thresholding_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
@ -79,6 +82,9 @@ constexpr char kBatchEndTag[] = "BATCH_END";
constexpr char kItemTag[] = "ITEM";
constexpr char kDetectionTag[] = "DETECTION";
constexpr char kBlendshapesTag[] = "BLENDSHAPES";
constexpr char kNormFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS";
constexpr char kSizeTag[] = "SIZE";
constexpr char kVectorTag[] = "VECTOR";
// a landmarks tensor and a scores tensor
constexpr int kFaceLandmarksOutputTensorsNum = 2;
@ -88,7 +94,6 @@ struct SingleFaceLandmarksOutputs {
Stream<NormalizedRect> rect_next_frame;
Stream<bool> presence;
Stream<float> presence_score;
std::optional<Stream<ClassificationList>> face_blendshapes;
};
struct MultiFaceLandmarksOutputs {
@ -148,6 +153,19 @@ void ConfigureFaceRectTransformationCalculator(
options->set_square_long(true);
}
void ConfigureLandmarksSmoothingCalculator(
mediapipe::LandmarksSmoothingCalculatorOptions& options) {
// Min cutoff 0.05 results into ~0.01 alpha in landmark EMA filter when
// landmark is static.
options.mutable_one_euro_filter()->set_min_cutoff(0.05f);
// Beta 80.0 in combintation with min_cutoff 0.05 results into ~0.94
// alpha in landmark EMA filter when landmark is moving fast.
options.mutable_one_euro_filter()->set_beta(80.0f);
// Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity
// EMA filter.
options.mutable_one_euro_filter()->set_derivate_cutoff(1.0f);
}
} // namespace
// A "mediapipe.tasks.vision.face_landmarker.SingleFaceLandmarksDetectorGraph"
@ -171,62 +189,6 @@ void ConfigureFaceRectTransformationCalculator(
// Boolean value indicates whether the face is present.
// PRESENCE_SCORE - float
// Float value indicates the probability that the face is present.
// BLENDSHAPES - ClassificationList @optional
// Blendshape classification, available when face_blendshapes_graph_options
// is set.
// All 52 blendshape coefficients:
// 0 - _neutral (ignore it)
// 1 - browDownLeft
// 2 - browDownRight
// 3 - browInnerUp
// 4 - browOuterUpLeft
// 5 - browOuterUpRight
// 6 - cheekPuff
// 7 - cheekSquintLeft
// 8 - cheekSquintRight
// 9 - eyeBlinkLeft
// 10 - eyeBlinkRight
// 11 - eyeLookDownLeft
// 12 - eyeLookDownRight
// 13 - eyeLookInLeft
// 14 - eyeLookInRight
// 15 - eyeLookOutLeft
// 16 - eyeLookOutRight
// 17 - eyeLookUpLeft
// 18 - eyeLookUpRight
// 19 - eyeSquintLeft
// 20 - eyeSquintRight
// 21 - eyeWideLeft
// 22 - eyeWideRight
// 23 - jawForward
// 24 - jawLeft
// 25 - jawOpen
// 26 - jawRight
// 27 - mouthClose
// 28 - mouthDimpleLeft
// 29 - mouthDimpleRight
// 30 - mouthFrownLeft
// 31 - mouthFrownRight
// 32 - mouthFunnel
// 33 - mouthLeft
// 34 - mouthLowerDownLeft
// 35 - mouthLowerDownRight
// 36 - mouthPressLeft
// 37 - mouthPressRight
// 38 - mouthPucker
// 39 - mouthRight
// 40 - mouthRollLower
// 41 - mouthRollUpper
// 42 - mouthShrugLower
// 43 - mouthShrugUpper
// 44 - mouthSmileLeft
// 45 - mouthSmileRight
// 46 - mouthStretchLeft
// 47 - mouthStretchRight
// 48 - mouthUpperUpLeft
// 49 - mouthUpperUpRight
// 50 - noseSneerLeft
// 51 - noseSneerRight
//
// Example:
// node {
@ -238,7 +200,6 @@ void ConfigureFaceRectTransformationCalculator(
// output_stream: "FACE_RECT_NEXT_FRAME:face_rect_next_frame"
// output_stream: "PRESENCE:presence"
// output_stream: "PRESENCE_SCORE:presence_score"
// output_stream: "BLENDSHAPES:blendshapes"
// options {
// [mediapipe.tasks.vision.face_landmarker.proto.FaceLandmarksDetectorGraphOptions.ext]
// {
@ -278,10 +239,6 @@ class SingleFaceLandmarksDetectorGraph : public core::ModelTaskGraph {
graph.Out(kFaceRectNextFrameTag).Cast<NormalizedRect>();
outs.presence >> graph.Out(kPresenceTag).Cast<bool>();
outs.presence_score >> graph.Out(kPresenceScoreTag).Cast<float>();
if (outs.face_blendshapes) {
outs.face_blendshapes.value() >>
graph.Out(kBlendshapesTag).Cast<ClassificationList>();
}
return graph.GetConfig();
}
@ -378,7 +335,7 @@ class SingleFaceLandmarksDetectorGraph : public core::ModelTaskGraph {
auto& landmark_projection = graph.AddNode("LandmarkProjectionCalculator");
landmarks_letterbox_removed >> landmark_projection.In(kNormLandmarksTag);
face_rect >> landmark_projection.In(kNormRectTag);
auto projected_landmarks = AllowIf(
Stream<NormalizedLandmarkList> projected_landmarks = AllowIf(
landmark_projection[Output<NormalizedLandmarkList>(kNormLandmarksTag)],
presence, graph);
@ -409,25 +366,11 @@ class SingleFaceLandmarksDetectorGraph : public core::ModelTaskGraph {
AllowIf(face_rect_transformation.Out("").Cast<NormalizedRect>(),
presence, graph);
std::optional<Stream<ClassificationList>> face_blendshapes;
if (subgraph_options.has_face_blendshapes_graph_options()) {
auto& face_blendshapes_graph = graph.AddNode(
"mediapipe.tasks.vision.face_landmarker.FaceBlendshapesGraph");
face_blendshapes_graph.GetOptions<proto::FaceBlendshapesGraphOptions>()
.Swap(subgraph_options.mutable_face_blendshapes_graph_options());
projected_landmarks >> face_blendshapes_graph.In(kLandmarksTag);
image_size >> face_blendshapes_graph.In(kImageSizeTag);
face_blendshapes =
std::make_optional(face_blendshapes_graph.Out(kBlendshapesTag)
.Cast<ClassificationList>());
}
return {{
/* landmarks= */ projected_landmarks,
/* rect_next_frame= */ face_rect_next_frame,
/* presence= */ presence,
/* presence_score= */ presence_score,
/* face_blendshapes= */ face_blendshapes,
}};
}
};
@ -465,6 +408,59 @@ REGISTER_MEDIAPIPE_GRAPH(
// BLENDSHAPES - std::vector<ClassificationList> @optional
// Vector of face blendshape classification, available when
// face_blendshapes_graph_options is set.
// All 52 blendshape coefficients:
// 0 - _neutral (ignore it)
// 1 - browDownLeft
// 2 - browDownRight
// 3 - browInnerUp
// 4 - browOuterUpLeft
// 5 - browOuterUpRight
// 6 - cheekPuff
// 7 - cheekSquintLeft
// 8 - cheekSquintRight
// 9 - eyeBlinkLeft
// 10 - eyeBlinkRight
// 11 - eyeLookDownLeft
// 12 - eyeLookDownRight
// 13 - eyeLookInLeft
// 14 - eyeLookInRight
// 15 - eyeLookOutLeft
// 16 - eyeLookOutRight
// 17 - eyeLookUpLeft
// 18 - eyeLookUpRight
// 19 - eyeSquintLeft
// 20 - eyeSquintRight
// 21 - eyeWideLeft
// 22 - eyeWideRight
// 23 - jawForward
// 24 - jawLeft
// 25 - jawOpen
// 26 - jawRight
// 27 - mouthClose
// 28 - mouthDimpleLeft
// 29 - mouthDimpleRight
// 30 - mouthFrownLeft
// 31 - mouthFrownRight
// 32 - mouthFunnel
// 33 - mouthLeft
// 34 - mouthLowerDownLeft
// 35 - mouthLowerDownRight
// 36 - mouthPressLeft
// 37 - mouthPressRight
// 38 - mouthPucker
// 39 - mouthRight
// 40 - mouthRollLower
// 41 - mouthRollUpper
// 42 - mouthShrugLower
// 43 - mouthShrugUpper
// 44 - mouthSmileLeft
// 45 - mouthSmileRight
// 46 - mouthStretchLeft
// 47 - mouthStretchRight
// 48 - mouthUpperUpLeft
// 49 - mouthUpperUpRight
// 50 - noseSneerLeft
// 51 - noseSneerRight
//
// Example:
// node {
@ -566,8 +562,9 @@ class MultiFaceLandmarksDetectorGraph : public core::ModelTaskGraph {
graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator");
batch_end >> end_loop_landmarks.In(kBatchEndTag);
landmarks >> end_loop_landmarks.In(kItemTag);
auto landmark_lists = end_loop_landmarks.Out(kIterableTag)
.Cast<std::vector<NormalizedLandmarkList>>();
Stream<std::vector<NormalizedLandmarkList>> landmark_lists =
end_loop_landmarks.Out(kIterableTag)
.Cast<std::vector<NormalizedLandmarkList>>();
auto& end_loop_rects_next_frame =
graph.AddNode("EndLoopNormalizedRectCalculator");
@ -576,16 +573,78 @@ class MultiFaceLandmarksDetectorGraph : public core::ModelTaskGraph {
auto face_rects_next_frame = end_loop_rects_next_frame.Out(kIterableTag)
.Cast<std::vector<NormalizedRect>>();
// Apply smoothing filter only on the single face landmarks, because
// landmarks smoothing calculator doesn't support multiple landmarks yet.
// Notice the landmarks smoothing calculator cannot be put inside the for
// loop calculator, because the smoothing calculator utilize the timestamp
// to smoote landmarks across frames but the for loop calculator makes fake
// timestamps for the streams.
if (face_landmark_subgraph
.GetOptions<proto::FaceLandmarksDetectorGraphOptions>()
.smooth_landmarks()) {
// Get the single face landmarks
auto& get_vector_item =
graph.AddNode("GetNormalizedLandmarkListVectorItemCalculator");
get_vector_item.GetOptions<mediapipe::GetVectorItemCalculatorOptions>()
.set_item_index(0);
landmark_lists >> get_vector_item.In(kVectorTag);
Stream<NormalizedLandmarkList> single_landmarks =
get_vector_item.Out(kItemTag).Cast<NormalizedLandmarkList>();
auto& image_properties = graph.AddNode("ImagePropertiesCalculator");
image_in >> image_properties.In(kImageTag);
auto image_size = image_properties.Out(kSizeTag);
// Apply smoothing filter on face landmarks.
auto& landmarks_smoothing = graph.AddNode("LandmarksSmoothingCalculator");
ConfigureLandmarksSmoothingCalculator(
landmarks_smoothing
.GetOptions<mediapipe::LandmarksSmoothingCalculatorOptions>());
single_landmarks >> landmarks_smoothing.In(kNormLandmarksTag);
image_size >> landmarks_smoothing.In(kImageSizeTag);
single_landmarks = landmarks_smoothing.Out(kNormFilteredLandmarksTag)
.Cast<NormalizedLandmarkList>();
// Wrap the single face landmarks into a vector of landmarks.
auto& concatenate_vector =
graph.AddNode("ConcatenateNormalizedLandmarkListVectorCalculator");
single_landmarks >> concatenate_vector.In("");
landmark_lists = concatenate_vector.Out("")
.Cast<std::vector<NormalizedLandmarkList>>();
}
std::optional<Stream<std::vector<ClassificationList>>>
face_blendshapes_vector;
if (face_landmark_subgraph
.GetOptions<proto::FaceLandmarksDetectorGraphOptions>()
.has_face_blendshapes_graph_options()) {
auto blendshapes = face_landmark_subgraph.Out(kBlendshapesTag);
auto& begin_loop_multi_face_landmarks =
graph.AddNode("BeginLoopNormalizedLandmarkListVectorCalculator");
landmark_lists >> begin_loop_multi_face_landmarks.In(kIterableTag);
image_in >> begin_loop_multi_face_landmarks.In(kCloneTag);
auto image = begin_loop_multi_face_landmarks.Out(kCloneTag);
auto batch_end = begin_loop_multi_face_landmarks.Out(kBatchEndTag);
auto landmarks = begin_loop_multi_face_landmarks.Out(kItemTag);
auto& image_properties = graph.AddNode("ImagePropertiesCalculator");
image >> image_properties.In(kImageTag);
auto image_size = image_properties.Out(kSizeTag);
auto& face_blendshapes_graph = graph.AddNode(
"mediapipe.tasks.vision.face_landmarker.FaceBlendshapesGraph");
face_blendshapes_graph.GetOptions<proto::FaceBlendshapesGraphOptions>()
.Swap(face_landmark_subgraph
.GetOptions<proto::FaceLandmarksDetectorGraphOptions>()
.mutable_face_blendshapes_graph_options());
landmarks >> face_blendshapes_graph.In(kLandmarksTag);
image_size >> face_blendshapes_graph.In(kImageSizeTag);
auto face_blendshapes = face_blendshapes_graph.Out(kBlendshapesTag)
.Cast<ClassificationList>();
auto& end_loop_blendshapes =
graph.AddNode("EndLoopClassificationListCalculator");
batch_end >> end_loop_blendshapes.In(kBatchEndTag);
blendshapes >> end_loop_blendshapes.In(kItemTag);
face_blendshapes >> end_loop_blendshapes.In(kItemTag);
face_blendshapes_vector =
std::make_optional(end_loop_blendshapes.Out(kIterableTag)
.Cast<std::vector<ClassificationList>>());

View File

@ -99,8 +99,7 @@ constexpr float kBlendshapesDiffMargin = 0.1;
// Helper function to create a Single Face Landmark TaskRunner.
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSingleFaceLandmarksTaskRunner(
absl::string_view landmarks_model_name,
std::optional<absl::string_view> blendshapes_model_name) {
absl::string_view landmarks_model_name) {
Graph graph;
auto& face_landmark_detection = graph.AddNode(
@ -112,14 +111,6 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSingleFaceLandmarksTaskRunner(
JoinPath("./", kTestDataDirectory, landmarks_model_name));
options->set_min_detection_confidence(0.5);
if (blendshapes_model_name.has_value()) {
options->mutable_face_blendshapes_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(
JoinPath("./", kTestDataDirectory, *blendshapes_model_name));
}
face_landmark_detection.GetOptions<proto::FaceLandmarksDetectorGraphOptions>()
.Swap(options.get());
@ -137,11 +128,6 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSingleFaceLandmarksTaskRunner(
face_landmark_detection.Out(kFaceRectNextFrameTag)
.SetName(kFaceRectNextFrameName) >>
graph[Output<NormalizedRect>(kFaceRectNextFrameTag)];
if (blendshapes_model_name.has_value()) {
face_landmark_detection.Out(kBlendshapesTag).SetName(kBlendshapesName) >>
graph[Output<ClassificationList>(kBlendshapesTag)];
}
return TaskRunner::Create(
graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>());
}
@ -227,8 +213,6 @@ struct SingeFaceTestParams {
std::string test_name;
// The filename of landmarks model name.
std::string landmarks_model_name;
// The filename of blendshape model name.
std::optional<std::string> blendshape_model_name;
// The filename of the test image.
std::string test_image_name;
// RoI on image to detect faces.
@ -237,13 +221,8 @@ struct SingeFaceTestParams {
bool expected_presence;
// The expected output landmarks positions.
NormalizedLandmarkList expected_landmarks;
// The expected output blendshape classification;
std::optional<ClassificationList> expected_blendshapes;
// The max value difference between expected_positions and detected positions.
float landmarks_diff_threshold;
// The max value difference between expected blendshapes and actual
// blendshapes.
float blendshapes_diff_threshold;
};
struct MultiFaceTestParams {
@ -279,8 +258,7 @@ TEST_P(SingleFaceLandmarksDetectionTest, Succeeds) {
GetParam().test_image_name)));
MP_ASSERT_OK_AND_ASSIGN(
auto task_runner,
CreateSingleFaceLandmarksTaskRunner(GetParam().landmarks_model_name,
GetParam().blendshape_model_name));
CreateSingleFaceLandmarksTaskRunner(GetParam().landmarks_model_name));
auto output_packets = task_runner->Process(
{{kImageName, MakePacket<Image>(std::move(image))},
@ -301,15 +279,6 @@ TEST_P(SingleFaceLandmarksDetectionTest, Succeeds) {
Approximately(Partially(EqualsProto(expected_landmarks)),
/*margin=*/kAbsMargin,
/*fraction=*/GetParam().landmarks_diff_threshold));
if (GetParam().expected_blendshapes) {
const ClassificationList& actual_blendshapes =
(*output_packets)[kBlendshapesName].Get<ClassificationList>();
const ClassificationList& expected_blendshapes =
*GetParam().expected_blendshapes;
EXPECT_THAT(actual_blendshapes,
Approximately(EqualsProto(expected_blendshapes),
GetParam().blendshapes_diff_threshold));
}
}
}
@ -360,34 +329,15 @@ TEST_P(MultiFaceLandmarksDetectionTest, Succeeds) {
INSTANTIATE_TEST_SUITE_P(
FaceLandmarksDetectionTest, SingleFaceLandmarksDetectionTest,
Values(SingeFaceTestParams{
/* test_name= */ "PortraitV2",
/* landmarks_model_name= */
kFaceLandmarksV2Model,
/* blendshape_model_name= */ std::nullopt,
/* test_image_name= */ kPortraitImageName,
/* norm_rect= */ MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0),
/* expected_presence= */ true,
/* expected_landmarks= */
GetExpectedLandmarkList(kPortraitExpectedFaceLandmarksName),
/* expected_blendshapes= */ std::nullopt,
/* landmarks_diff_threshold= */ kFractionDiff,
/* blendshapes_diff_threshold= */ kBlendshapesDiffMargin},
SingeFaceTestParams{
/* test_name= */ "PortraitV2WithBlendshapes",
/* landmarks_model_name= */
kFaceLandmarksV2Model,
/* blendshape_model_name= */ kFaceBlendshapesModel,
/* test_image_name= */ kPortraitImageName,
/* norm_rect= */
MakeNormRect(0.48906386, 0.22731927, 0.42905223, 0.34357703,
0.008304443),
/* expected_presence= */ true,
/* expected_landmarks= */
GetExpectedLandmarkList(kPortraitExpectedFaceLandmarksName),
/* expected_blendshapes= */
GetBlendshapes(kPortraitExpectedBlendshapesName),
/* landmarks_diff_threshold= */ kFractionDiff,
/* blendshapes_diff_threshold= */ kBlendshapesDiffMargin}),
/* test_name= */ "PortraitV2",
/* landmarks_model_name= */
kFaceLandmarksV2Model,
/* test_image_name= */ kPortraitImageName,
/* norm_rect= */ MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0),
/* expected_presence= */ true,
/* expected_landmarks= */
GetExpectedLandmarkList(kPortraitExpectedFaceLandmarksName),
/* landmarks_diff_threshold= */ kFractionDiff}),
[](const TestParamInfo<SingleFaceLandmarksDetectionTest::ParamType>& info) {
return info.param.test_name;

View File

@ -37,6 +37,13 @@ message FaceLandmarksDetectorGraphOptions {
// successfully detecting a face in the image.
optional float min_detection_confidence = 2 [default = 0.5];
// Whether to smooth the detected landmarks over timestamps. Note that
// landmarks smoothing is only applicable for a single face. If multiple faces
// landmarks are given, and smooth_landmarks is true, only the first face
// landmarks would be smoothed, and the remaining landmarks are discarded in
// the returned landmarks list.
optional bool smooth_landmarks = 4;
// Optional options for FaceBlendshapeGraph. If this options is set, the
// FaceLandmarksDetectorGraph would output the face blendshapes.
optional FaceBlendshapesGraphOptions face_blendshapes_graph_options = 3;