diff --git a/mediapipe/tasks/cc/vision/face_landmarker/BUILD b/mediapipe/tasks/cc/vision/face_landmarker/BUILD index a5b8508aa..8dc4e0397 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/face_landmarker/BUILD @@ -69,12 +69,14 @@ cc_library( name = "face_landmarks_detector_graph", srcs = ["face_landmarks_detector_graph.cc"], deps = [ + ":face_blendshapes_graph", ":tensors_to_face_landmarks_graph", "//mediapipe/calculators/core:begin_loop_calculator", "//mediapipe/calculators/core:end_loop_calculator", "//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/calculators/core:split_vector_calculator_cc_proto", "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/tensor:tensors_to_floats_calculator", "//mediapipe/calculators/tensor:tensors_to_floats_calculator_cc_proto", @@ -89,8 +91,11 @@ cc_library( "//mediapipe/calculators/util:rect_transformation_calculator_cc_proto", "//mediapipe/calculators/util:thresholding_calculator", "//mediapipe/calculators/util:thresholding_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", @@ -101,9 +106,11 @@ cc_library( "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_landmarker/proto:tensors_to_face_landmarks_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "//mediapipe/util:graph_builder_utils", ], alwayslink = 1, ) diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph.cc index 61ab45fd5..a898f2fe9 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include @@ -26,6 +27,9 @@ limitations under the License. #include "mediapipe/calculators/util/thresholding_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" @@ -36,9 +40,11 @@ limitations under the License. #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/face_landmarker/proto/tensors_to_face_landmarks_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" +#include "mediapipe/util/graph_builder_utils.h" namespace mediapipe { namespace tasks { @@ -72,6 +78,7 @@ constexpr char kIterableTag[] = "ITERABLE"; constexpr char kBatchEndTag[] = "BATCH_END"; constexpr char kItemTag[] = "ITEM"; constexpr char kDetectionTag[] = "DETECTION"; +constexpr char kBlendshapesTag[] = "BLENDSHAPES"; // a landmarks tensor and a scores tensor constexpr int kFaceLandmarksOutputTensorsNum = 2; @@ -83,6 +90,7 @@ struct SingleFaceLandmarksOutputs { Stream rect_next_frame; Stream presence; Stream presence_score; + std::optional> face_blendshapes; }; struct MultiFaceLandmarksOutputs { @@ -90,6 +98,7 @@ struct MultiFaceLandmarksOutputs { Stream> rects_next_frame; Stream> presences; Stream> presence_scores; + std::optional>> face_blendshapes; }; absl::Status SanityCheckOptions( @@ -180,6 +189,62 @@ bool IsAttentionModel(const core::ModelResources& model_resources) { // Boolean value indicates whether the face is present. // PRESENCE_SCORE - float // Float value indicates the probability that the face is present. +// BLENDSHAPES - ClassificationList @optional +// Blendshape classification, available when face_blendshapes_graph_options +// is set. +// All 52 blendshape coefficients: +// 0 - _neutral (ignore it) +// 1 - browDownLeft +// 2 - browDownRight +// 3 - browInnerUp +// 4 - browOuterUpLeft +// 5 - browOuterUpRight +// 6 - cheekPuff +// 7 - cheekSquintLeft +// 8 - cheekSquintRight +// 9 - eyeBlinkLeft +// 10 - eyeBlinkRight +// 11 - eyeLookDownLeft +// 12 - eyeLookDownRight +// 13 - eyeLookInLeft +// 14 - eyeLookInRight +// 15 - eyeLookOutLeft +// 16 - eyeLookOutRight +// 17 - eyeLookUpLeft +// 18 - eyeLookUpRight +// 19 - eyeSquintLeft +// 20 - eyeSquintRight +// 21 - eyeWideLeft +// 22 - eyeWideRight +// 23 - jawForward +// 24 - jawLeft +// 25 - jawOpen +// 26 - jawRight +// 27 - mouthClose +// 28 - mouthDimpleLeft +// 29 - mouthDimpleRight +// 30 - mouthFrownLeft +// 31 - mouthFrownRight +// 32 - mouthFunnel +// 33 - mouthLeft +// 34 - mouthLowerDownLeft +// 35 - mouthLowerDownRight +// 36 - mouthPressLeft +// 37 - mouthPressRight +// 38 - mouthPucker +// 39 - mouthRight +// 40 - mouthRollLower +// 41 - mouthRollUpper +// 42 - mouthShrugLower +// 43 - mouthShrugUpper +// 44 - mouthSmileLeft +// 45 - mouthSmileRight +// 46 - mouthStretchLeft +// 47 - mouthStretchRight +// 48 - mouthUpperUpLeft +// 49 - mouthUpperUpRight +// 50 - noseSneerLeft +// 51 - noseSneerRight // // Example: // node { @@ -191,6 +256,7 @@ bool IsAttentionModel(const core::ModelResources& model_resources) { // output_stream: "FACE_RECT_NEXT_FRAME:face_rect_next_frame" // output_stream: "PRESENCE:presence" // output_stream: "PRESENCE_SCORE:presence_score" +// output_stream: "BLENDSHAPES:blendshapes" // options { // [mediapipe.tasks.vision.face_landmarker.proto.FaceLandmarksDetectorGraphOptions.ext] // { @@ -200,6 +266,13 @@ bool IsAttentionModel(const core::ModelResources& model_resources) { // } // } // min_detection_confidence: 0.5 +// face_blendshapes_graph_options { +// base_options { +// model_asset { +// file_name: "face_blendshape.tflite" +// } +// } +// } // } // } // } @@ -214,7 +287,7 @@ class SingleFaceLandmarksDetectorGraph : public core::ModelTaskGraph { ASSIGN_OR_RETURN( auto outs, BuildSingleFaceLandmarksDetectorGraph( - sc->Options(), + *sc->MutableOptions(), *model_resources, graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph)); outs.landmarks >> @@ -223,6 +296,10 @@ class SingleFaceLandmarksDetectorGraph : public core::ModelTaskGraph { graph.Out(kFaceRectNextFrameTag).Cast(); outs.presence >> graph.Out(kPresenceTag).Cast(); outs.presence_score >> graph.Out(kPresenceScoreTag).Cast(); + if (outs.face_blendshapes) { + outs.face_blendshapes.value() >> + graph.Out(kBlendshapesTag).Cast(); + } return graph.GetConfig(); } @@ -239,7 +316,7 @@ class SingleFaceLandmarksDetectorGraph : public core::ModelTaskGraph { // graph: the mediapipe graph instance to be updated. absl::StatusOr BuildSingleFaceLandmarksDetectorGraph( - const proto::FaceLandmarksDetectorGraphOptions& subgraph_options, + proto::FaceLandmarksDetectorGraphOptions& subgraph_options, const core::ModelResources& model_resources, Stream image_in, Stream face_rect, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options)); @@ -351,11 +428,26 @@ class SingleFaceLandmarksDetectorGraph : public core::ModelTaskGraph { auto face_rect_next_frame = AllowIf(face_rect_transformation.Out("").Cast(), presence, graph); + + std::optional> face_blendshapes; + if (subgraph_options.has_face_blendshapes_graph_options()) { + auto& face_blendshapes_graph = graph.AddNode( + "mediapipe.tasks.vision.face_landmarker.FaceBlendshapesGraph"); + face_blendshapes_graph.GetOptions() + .Swap(subgraph_options.mutable_face_blendshapes_graph_options()); + projected_landmarks >> face_blendshapes_graph.In(kLandmarksTag); + image_size >> face_blendshapes_graph.In(kImageSizeTag); + face_blendshapes = + std::make_optional(face_blendshapes_graph.Out(kBlendshapesTag) + .Cast()); + } + return {{ /* landmarks= */ projected_landmarks, /* rect_next_frame= */ face_rect_next_frame, /* presence= */ presence, /* presence_score= */ presence_score, + /* face_blendshapes= */ face_blendshapes, }}; } }; @@ -390,6 +482,9 @@ REGISTER_MEDIAPIPE_GRAPH( // Vector of boolean value indicates whether the face is present. // PRESENCE_SCORE - std::vector // Vector of float value indicates the probability that the face is present. +// BLENDSHAPES - std::vector @optional +// Vector of face blendshape classification, available when +// face_blendshapes_graph_options is set. // // Example: // node { @@ -401,6 +496,7 @@ REGISTER_MEDIAPIPE_GRAPH( // output_stream: "FACE_RECTS_NEXT_FRAME:face_rects_next_frame" // output_stream: "PRESENCE:presence" // output_stream: "PRESENCE_SCORE:presence_score" +// output_stream: "BLENDSHAPES:blendshapes" // options { // [mediapipe.tasks.vision.face_landmarker.proto.FaceLandmarksDetectorGraphOptions.ext] // { @@ -410,6 +506,13 @@ REGISTER_MEDIAPIPE_GRAPH( // } // } // min_detection_confidence: 0.5 +// face_blendshapes_graph_options { +// base_options { +// model_asset { +// file_name: "face_blendshape.tflite" +// } +// } +// } // } // } // } @@ -421,7 +524,7 @@ class MultiFaceLandmarksDetectorGraph : public core::ModelTaskGraph { ASSIGN_OR_RETURN( auto outs, BuildFaceLandmarksDetectorGraph( - sc->Options(), + *sc->MutableOptions(), graph[Input(kImageTag)], graph[Input>(kNormRectTag)], graph)); outs.landmarks_lists >> graph.Out(kNormLandmarksTag) @@ -431,13 +534,16 @@ class MultiFaceLandmarksDetectorGraph : public core::ModelTaskGraph { outs.presences >> graph.Out(kPresenceTag).Cast>(); outs.presence_scores >> graph.Out(kPresenceScoreTag).Cast>(); - + if (outs.face_blendshapes) { + outs.face_blendshapes.value() >> + graph.Out(kBlendshapesTag).Cast>(); + } return graph.GetConfig(); } private: absl::StatusOr BuildFaceLandmarksDetectorGraph( - const proto::FaceLandmarksDetectorGraphOptions& subgraph_options, + proto::FaceLandmarksDetectorGraphOptions& subgraph_options, Stream image_in, Stream> multi_face_rects, Graph& graph) { auto& face_landmark_subgraph = graph.AddNode( @@ -445,7 +551,7 @@ class MultiFaceLandmarksDetectorGraph : public core::ModelTaskGraph { "SingleFaceLandmarksDetectorGraph"); face_landmark_subgraph .GetOptions() - .CopyFrom(subgraph_options); + .Swap(&subgraph_options); auto& begin_loop_multi_face_rects = graph.AddNode("BeginLoopNormalizedRectCalculator"); @@ -490,11 +596,27 @@ class MultiFaceLandmarksDetectorGraph : public core::ModelTaskGraph { auto face_rects_next_frame = end_loop_rects_next_frame.Out(kIterableTag) .Cast>(); + std::optional>> + face_blendshapes_vector; + if (face_landmark_subgraph + .GetOptions() + .has_face_blendshapes_graph_options()) { + auto blendshapes = face_landmark_subgraph.Out(kBlendshapesTag); + auto& end_loop_blendshapes = + graph.AddNode("EndLoopClassificationListCalculator"); + batch_end >> end_loop_blendshapes.In(kBatchEndTag); + blendshapes >> end_loop_blendshapes.In(kItemTag); + face_blendshapes_vector = + std::make_optional(end_loop_blendshapes.Out(kIterableTag) + .Cast>()); + } + return {{ /* landmarks_lists= */ landmark_lists, /* face_rects_next_frame= */ face_rects_next_frame, /* presences= */ presences, /* presence_scores= */ presence_scores, + /* face_blendshapes= */ face_blendshapes_vector, }}; } }; diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph_test.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph_test.cc index baebfdd3b..9853e1548 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/flags/flag.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" @@ -21,6 +23,7 @@ limitations under the License. #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" @@ -33,6 +36,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" @@ -70,6 +74,9 @@ constexpr char kPortraitExpectedFaceLandamrksName[] = "portrait_expected_face_landmarks.pbtxt"; constexpr char kPortraitExpectedFaceLandamrksWithAttentionName[] = "portrait_expected_face_landmarks_with_attention.pbtxt"; +constexpr char kFaceBlendshapesModel[] = "face_blendshapes.tflite"; +constexpr char kPortraitExpectedBlendshapesName[] = + "portrait_expected_blendshapes_with_attention.pbtxt"; constexpr char kImageTag[] = "IMAGE"; constexpr char kImageName[] = "image"; @@ -86,13 +93,17 @@ constexpr char kPresenceTag[] = "PRESENCE"; constexpr char kPresenceName[] = "presence"; constexpr char kPresenceScoreTag[] = "PRESENCE_SCORE"; constexpr char kPresenceScoreName[] = "presence_score"; +constexpr char kBlendshapesTag[] = "BLENDSHAPES"; +constexpr char kBlendshapesName[] = "blendshapes"; constexpr float kFractionDiff = 0.05; // percentage constexpr float kAbsMargin = 0.03; +constexpr float kBlendshapesDiffMargin = 0.1; // Helper function to create a Single Face Landmark TaskRunner. absl::StatusOr> CreateSingleFaceLandmarksTaskRunner( - absl::string_view model_name) { + absl::string_view landmarks_model_name, + std::optional blendshapes_model_name) { Graph graph; auto& face_landmark_detection = graph.AddNode( @@ -101,8 +112,17 @@ absl::StatusOr> CreateSingleFaceLandmarksTaskRunner( auto options = std::make_unique(); options->mutable_base_options()->mutable_model_asset()->set_file_name( - JoinPath("./", kTestDataDirectory, model_name)); + JoinPath("./", kTestDataDirectory, landmarks_model_name)); options->set_min_detection_confidence(0.5); + + if (blendshapes_model_name.has_value()) { + options->mutable_face_blendshapes_graph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name( + JoinPath("./", kTestDataDirectory, *blendshapes_model_name)); + } + face_landmark_detection.GetOptions() .Swap(options.get()); @@ -120,6 +140,10 @@ absl::StatusOr> CreateSingleFaceLandmarksTaskRunner( face_landmark_detection.Out(kFaceRectNextFrameTag) .SetName(kFaceRectNextFrameName) >> graph[Output(kFaceRectNextFrameTag)]; + if (blendshapes_model_name.has_value()) { + face_landmark_detection.Out(kBlendshapesTag).SetName(kBlendshapesName) >> + graph[Output(kBlendshapesTag)]; + } return TaskRunner::Create( graph.GetConfig(), absl::make_unique()); @@ -127,7 +151,8 @@ absl::StatusOr> CreateSingleFaceLandmarksTaskRunner( // Helper function to create a Multi Face Landmark TaskRunner. absl::StatusOr> CreateMultiFaceLandmarksTaskRunner( - absl::string_view model_name) { + absl::string_view landmarks_model_name, + std::optional blendshapes_model_name) { Graph graph; auto& face_landmark_detection = graph.AddNode( @@ -136,8 +161,15 @@ absl::StatusOr> CreateMultiFaceLandmarksTaskRunner( auto options = std::make_unique(); options->mutable_base_options()->mutable_model_asset()->set_file_name( - JoinPath("./", kTestDataDirectory, model_name)); + JoinPath("./", kTestDataDirectory, landmarks_model_name)); options->set_min_detection_confidence(0.5); + if (blendshapes_model_name.has_value()) { + options->mutable_face_blendshapes_graph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name( + JoinPath("./", kTestDataDirectory, *blendshapes_model_name)); + } face_landmark_detection.GetOptions() .Swap(options.get()); @@ -156,6 +188,10 @@ absl::StatusOr> CreateMultiFaceLandmarksTaskRunner( face_landmark_detection.Out(kFaceRectsNextFrameTag) .SetName(kFaceRectsNextFrameName) >> graph[Output>(kFaceRectsNextFrameTag)]; + if (blendshapes_model_name.has_value()) { + face_landmark_detection.Out(kBlendshapesTag).SetName(kBlendshapesName) >> + graph[Output(kBlendshapesTag)]; + } return TaskRunner::Create( graph.GetConfig(), absl::make_unique()); @@ -168,6 +204,13 @@ NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) { return expected_landmark_list; } +ClassificationList GetBlendshapes(absl::string_view filename) { + ClassificationList blendshapes; + MP_EXPECT_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, filename), + &blendshapes, Defaults())); + return blendshapes; +} + // Helper function to construct NormalizeRect proto. NormalizedRect MakeNormRect(float x_center, float y_center, float width, float height, float rotation) { @@ -185,8 +228,10 @@ NormalizedRect MakeNormRect(float x_center, float y_center, float width, struct SingeFaceTestParams { // The name of this test, for convenience when displaying test results. std::string test_name; - // The filename of the model to test. - std::string input_model_name; + // The filename of landmarks model name. + std::string landmarks_model_name; + // The filename of blendshape model name. + std::optional blendshape_model_name; // The filename of the test image. std::string test_image_name; // RoI on image to detect hands. @@ -195,15 +240,22 @@ struct SingeFaceTestParams { bool expected_presence; // The expected output landmarks positions. NormalizedLandmarkList expected_landmarks; + // The expected output blendshape classification; + std::optional expected_blendshapes; // The max value difference between expected_positions and detected positions. float landmarks_diff_threshold; + // The max value difference between expected blendshapes and actual + // blendshapes. + float blendshapes_diff_threshold; }; struct MultiFaceTestParams { // The name of this test, for convenience when displaying test results. std::string test_name; - // The filename of the model to test. - std::string input_model_name; + // The filename of landmarks model name. + std::string landmarks_model_name; + // The filename of blendshape model name. + std::optional blendshape_model_name; // The filename of the test image. std::string test_image_name; // RoI on image to detect hands. @@ -212,8 +264,13 @@ struct MultiFaceTestParams { std::vector expected_presence; // The expected output landmarks positions. std::optional> expected_landmarks_lists; + // The expected output blendshape classification; + std::optional> expected_blendshapes; // The max value difference between expected_positions and detected positions. float landmarks_diff_threshold; + // The max value difference between expected blendshapes and actual + // blendshapes. + float blendshapes_diff_threshold; }; class SingleFaceLandmarksDetectionTest @@ -223,8 +280,10 @@ TEST_P(SingleFaceLandmarksDetectionTest, Succeeds) { MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, GetParam().test_image_name))); - MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateSingleFaceLandmarksTaskRunner( - GetParam().input_model_name)); + MP_ASSERT_OK_AND_ASSIGN( + auto task_runner, + CreateSingleFaceLandmarksTaskRunner(GetParam().landmarks_model_name, + GetParam().blendshape_model_name)); auto output_packets = task_runner->Process( {{kImageName, MakePacket(std::move(image))}, @@ -246,6 +305,15 @@ TEST_P(SingleFaceLandmarksDetectionTest, Succeeds) { Approximately(Partially(EqualsProto(expected_landmarks)), /*margin=*/kAbsMargin, /*fraction=*/GetParam().landmarks_diff_threshold)); + if (GetParam().expected_blendshapes) { + const ClassificationList& actual_blendshapes = + (*output_packets)[kBlendshapesName].Get(); + const ClassificationList& expected_blendshapes = + *GetParam().expected_blendshapes; + EXPECT_THAT(actual_blendshapes, + Approximately(EqualsProto(expected_blendshapes), + GetParam().blendshapes_diff_threshold)); + } } } @@ -256,8 +324,10 @@ TEST_P(MultiFaceLandmarksDetectionTest, Succeeds) { MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, GetParam().test_image_name))); - MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateMultiFaceLandmarksTaskRunner( - GetParam().input_model_name)); + MP_ASSERT_OK_AND_ASSIGN( + auto task_runner, + CreateMultiFaceLandmarksTaskRunner(GetParam().landmarks_model_name, + GetParam().blendshape_model_name)); auto output_packets = task_runner->Process( {{kImageName, MakePacket(std::move(image))}, @@ -278,29 +348,63 @@ TEST_P(MultiFaceLandmarksDetectionTest, Succeeds) { /*fraction=*/GetParam().landmarks_diff_threshold), *GetParam().expected_landmarks_lists)); } + if (GetParam().expected_blendshapes) { + const std::vector& actual_blendshapes = + (*output_packets)[kBlendshapesName] + .Get>(); + const std::vector& expected_blendshapes = + *GetParam().expected_blendshapes; + EXPECT_THAT(actual_blendshapes, + Pointwise(Approximately(EqualsProto(), + GetParam().blendshapes_diff_threshold), + expected_blendshapes)); + } } INSTANTIATE_TEST_SUITE_P( FaceLandmarksDetectionTest, SingleFaceLandmarksDetectionTest, Values(SingeFaceTestParams{ /* test_name= */ "Portrait", - /*input_model_name= */ kFaceLandmarksDetectionModel, - /*test_image_name=*/kPortraitImageName, - /*norm_rect= */ MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0), - /*expected_presence = */ true, - /*expected_landmarks = */ + /* landmarks_model_name= */ kFaceLandmarksDetectionModel, + /* blendshape_model_name= */ std::nullopt, + /* test_image_name=*/kPortraitImageName, + /* norm_rect= */ MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0), + /* expected_presence= */ true, + /* expected_landmarks= */ GetExpectedLandmarkList(kPortraitExpectedFaceLandamrksName), - /*landmarks_diff_threshold = */ kFractionDiff}, + /* expected_blendshapes= */ std::nullopt, + /* landmarks_diff_threshold= */ kFractionDiff, + /* blendshapes_diff_threshold= */ kBlendshapesDiffMargin}, SingeFaceTestParams{ /* test_name= */ "PortraitWithAttention", - /*input_model_name= */ kFaceLandmarksDetectionWithAttentionModel, - /*test_image_name=*/kPortraitImageName, - /*norm_rect= */ MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0), - /*expected_presence = */ true, - /*expected_landmarks = */ + /* landmarks_model_name= */ + kFaceLandmarksDetectionWithAttentionModel, + /* blendshape_model_name= */ std::nullopt, + /* test_image_name= */ kPortraitImageName, + /* norm_rect= */ MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0), + /* expected_presence= */ true, + /* expected_landmarks= */ GetExpectedLandmarkList( kPortraitExpectedFaceLandamrksWithAttentionName), - /*landmarks_diff_threshold = */ kFractionDiff}), + /* expected_blendshapes= */ std::nullopt, + /* landmarks_diff_threshold= */ kFractionDiff, + /* blendshapes_diff_threshold= */ kBlendshapesDiffMargin}, + SingeFaceTestParams{ + /* test_name= */ "PortraitWithAttentionWithBlendshapes", + /* landmarks_model_name= */ + kFaceLandmarksDetectionWithAttentionModel, + /* blendshape_model_name= */ kFaceBlendshapesModel, + /* test_image_name= */ kPortraitImageName, + /* norm_rect= */ MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0), + /* expected_presence= */ true, + /* expected_landmarks= */ + GetExpectedLandmarkList( + kPortraitExpectedFaceLandamrksWithAttentionName), + /* expected_blendshapes= */ + GetBlendshapes(kPortraitExpectedBlendshapesName), + /* landmarks_diff_threshold= */ kFractionDiff, + /* blendshapes_diff_threshold= */ kBlendshapesDiffMargin}), + [](const TestParamInfo& info) { return info.param.test_name; }); @@ -310,31 +414,57 @@ INSTANTIATE_TEST_SUITE_P( Values( MultiFaceTestParams{ /* test_name= */ "Portrait", - /*input_model_name= */ kFaceLandmarksDetectionModel, - /*test_image_name=*/kPortraitImageName, - /*norm_rects= */ {MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0)}, - /*expected_presence = */ {true}, - /*expected_landmarks_list = */ + /* landmarks_model_name= */ kFaceLandmarksDetectionModel, + /* blendshape_model_name= */ std::nullopt, + /* test_image_name= */ kPortraitImageName, + /* norm_rects= */ {MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0)}, + /* expected_presence= */ {true}, + /* expected_landmarks_list= */ {{GetExpectedLandmarkList(kPortraitExpectedFaceLandamrksName)}}, - /*landmarks_diff_threshold = */ kFractionDiff}, + /* expected_blendshapes= */ std::nullopt, + /* landmarks_diff_threshold= */ kFractionDiff, + /* blendshapes_diff_threshold= */ kBlendshapesDiffMargin}, MultiFaceTestParams{ /* test_name= */ "PortraitWithAttention", - /*input_model_name= */ kFaceLandmarksDetectionWithAttentionModel, - /*test_image_name=*/kPortraitImageName, - /*norm_rects= */ {MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0)}, - /*expected_presence = */ {true}, - /*expected_landmarks_list = */ + /* landmarks_model_name= */ + kFaceLandmarksDetectionWithAttentionModel, + /* blendshape_model_name= */ std::nullopt, + /* test_image_name= */ kPortraitImageName, + /* norm_rects= */ {MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0)}, + /* expected_presence= */ {true}, + /* expected_landmarks_list= */ {{GetExpectedLandmarkList( kPortraitExpectedFaceLandamrksWithAttentionName)}}, - /*landmarks_diff_threshold = */ kFractionDiff}, + /* expected_blendshapes= */ std::nullopt, + /* landmarks_diff_threshold= */ kFractionDiff, + /* blendshapes_diff_threshold= */ kBlendshapesDiffMargin}, + MultiFaceTestParams{ + /* test_name= */ "PortraitWithAttentionWithBlendshapes", + /* landmarks_model_name= */ + kFaceLandmarksDetectionWithAttentionModel, + /* blendshape_model_name= */ kFaceBlendshapesModel, + /* test_image_name= */ kPortraitImageName, + /* norm_rects= */ {MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0)}, + /* expected_presence= */ {true}, + /* expected_landmarks_list= */ + {{GetExpectedLandmarkList( + kPortraitExpectedFaceLandamrksWithAttentionName)}}, + /* expected_blendshapes= */ + {{GetBlendshapes(kPortraitExpectedBlendshapesName)}}, + /* landmarks_diff_threshold= */ kFractionDiff, + /* blendshapes_diff_threshold= */ kBlendshapesDiffMargin}, MultiFaceTestParams{ /* test_name= */ "NoFace", - /*input_model_name= */ kFaceLandmarksDetectionModel, - /*test_image_name=*/kCatImageName, - /*norm_rects= */ {MakeNormRect(0.5, 0.5, 1.0, 1.0, 0)}, - /*expected_presence = */ {false}, - /*expected_landmarks_list = */ std::nullopt, - /*landmarks_diff_threshold = */ kFractionDiff}), + /* landmarks_model_name= */ + kFaceLandmarksDetectionModel, + /* blendshape_model_name= */ std::nullopt, + /* test_image_name= */ kCatImageName, + /* norm_rects= */ {MakeNormRect(0.5, 0.5, 1.0, 1.0, 0)}, + /* expected_presence= */ {false}, + /* expected_landmarks_list= */ std::nullopt, + /* expected_blendshapes= */ std::nullopt, + /*landmarks_diff_threshold= */ kFractionDiff, + /*blendshapes_diff_threshold= */ kBlendshapesDiffMargin}), [](const TestParamInfo& info) { return info.param.test_name; }); diff --git a/mediapipe/tasks/cc/vision/face_landmarker/proto/BUILD b/mediapipe/tasks/cc/vision/face_landmarker/proto/BUILD index fe8e8353e..d68207b9e 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/proto/BUILD +++ b/mediapipe/tasks/cc/vision/face_landmarker/proto/BUILD @@ -34,6 +34,7 @@ mediapipe_proto_library( name = "face_landmarks_detector_graph_options_proto", srcs = ["face_landmarks_detector_graph_options.proto"], deps = [ + ":face_blendshapes_graph_options_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", diff --git a/mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto b/mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto index 90bfd0087..c2fa49607 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto @@ -20,6 +20,7 @@ package mediapipe.tasks.vision.face_landmarker.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.facelandmarker.proto"; option java_outer_classname = "FaceLandmarksDetectorGraphOptionsProto"; @@ -35,4 +36,8 @@ message FaceLandmarksDetectorGraphOptions { // Minimum confidence value ([0.0, 1.0]) for confidence score to be considered // successfully detecting a face in the image. optional float min_detection_confidence = 2 [default = 0.5]; + + // Optional options for FaceBlendshapeGraph. If this options is set, the + // FaceLandmarksDetectorGraph would output the face blendshapes. + optional FaceBlendshapesGraphOptions face_blendshapes_graph_options = 3; } diff --git a/mediapipe/tasks/testdata/vision/portrait_expected_blendshapes_with_attention.pbtxt b/mediapipe/tasks/testdata/vision/portrait_expected_blendshapes_with_attention.pbtxt new file mode 100644 index 000000000..777756f6c --- /dev/null +++ b/mediapipe/tasks/testdata/vision/portrait_expected_blendshapes_with_attention.pbtxt @@ -0,0 +1,262 @@ +# proto-file: mediapipe/framework/formats/classification.proto +# proto-message: ClassificationList +classification { + index: 0 + score: 4.9559007e-06 + label: "_neutral" +} +classification { + index: 1 + score: 0.22943014 + label: "browDownLeft" +} +classification { + index: 2 + score: 0.22297752 + label: "browDownRight" +} +classification { + index: 3 + score: 0.015948873 + label: "browInnerUp" +} +classification { + index: 4 + score: 0.006946607 + label: "browOuterUpLeft" +} +classification { + index: 5 + score: 0.0070318673 + label: "browOuterUpRight" +} +classification { + index: 6 + score: 0.0013679645 + label: "cheekPuff" +} +classification { + index: 7 + score: 7.1003383e-06 + label: "cheekSquintLeft" +} +classification { + index: 8 + score: 5.78299e-06 + label: "cheekSquintRight" +} +classification { + index: 9 + score: 0.20132238 + label: "eyeBlinkLeft" +} +classification { + index: 10 + score: 0.16521452 + label: "eyeBlinkRight" +} +classification { + index: 11 + score: 0.03764786 + label: "eyeLookDownLeft" +} +classification { + index: 12 + score: 0.04828824 + label: "eyeLookDownRight" +} +classification { + index: 13 + score: 0.016539993 + label: "eyeLookInLeft" +} +classification { + index: 14 + score: 0.20026363 + label: "eyeLookInRight" +} +classification { + index: 15 + score: 0.21363346 + label: "eyeLookOutLeft" +} +classification { + index: 16 + score: 0.024430025 + label: "eyeLookOutRight" +} +classification { + index: 17 + score: 0.30147508 + label: "eyeLookUpLeft" +} +classification { + index: 18 + score: 0.28701693 + label: "eyeLookUpRight" +} +classification { + index: 19 + score: 0.67143106 + label: "eyeSquintLeft" +} +classification { + index: 20 + score: 0.5306328 + label: "eyeSquintRight" +} +classification { + index: 21 + score: 0.0041342233 + label: "eyeWideLeft" +} +classification { + index: 22 + score: 0.005231879 + label: "eyeWideRight" +} +classification { + index: 23 + score: 0.009427094 + label: "jawForward" +} +classification { + index: 24 + score: 0.0015789346 + label: "jawLeft" +} +classification { + index: 25 + score: 0.073719256 + label: "jawOpen" +} +classification { + index: 26 + score: 0.00046979196 + label: "jawRight" +} +classification { + index: 27 + score: 0.0011400756 + label: "mouthClose" +} +classification { + index: 28 + score: 0.0060502808 + label: "mouthDimpleLeft" +} +classification { + index: 29 + score: 0.013351685 + label: "mouthDimpleRight" +} +classification { + index: 30 + score: 0.09859665 + label: "mouthFrownLeft" +} +classification { + index: 31 + score: 0.08897466 + label: "mouthFrownRight" +} +classification { + index: 32 + score: 0.0020718675 + label: "mouthFunnel" +} +classification { + index: 33 + score: 6.42887e-06 + label: "mouthLeft" +} +classification { + index: 34 + score: 0.68950605 + label: "mouthLowerDownLeft" +} +classification { + index: 35 + score: 0.7864029 + label: "mouthLowerDownRight" +} +classification { + index: 36 + score: 0.056456964 + label: "mouthPressLeft" +} +classification { + index: 37 + score: 0.037348792 + label: "mouthPressRight" +} +classification { + index: 38 + score: 0.00067001814 + label: "mouthPucker" +} +classification { + index: 39 + score: 0.005189785 + label: "mouthRight" +} +classification { + index: 40 + score: 0.018723497 + label: "mouthRollLower" +} +classification { + index: 41 + score: 0.052819636 + label: "mouthRollUpper" +} +classification { + index: 42 + score: 0.0033772716 + label: "mouthShrugLower" +} +classification { + index: 43 + score: 0.0031609535 + label: "mouthShrugUpper" +} +classification { + index: 44 + score: 0.49639142 + label: "mouthSmileLeft" +} +classification { + index: 45 + score: 0.4014515 + label: "mouthSmileRight" +} +classification { + index: 46 + score: 0.5825701 + label: "mouthStretchLeft" +} +classification { + index: 47 + score: 0.73058575 + label: "mouthStretchRight" +} +classification { + index: 48 + score: 0.13561466 + label: "mouthUpperUpLeft" +} +classification { + index: 49 + score: 0.20078722 + label: "mouthUpperUpRight" +} +classification { + index: 50 + score: 3.3396598e-06 + label: "noseSneerLeft" +} +classification { + index: 51 + score: 1.3096546e-05 + label: "noseSneerRight" +}