Add optional face blendshapes to face landmarks detector graph.

PiperOrigin-RevId: 513867488
This commit is contained in:
MediaPipe Team 2023-03-03 10:46:38 -08:00 committed by Copybara-Service
parent c9c1bf21ae
commit b7ec83efb5
6 changed files with 575 additions and 48 deletions

View File

@ -69,12 +69,14 @@ cc_library(
name = "face_landmarks_detector_graph",
srcs = ["face_landmarks_detector_graph.cc"],
deps = [
":face_blendshapes_graph",
":tensors_to_face_landmarks_graph",
"//mediapipe/calculators/core:begin_loop_calculator",
"//mediapipe/calculators/core:end_loop_calculator",
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/tensor:tensors_to_floats_calculator",
"//mediapipe/calculators/tensor:tensors_to_floats_calculator_cc_proto",
@ -89,8 +91,11 @@ cc_library(
"//mediapipe/calculators/util:rect_transformation_calculator_cc_proto",
"//mediapipe/calculators/util:thresholding_calculator",
"//mediapipe/calculators/util:thresholding_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
@ -101,9 +106,11 @@ cc_library(
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:tensors_to_face_landmarks_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"//mediapipe/util:graph_builder_utils",
],
alwayslink = 1,
)

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <memory>
#include <optional>
#include <type_traits>
#include <utility>
#include <vector>
@ -26,6 +27,9 @@ limitations under the License.
#include "mediapipe/calculators/util/thresholding_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_options.pb.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
@ -36,9 +40,11 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/tensors_to_face_landmarks_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "mediapipe/util/graph_builder_utils.h"
namespace mediapipe {
namespace tasks {
@ -72,6 +78,7 @@ constexpr char kIterableTag[] = "ITERABLE";
constexpr char kBatchEndTag[] = "BATCH_END";
constexpr char kItemTag[] = "ITEM";
constexpr char kDetectionTag[] = "DETECTION";
constexpr char kBlendshapesTag[] = "BLENDSHAPES";
// a landmarks tensor and a scores tensor
constexpr int kFaceLandmarksOutputTensorsNum = 2;
@ -83,6 +90,7 @@ struct SingleFaceLandmarksOutputs {
Stream<NormalizedRect> rect_next_frame;
Stream<bool> presence;
Stream<float> presence_score;
std::optional<Stream<ClassificationList>> face_blendshapes;
};
struct MultiFaceLandmarksOutputs {
@ -90,6 +98,7 @@ struct MultiFaceLandmarksOutputs {
Stream<std::vector<NormalizedRect>> rects_next_frame;
Stream<std::vector<bool>> presences;
Stream<std::vector<float>> presence_scores;
std::optional<Stream<std::vector<ClassificationList>>> face_blendshapes;
};
absl::Status SanityCheckOptions(
@ -180,6 +189,62 @@ bool IsAttentionModel(const core::ModelResources& model_resources) {
// Boolean value indicates whether the face is present.
// PRESENCE_SCORE - float
// Float value indicates the probability that the face is present.
// BLENDSHAPES - ClassificationList @optional
// Blendshape classification, available when face_blendshapes_graph_options
// is set.
// All 52 blendshape coefficients:
// 0 - _neutral (ignore it)
// 1 - browDownLeft
// 2 - browDownRight
// 3 - browInnerUp
// 4 - browOuterUpLeft
// 5 - browOuterUpRight
// 6 - cheekPuff
// 7 - cheekSquintLeft
// 8 - cheekSquintRight
// 9 - eyeBlinkLeft
// 10 - eyeBlinkRight
// 11 - eyeLookDownLeft
// 12 - eyeLookDownRight
// 13 - eyeLookInLeft
// 14 - eyeLookInRight
// 15 - eyeLookOutLeft
// 16 - eyeLookOutRight
// 17 - eyeLookUpLeft
// 18 - eyeLookUpRight
// 19 - eyeSquintLeft
// 20 - eyeSquintRight
// 21 - eyeWideLeft
// 22 - eyeWideRight
// 23 - jawForward
// 24 - jawLeft
// 25 - jawOpen
// 26 - jawRight
// 27 - mouthClose
// 28 - mouthDimpleLeft
// 29 - mouthDimpleRight
// 30 - mouthFrownLeft
// 31 - mouthFrownRight
// 32 - mouthFunnel
// 33 - mouthLeft
// 34 - mouthLowerDownLeft
// 35 - mouthLowerDownRight
// 36 - mouthPressLeft
// 37 - mouthPressRight
// 38 - mouthPucker
// 39 - mouthRight
// 40 - mouthRollLower
// 41 - mouthRollUpper
// 42 - mouthShrugLower
// 43 - mouthShrugUpper
// 44 - mouthSmileLeft
// 45 - mouthSmileRight
// 46 - mouthStretchLeft
// 47 - mouthStretchRight
// 48 - mouthUpperUpLeft
// 49 - mouthUpperUpRight
// 50 - noseSneerLeft
// 51 - noseSneerRight
//
// Example:
// node {
@ -191,6 +256,7 @@ bool IsAttentionModel(const core::ModelResources& model_resources) {
// output_stream: "FACE_RECT_NEXT_FRAME:face_rect_next_frame"
// output_stream: "PRESENCE:presence"
// output_stream: "PRESENCE_SCORE:presence_score"
// output_stream: "BLENDSHAPES:blendshapes"
// options {
// [mediapipe.tasks.vision.face_landmarker.proto.FaceLandmarksDetectorGraphOptions.ext]
// {
@ -200,6 +266,13 @@ bool IsAttentionModel(const core::ModelResources& model_resources) {
// }
// }
// min_detection_confidence: 0.5
// face_blendshapes_graph_options {
// base_options {
// model_asset {
// file_name: "face_blendshape.tflite"
// }
// }
// }
// }
// }
// }
@ -214,7 +287,7 @@ class SingleFaceLandmarksDetectorGraph : public core::ModelTaskGraph {
ASSIGN_OR_RETURN(
auto outs,
BuildSingleFaceLandmarksDetectorGraph(
sc->Options<proto::FaceLandmarksDetectorGraphOptions>(),
*sc->MutableOptions<proto::FaceLandmarksDetectorGraphOptions>(),
*model_resources, graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
outs.landmarks >>
@ -223,6 +296,10 @@ class SingleFaceLandmarksDetectorGraph : public core::ModelTaskGraph {
graph.Out(kFaceRectNextFrameTag).Cast<NormalizedRect>();
outs.presence >> graph.Out(kPresenceTag).Cast<bool>();
outs.presence_score >> graph.Out(kPresenceScoreTag).Cast<float>();
if (outs.face_blendshapes) {
outs.face_blendshapes.value() >>
graph.Out(kBlendshapesTag).Cast<ClassificationList>();
}
return graph.GetConfig();
}
@ -239,7 +316,7 @@ class SingleFaceLandmarksDetectorGraph : public core::ModelTaskGraph {
// graph: the mediapipe graph instance to be updated.
absl::StatusOr<SingleFaceLandmarksOutputs>
BuildSingleFaceLandmarksDetectorGraph(
const proto::FaceLandmarksDetectorGraphOptions& subgraph_options,
proto::FaceLandmarksDetectorGraphOptions& subgraph_options,
const core::ModelResources& model_resources, Stream<Image> image_in,
Stream<NormalizedRect> face_rect, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options));
@ -351,11 +428,26 @@ class SingleFaceLandmarksDetectorGraph : public core::ModelTaskGraph {
auto face_rect_next_frame =
AllowIf(face_rect_transformation.Out("").Cast<NormalizedRect>(),
presence, graph);
std::optional<Stream<ClassificationList>> face_blendshapes;
if (subgraph_options.has_face_blendshapes_graph_options()) {
auto& face_blendshapes_graph = graph.AddNode(
"mediapipe.tasks.vision.face_landmarker.FaceBlendshapesGraph");
face_blendshapes_graph.GetOptions<proto::FaceBlendshapesGraphOptions>()
.Swap(subgraph_options.mutable_face_blendshapes_graph_options());
projected_landmarks >> face_blendshapes_graph.In(kLandmarksTag);
image_size >> face_blendshapes_graph.In(kImageSizeTag);
face_blendshapes =
std::make_optional(face_blendshapes_graph.Out(kBlendshapesTag)
.Cast<ClassificationList>());
}
return {{
/* landmarks= */ projected_landmarks,
/* rect_next_frame= */ face_rect_next_frame,
/* presence= */ presence,
/* presence_score= */ presence_score,
/* face_blendshapes= */ face_blendshapes,
}};
}
};
@ -390,6 +482,9 @@ REGISTER_MEDIAPIPE_GRAPH(
// Vector of boolean value indicates whether the face is present.
// PRESENCE_SCORE - std::vector<float>
// Vector of float value indicates the probability that the face is present.
// BLENDSHAPES - std::vector<ClassificationList> @optional
// Vector of face blendshape classification, available when
// face_blendshapes_graph_options is set.
//
// Example:
// node {
@ -401,6 +496,7 @@ REGISTER_MEDIAPIPE_GRAPH(
// output_stream: "FACE_RECTS_NEXT_FRAME:face_rects_next_frame"
// output_stream: "PRESENCE:presence"
// output_stream: "PRESENCE_SCORE:presence_score"
// output_stream: "BLENDSHAPES:blendshapes"
// options {
// [mediapipe.tasks.vision.face_landmarker.proto.FaceLandmarksDetectorGraphOptions.ext]
// {
@ -410,6 +506,13 @@ REGISTER_MEDIAPIPE_GRAPH(
// }
// }
// min_detection_confidence: 0.5
// face_blendshapes_graph_options {
// base_options {
// model_asset {
// file_name: "face_blendshape.tflite"
// }
// }
// }
// }
// }
// }
@ -421,7 +524,7 @@ class MultiFaceLandmarksDetectorGraph : public core::ModelTaskGraph {
ASSIGN_OR_RETURN(
auto outs,
BuildFaceLandmarksDetectorGraph(
sc->Options<proto::FaceLandmarksDetectorGraphOptions>(),
*sc->MutableOptions<proto::FaceLandmarksDetectorGraphOptions>(),
graph[Input<Image>(kImageTag)],
graph[Input<std::vector<NormalizedRect>>(kNormRectTag)], graph));
outs.landmarks_lists >> graph.Out(kNormLandmarksTag)
@ -431,13 +534,16 @@ class MultiFaceLandmarksDetectorGraph : public core::ModelTaskGraph {
outs.presences >> graph.Out(kPresenceTag).Cast<std::vector<bool>>();
outs.presence_scores >>
graph.Out(kPresenceScoreTag).Cast<std::vector<float>>();
if (outs.face_blendshapes) {
outs.face_blendshapes.value() >>
graph.Out(kBlendshapesTag).Cast<std::vector<ClassificationList>>();
}
return graph.GetConfig();
}
private:
absl::StatusOr<MultiFaceLandmarksOutputs> BuildFaceLandmarksDetectorGraph(
const proto::FaceLandmarksDetectorGraphOptions& subgraph_options,
proto::FaceLandmarksDetectorGraphOptions& subgraph_options,
Stream<Image> image_in,
Stream<std::vector<NormalizedRect>> multi_face_rects, Graph& graph) {
auto& face_landmark_subgraph = graph.AddNode(
@ -445,7 +551,7 @@ class MultiFaceLandmarksDetectorGraph : public core::ModelTaskGraph {
"SingleFaceLandmarksDetectorGraph");
face_landmark_subgraph
.GetOptions<proto::FaceLandmarksDetectorGraphOptions>()
.CopyFrom(subgraph_options);
.Swap(&subgraph_options);
auto& begin_loop_multi_face_rects =
graph.AddNode("BeginLoopNormalizedRectCalculator");
@ -490,11 +596,27 @@ class MultiFaceLandmarksDetectorGraph : public core::ModelTaskGraph {
auto face_rects_next_frame = end_loop_rects_next_frame.Out(kIterableTag)
.Cast<std::vector<NormalizedRect>>();
std::optional<Stream<std::vector<ClassificationList>>>
face_blendshapes_vector;
if (face_landmark_subgraph
.GetOptions<proto::FaceLandmarksDetectorGraphOptions>()
.has_face_blendshapes_graph_options()) {
auto blendshapes = face_landmark_subgraph.Out(kBlendshapesTag);
auto& end_loop_blendshapes =
graph.AddNode("EndLoopClassificationListCalculator");
batch_end >> end_loop_blendshapes.In(kBatchEndTag);
blendshapes >> end_loop_blendshapes.In(kItemTag);
face_blendshapes_vector =
std::make_optional(end_loop_blendshapes.Out(kIterableTag)
.Cast<std::vector<ClassificationList>>());
}
return {{
/* landmarks_lists= */ landmark_lists,
/* face_rects_next_frame= */ face_rects_next_frame,
/* presences= */ presences,
/* presence_scores= */ presence_scores,
/* face_blendshapes= */ face_blendshapes_vector,
}};
}
};

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <optional>
#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
@ -21,6 +23,7 @@ limitations under the License.
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
@ -33,6 +36,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
@ -70,6 +74,9 @@ constexpr char kPortraitExpectedFaceLandamrksName[] =
"portrait_expected_face_landmarks.pbtxt";
constexpr char kPortraitExpectedFaceLandamrksWithAttentionName[] =
"portrait_expected_face_landmarks_with_attention.pbtxt";
constexpr char kFaceBlendshapesModel[] = "face_blendshapes.tflite";
constexpr char kPortraitExpectedBlendshapesName[] =
"portrait_expected_blendshapes_with_attention.pbtxt";
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageName[] = "image";
@ -86,13 +93,17 @@ constexpr char kPresenceTag[] = "PRESENCE";
constexpr char kPresenceName[] = "presence";
constexpr char kPresenceScoreTag[] = "PRESENCE_SCORE";
constexpr char kPresenceScoreName[] = "presence_score";
constexpr char kBlendshapesTag[] = "BLENDSHAPES";
constexpr char kBlendshapesName[] = "blendshapes";
constexpr float kFractionDiff = 0.05; // percentage
constexpr float kAbsMargin = 0.03;
constexpr float kBlendshapesDiffMargin = 0.1;
// Helper function to create a Single Face Landmark TaskRunner.
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSingleFaceLandmarksTaskRunner(
absl::string_view model_name) {
absl::string_view landmarks_model_name,
std::optional<absl::string_view> blendshapes_model_name) {
Graph graph;
auto& face_landmark_detection = graph.AddNode(
@ -101,8 +112,17 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSingleFaceLandmarksTaskRunner(
auto options = std::make_unique<proto::FaceLandmarksDetectorGraphOptions>();
options->mutable_base_options()->mutable_model_asset()->set_file_name(
JoinPath("./", kTestDataDirectory, model_name));
JoinPath("./", kTestDataDirectory, landmarks_model_name));
options->set_min_detection_confidence(0.5);
if (blendshapes_model_name.has_value()) {
options->mutable_face_blendshapes_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(
JoinPath("./", kTestDataDirectory, *blendshapes_model_name));
}
face_landmark_detection.GetOptions<proto::FaceLandmarksDetectorGraphOptions>()
.Swap(options.get());
@ -120,6 +140,10 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSingleFaceLandmarksTaskRunner(
face_landmark_detection.Out(kFaceRectNextFrameTag)
.SetName(kFaceRectNextFrameName) >>
graph[Output<NormalizedRect>(kFaceRectNextFrameTag)];
if (blendshapes_model_name.has_value()) {
face_landmark_detection.Out(kBlendshapesTag).SetName(kBlendshapesName) >>
graph[Output<ClassificationList>(kBlendshapesTag)];
}
return TaskRunner::Create(
graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>());
@ -127,7 +151,8 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSingleFaceLandmarksTaskRunner(
// Helper function to create a Multi Face Landmark TaskRunner.
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateMultiFaceLandmarksTaskRunner(
absl::string_view model_name) {
absl::string_view landmarks_model_name,
std::optional<absl::string_view> blendshapes_model_name) {
Graph graph;
auto& face_landmark_detection = graph.AddNode(
@ -136,8 +161,15 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateMultiFaceLandmarksTaskRunner(
auto options = std::make_unique<proto::FaceLandmarksDetectorGraphOptions>();
options->mutable_base_options()->mutable_model_asset()->set_file_name(
JoinPath("./", kTestDataDirectory, model_name));
JoinPath("./", kTestDataDirectory, landmarks_model_name));
options->set_min_detection_confidence(0.5);
if (blendshapes_model_name.has_value()) {
options->mutable_face_blendshapes_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(
JoinPath("./", kTestDataDirectory, *blendshapes_model_name));
}
face_landmark_detection.GetOptions<proto::FaceLandmarksDetectorGraphOptions>()
.Swap(options.get());
@ -156,6 +188,10 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateMultiFaceLandmarksTaskRunner(
face_landmark_detection.Out(kFaceRectsNextFrameTag)
.SetName(kFaceRectsNextFrameName) >>
graph[Output<std::vector<NormalizedRect>>(kFaceRectsNextFrameTag)];
if (blendshapes_model_name.has_value()) {
face_landmark_detection.Out(kBlendshapesTag).SetName(kBlendshapesName) >>
graph[Output<ClassificationList>(kBlendshapesTag)];
}
return TaskRunner::Create(
graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>());
@ -168,6 +204,13 @@ NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) {
return expected_landmark_list;
}
ClassificationList GetBlendshapes(absl::string_view filename) {
ClassificationList blendshapes;
MP_EXPECT_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, filename),
&blendshapes, Defaults()));
return blendshapes;
}
// Helper function to construct NormalizeRect proto.
NormalizedRect MakeNormRect(float x_center, float y_center, float width,
float height, float rotation) {
@ -185,8 +228,10 @@ NormalizedRect MakeNormRect(float x_center, float y_center, float width,
struct SingeFaceTestParams {
// The name of this test, for convenience when displaying test results.
std::string test_name;
// The filename of the model to test.
std::string input_model_name;
// The filename of landmarks model name.
std::string landmarks_model_name;
// The filename of blendshape model name.
std::optional<std::string> blendshape_model_name;
// The filename of the test image.
std::string test_image_name;
// RoI on image to detect hands.
@ -195,15 +240,22 @@ struct SingeFaceTestParams {
bool expected_presence;
// The expected output landmarks positions.
NormalizedLandmarkList expected_landmarks;
// The expected output blendshape classification;
std::optional<ClassificationList> expected_blendshapes;
// The max value difference between expected_positions and detected positions.
float landmarks_diff_threshold;
// The max value difference between expected blendshapes and actual
// blendshapes.
float blendshapes_diff_threshold;
};
struct MultiFaceTestParams {
// The name of this test, for convenience when displaying test results.
std::string test_name;
// The filename of the model to test.
std::string input_model_name;
// The filename of landmarks model name.
std::string landmarks_model_name;
// The filename of blendshape model name.
std::optional<std::string> blendshape_model_name;
// The filename of the test image.
std::string test_image_name;
// RoI on image to detect hands.
@ -212,8 +264,13 @@ struct MultiFaceTestParams {
std::vector<bool> expected_presence;
// The expected output landmarks positions.
std::optional<std::vector<NormalizedLandmarkList>> expected_landmarks_lists;
// The expected output blendshape classification;
std::optional<std::vector<ClassificationList>> expected_blendshapes;
// The max value difference between expected_positions and detected positions.
float landmarks_diff_threshold;
// The max value difference between expected blendshapes and actual
// blendshapes.
float blendshapes_diff_threshold;
};
class SingleFaceLandmarksDetectionTest
@ -223,8 +280,10 @@ TEST_P(SingleFaceLandmarksDetectionTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateSingleFaceLandmarksTaskRunner(
GetParam().input_model_name));
MP_ASSERT_OK_AND_ASSIGN(
auto task_runner,
CreateSingleFaceLandmarksTaskRunner(GetParam().landmarks_model_name,
GetParam().blendshape_model_name));
auto output_packets = task_runner->Process(
{{kImageName, MakePacket<Image>(std::move(image))},
@ -246,6 +305,15 @@ TEST_P(SingleFaceLandmarksDetectionTest, Succeeds) {
Approximately(Partially(EqualsProto(expected_landmarks)),
/*margin=*/kAbsMargin,
/*fraction=*/GetParam().landmarks_diff_threshold));
if (GetParam().expected_blendshapes) {
const ClassificationList& actual_blendshapes =
(*output_packets)[kBlendshapesName].Get<ClassificationList>();
const ClassificationList& expected_blendshapes =
*GetParam().expected_blendshapes;
EXPECT_THAT(actual_blendshapes,
Approximately(EqualsProto(expected_blendshapes),
GetParam().blendshapes_diff_threshold));
}
}
}
@ -256,8 +324,10 @@ TEST_P(MultiFaceLandmarksDetectionTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateMultiFaceLandmarksTaskRunner(
GetParam().input_model_name));
MP_ASSERT_OK_AND_ASSIGN(
auto task_runner,
CreateMultiFaceLandmarksTaskRunner(GetParam().landmarks_model_name,
GetParam().blendshape_model_name));
auto output_packets = task_runner->Process(
{{kImageName, MakePacket<Image>(std::move(image))},
@ -278,29 +348,63 @@ TEST_P(MultiFaceLandmarksDetectionTest, Succeeds) {
/*fraction=*/GetParam().landmarks_diff_threshold),
*GetParam().expected_landmarks_lists));
}
if (GetParam().expected_blendshapes) {
const std::vector<ClassificationList>& actual_blendshapes =
(*output_packets)[kBlendshapesName]
.Get<std::vector<ClassificationList>>();
const std::vector<ClassificationList>& expected_blendshapes =
*GetParam().expected_blendshapes;
EXPECT_THAT(actual_blendshapes,
Pointwise(Approximately(EqualsProto(),
GetParam().blendshapes_diff_threshold),
expected_blendshapes));
}
}
INSTANTIATE_TEST_SUITE_P(
FaceLandmarksDetectionTest, SingleFaceLandmarksDetectionTest,
Values(SingeFaceTestParams{
/* test_name= */ "Portrait",
/*input_model_name= */ kFaceLandmarksDetectionModel,
/*test_image_name=*/kPortraitImageName,
/*norm_rect= */ MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0),
/*expected_presence = */ true,
/*expected_landmarks = */
/* landmarks_model_name= */ kFaceLandmarksDetectionModel,
/* blendshape_model_name= */ std::nullopt,
/* test_image_name=*/kPortraitImageName,
/* norm_rect= */ MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0),
/* expected_presence= */ true,
/* expected_landmarks= */
GetExpectedLandmarkList(kPortraitExpectedFaceLandamrksName),
/*landmarks_diff_threshold = */ kFractionDiff},
/* expected_blendshapes= */ std::nullopt,
/* landmarks_diff_threshold= */ kFractionDiff,
/* blendshapes_diff_threshold= */ kBlendshapesDiffMargin},
SingeFaceTestParams{
/* test_name= */ "PortraitWithAttention",
/*input_model_name= */ kFaceLandmarksDetectionWithAttentionModel,
/*test_image_name=*/kPortraitImageName,
/*norm_rect= */ MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0),
/*expected_presence = */ true,
/*expected_landmarks = */
/* landmarks_model_name= */
kFaceLandmarksDetectionWithAttentionModel,
/* blendshape_model_name= */ std::nullopt,
/* test_image_name= */ kPortraitImageName,
/* norm_rect= */ MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0),
/* expected_presence= */ true,
/* expected_landmarks= */
GetExpectedLandmarkList(
kPortraitExpectedFaceLandamrksWithAttentionName),
/*landmarks_diff_threshold = */ kFractionDiff}),
/* expected_blendshapes= */ std::nullopt,
/* landmarks_diff_threshold= */ kFractionDiff,
/* blendshapes_diff_threshold= */ kBlendshapesDiffMargin},
SingeFaceTestParams{
/* test_name= */ "PortraitWithAttentionWithBlendshapes",
/* landmarks_model_name= */
kFaceLandmarksDetectionWithAttentionModel,
/* blendshape_model_name= */ kFaceBlendshapesModel,
/* test_image_name= */ kPortraitImageName,
/* norm_rect= */ MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0),
/* expected_presence= */ true,
/* expected_landmarks= */
GetExpectedLandmarkList(
kPortraitExpectedFaceLandamrksWithAttentionName),
/* expected_blendshapes= */
GetBlendshapes(kPortraitExpectedBlendshapesName),
/* landmarks_diff_threshold= */ kFractionDiff,
/* blendshapes_diff_threshold= */ kBlendshapesDiffMargin}),
[](const TestParamInfo<SingleFaceLandmarksDetectionTest::ParamType>& info) {
return info.param.test_name;
});
@ -310,31 +414,57 @@ INSTANTIATE_TEST_SUITE_P(
Values(
MultiFaceTestParams{
/* test_name= */ "Portrait",
/*input_model_name= */ kFaceLandmarksDetectionModel,
/*test_image_name=*/kPortraitImageName,
/*norm_rects= */ {MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0)},
/*expected_presence = */ {true},
/*expected_landmarks_list = */
/* landmarks_model_name= */ kFaceLandmarksDetectionModel,
/* blendshape_model_name= */ std::nullopt,
/* test_image_name= */ kPortraitImageName,
/* norm_rects= */ {MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0)},
/* expected_presence= */ {true},
/* expected_landmarks_list= */
{{GetExpectedLandmarkList(kPortraitExpectedFaceLandamrksName)}},
/*landmarks_diff_threshold = */ kFractionDiff},
/* expected_blendshapes= */ std::nullopt,
/* landmarks_diff_threshold= */ kFractionDiff,
/* blendshapes_diff_threshold= */ kBlendshapesDiffMargin},
MultiFaceTestParams{
/* test_name= */ "PortraitWithAttention",
/*input_model_name= */ kFaceLandmarksDetectionWithAttentionModel,
/*test_image_name=*/kPortraitImageName,
/*norm_rects= */ {MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0)},
/*expected_presence = */ {true},
/*expected_landmarks_list = */
/* landmarks_model_name= */
kFaceLandmarksDetectionWithAttentionModel,
/* blendshape_model_name= */ std::nullopt,
/* test_image_name= */ kPortraitImageName,
/* norm_rects= */ {MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0)},
/* expected_presence= */ {true},
/* expected_landmarks_list= */
{{GetExpectedLandmarkList(
kPortraitExpectedFaceLandamrksWithAttentionName)}},
/*landmarks_diff_threshold = */ kFractionDiff},
/* expected_blendshapes= */ std::nullopt,
/* landmarks_diff_threshold= */ kFractionDiff,
/* blendshapes_diff_threshold= */ kBlendshapesDiffMargin},
MultiFaceTestParams{
/* test_name= */ "PortraitWithAttentionWithBlendshapes",
/* landmarks_model_name= */
kFaceLandmarksDetectionWithAttentionModel,
/* blendshape_model_name= */ kFaceBlendshapesModel,
/* test_image_name= */ kPortraitImageName,
/* norm_rects= */ {MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0)},
/* expected_presence= */ {true},
/* expected_landmarks_list= */
{{GetExpectedLandmarkList(
kPortraitExpectedFaceLandamrksWithAttentionName)}},
/* expected_blendshapes= */
{{GetBlendshapes(kPortraitExpectedBlendshapesName)}},
/* landmarks_diff_threshold= */ kFractionDiff,
/* blendshapes_diff_threshold= */ kBlendshapesDiffMargin},
MultiFaceTestParams{
/* test_name= */ "NoFace",
/*input_model_name= */ kFaceLandmarksDetectionModel,
/*test_image_name=*/kCatImageName,
/*norm_rects= */ {MakeNormRect(0.5, 0.5, 1.0, 1.0, 0)},
/*expected_presence = */ {false},
/*expected_landmarks_list = */ std::nullopt,
/*landmarks_diff_threshold = */ kFractionDiff}),
/* landmarks_model_name= */
kFaceLandmarksDetectionModel,
/* blendshape_model_name= */ std::nullopt,
/* test_image_name= */ kCatImageName,
/* norm_rects= */ {MakeNormRect(0.5, 0.5, 1.0, 1.0, 0)},
/* expected_presence= */ {false},
/* expected_landmarks_list= */ std::nullopt,
/* expected_blendshapes= */ std::nullopt,
/*landmarks_diff_threshold= */ kFractionDiff,
/*blendshapes_diff_threshold= */ kBlendshapesDiffMargin}),
[](const TestParamInfo<MultiFaceLandmarksDetectionTest::ParamType>& info) {
return info.param.test_name;
});

View File

@ -34,6 +34,7 @@ mediapipe_proto_library(
name = "face_landmarks_detector_graph_options_proto",
srcs = ["face_landmarks_detector_graph_options.proto"],
deps = [
":face_blendshapes_graph_options_proto",
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",

View File

@ -20,6 +20,7 @@ package mediapipe.tasks.vision.face_landmarker.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.facelandmarker.proto";
option java_outer_classname = "FaceLandmarksDetectorGraphOptionsProto";
@ -35,4 +36,8 @@ message FaceLandmarksDetectorGraphOptions {
// Minimum confidence value ([0.0, 1.0]) for confidence score to be considered
// successfully detecting a face in the image.
optional float min_detection_confidence = 2 [default = 0.5];
// Optional options for FaceBlendshapeGraph. If this options is set, the
// FaceLandmarksDetectorGraph would output the face blendshapes.
optional FaceBlendshapesGraphOptions face_blendshapes_graph_options = 3;
}

View File

@ -0,0 +1,262 @@
# proto-file: mediapipe/framework/formats/classification.proto
# proto-message: ClassificationList
classification {
index: 0
score: 4.9559007e-06
label: "_neutral"
}
classification {
index: 1
score: 0.22943014
label: "browDownLeft"
}
classification {
index: 2
score: 0.22297752
label: "browDownRight"
}
classification {
index: 3
score: 0.015948873
label: "browInnerUp"
}
classification {
index: 4
score: 0.006946607
label: "browOuterUpLeft"
}
classification {
index: 5
score: 0.0070318673
label: "browOuterUpRight"
}
classification {
index: 6
score: 0.0013679645
label: "cheekPuff"
}
classification {
index: 7
score: 7.1003383e-06
label: "cheekSquintLeft"
}
classification {
index: 8
score: 5.78299e-06
label: "cheekSquintRight"
}
classification {
index: 9
score: 0.20132238
label: "eyeBlinkLeft"
}
classification {
index: 10
score: 0.16521452
label: "eyeBlinkRight"
}
classification {
index: 11
score: 0.03764786
label: "eyeLookDownLeft"
}
classification {
index: 12
score: 0.04828824
label: "eyeLookDownRight"
}
classification {
index: 13
score: 0.016539993
label: "eyeLookInLeft"
}
classification {
index: 14
score: 0.20026363
label: "eyeLookInRight"
}
classification {
index: 15
score: 0.21363346
label: "eyeLookOutLeft"
}
classification {
index: 16
score: 0.024430025
label: "eyeLookOutRight"
}
classification {
index: 17
score: 0.30147508
label: "eyeLookUpLeft"
}
classification {
index: 18
score: 0.28701693
label: "eyeLookUpRight"
}
classification {
index: 19
score: 0.67143106
label: "eyeSquintLeft"
}
classification {
index: 20
score: 0.5306328
label: "eyeSquintRight"
}
classification {
index: 21
score: 0.0041342233
label: "eyeWideLeft"
}
classification {
index: 22
score: 0.005231879
label: "eyeWideRight"
}
classification {
index: 23
score: 0.009427094
label: "jawForward"
}
classification {
index: 24
score: 0.0015789346
label: "jawLeft"
}
classification {
index: 25
score: 0.073719256
label: "jawOpen"
}
classification {
index: 26
score: 0.00046979196
label: "jawRight"
}
classification {
index: 27
score: 0.0011400756
label: "mouthClose"
}
classification {
index: 28
score: 0.0060502808
label: "mouthDimpleLeft"
}
classification {
index: 29
score: 0.013351685
label: "mouthDimpleRight"
}
classification {
index: 30
score: 0.09859665
label: "mouthFrownLeft"
}
classification {
index: 31
score: 0.08897466
label: "mouthFrownRight"
}
classification {
index: 32
score: 0.0020718675
label: "mouthFunnel"
}
classification {
index: 33
score: 6.42887e-06
label: "mouthLeft"
}
classification {
index: 34
score: 0.68950605
label: "mouthLowerDownLeft"
}
classification {
index: 35
score: 0.7864029
label: "mouthLowerDownRight"
}
classification {
index: 36
score: 0.056456964
label: "mouthPressLeft"
}
classification {
index: 37
score: 0.037348792
label: "mouthPressRight"
}
classification {
index: 38
score: 0.00067001814
label: "mouthPucker"
}
classification {
index: 39
score: 0.005189785
label: "mouthRight"
}
classification {
index: 40
score: 0.018723497
label: "mouthRollLower"
}
classification {
index: 41
score: 0.052819636
label: "mouthRollUpper"
}
classification {
index: 42
score: 0.0033772716
label: "mouthShrugLower"
}
classification {
index: 43
score: 0.0031609535
label: "mouthShrugUpper"
}
classification {
index: 44
score: 0.49639142
label: "mouthSmileLeft"
}
classification {
index: 45
score: 0.4014515
label: "mouthSmileRight"
}
classification {
index: 46
score: 0.5825701
label: "mouthStretchLeft"
}
classification {
index: 47
score: 0.73058575
label: "mouthStretchRight"
}
classification {
index: 48
score: 0.13561466
label: "mouthUpperUpLeft"
}
classification {
index: 49
score: 0.20078722
label: "mouthUpperUpRight"
}
classification {
index: 50
score: 3.3396598e-06
label: "noseSneerLeft"
}
classification {
index: 51
score: 1.3096546e-05
label: "noseSneerRight"
}