face landmarks detector graph

PiperOrigin-RevId: 509630430
This commit is contained in:
MediaPipe Team 2023-02-14 13:58:46 -08:00 committed by Copybara-Service
parent d6fd2c52a7
commit 5f2261ff59
8 changed files with 2828 additions and 2 deletions

View File

@ -0,0 +1,60 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = [
"//mediapipe/tasks:internal",
])
licenses(["notice"])
cc_library(
name = "face_landmarks_detector_graph",
srcs = ["face_landmarks_detector_graph.cc"],
deps = [
"//mediapipe/calculators/core:begin_loop_calculator",
"//mediapipe/calculators/core:end_loop_calculator",
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/tensor:tensors_to_floats_calculator",
"//mediapipe/calculators/tensor:tensors_to_floats_calculator_cc_proto",
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator",
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator_cc_proto",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto",
"//mediapipe/calculators/util:landmark_letterbox_removal_calculator",
"//mediapipe/calculators/util:landmark_projection_calculator",
"//mediapipe/calculators/util:landmarks_to_detection_calculator",
"//mediapipe/calculators/util:rect_transformation_calculator",
"//mediapipe/calculators/util:rect_transformation_calculator_cc_proto",
"//mediapipe/calculators/util:thresholding_calculator",
"//mediapipe/calculators/util:thresholding_calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/components/utils:gate",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
],
alwayslink = 1,
)

View File

@ -0,0 +1,489 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_floats_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_landmarks_calculator.pb.h"
#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h"
#include "mediapipe/calculators/util/thresholding_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/components/utils/gate.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace face_landmarker {
namespace {
using ::mediapipe::NormalizedRect;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::components::utils::AllowIf;
constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kFaceRectNextFrameTag[] = "FACE_RECT_NEXT_FRAME";
constexpr char kFaceRectsNextFrameTag[] = "FACE_RECTS_NEXT_FRAME";
constexpr char kPresenceTag[] = "PRESENCE";
constexpr char kPresenceScoreTag[] = "PRESENCE_SCORE";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kLandmarksTag[] = "LANDMARKS";
constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
constexpr char kFloatTag[] = "FLOAT";
constexpr char kFlagTag[] = "FLAG";
constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING";
constexpr char kCloneTag[] = "CLONE";
constexpr char kIterableTag[] = "ITERABLE";
constexpr char kBatchEndTag[] = "BATCH_END";
constexpr char kItemTag[] = "ITEM";
constexpr char kDetectionTag[] = "DETECTION";
constexpr int kLandmarksNum = 468;
constexpr int kModelOutputTensorSplitNum = 2;
struct SingleFaceLandmarksOutputs {
Stream<NormalizedLandmarkList> landmarks;
Stream<NormalizedRect> rect_next_frame;
Stream<bool> presence;
Stream<float> presence_score;
};
struct MultiFaceLandmarksOutputs {
Stream<std::vector<NormalizedLandmarkList>> landmarks_lists;
Stream<std::vector<NormalizedRect>> rects_next_frame;
Stream<std::vector<bool>> presences;
Stream<std::vector<float>> presence_scores;
};
absl::Status SanityCheckOptions(
const proto::FaceLandmarksDetectorGraphOptions& options) {
if (options.min_detection_confidence() < 0 ||
options.min_detection_confidence() > 1) {
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
"Invalid `min_detection_confidence` option: "
"value must be in the range [0.0, 1.0]",
MediaPipeTasksStatus::kInvalidArgumentError);
}
return absl::OkStatus();
}
// Split face landmark detection model output tensor into two parts,
// representing landmarks and face presence scores.
void ConfigureSplitTensorVectorCalculator(
mediapipe::SplitVectorCalculatorOptions* options) {
for (int i = 0; i < kModelOutputTensorSplitNum; ++i) {
auto* range = options->add_ranges();
range->set_begin(i);
range->set_end(i + 1);
}
}
void ConfigureTensorsToLandmarksCalculator(
const ImageTensorSpecs& input_image_tensor_spec,
mediapipe::TensorsToLandmarksCalculatorOptions* options) {
options->set_num_landmarks(kLandmarksNum);
options->set_input_image_height(input_image_tensor_spec.image_height);
options->set_input_image_width(input_image_tensor_spec.image_width);
}
void ConfigureFaceDetectionsToRectsCalculator(
mediapipe::DetectionsToRectsCalculatorOptions* options) {
// Left side of left eye.
options->set_rotation_vector_start_keypoint_index(33);
// Right side of right eye.
options->set_rotation_vector_end_keypoint_index(263);
options->set_rotation_vector_target_angle_degrees(0);
}
void ConfigureFaceRectTransformationCalculator(
mediapipe::RectTransformationCalculatorOptions* options) {
// TODO: make rect transformation configurable, e.g. from
// Metadata or configuration options.
options->set_scale_x(1.5f);
options->set_scale_y(1.5f);
options->set_square_long(true);
}
} // namespace
// A "mediapipe.tasks.vision.face_landmarker.SingleFaceLandmarksDetectorGraph"
// performs face landmarks detection.
//
// Inputs:
// IMAGE - Image
// Image to perform detection on.
// NORM_RECT - NormalizedRect @Optional
// Rect enclosing the RoI to perform detection on. If not set, the detection
// RoI is the whole image.
//
//
// Outputs:
// NORM_LANDMARKS: - NormalizedLandmarkList
// Detected face landmarks.
// FACE_RECT_NEXT_FRAME - NormalizedRect
// The predicted Rect enclosing the face RoI for landmark detection on the
// next frame.
// PRESENCE - bool
// Boolean value indicates whether the face is present.
// PRESENCE_SCORE - float
// Float value indicates the probability that the face is present.
//
// Example:
// node {
// calculator:
// "mediapipe.tasks.vision.face_landmarker.SingleFaceLandmarksDetectorGraph"
// input_stream: "IMAGE:input_image"
// input_stream: "FACE_RECT:face_rect"
// output_stream: "LANDMARKS:face_landmarks"
// output_stream: "FACE_RECT_NEXT_FRAME:face_rect_next_frame"
// output_stream: "PRESENCE:presence"
// output_stream: "PRESENCE_SCORE:presence_score"
// options {
// [mediapipe.tasks.vision.face_landmarker.proto.FaceLandmarksDetectorGraphOptions.ext]
// {
// base_options {
// model_asset {
// file_name: "face_landmark_lite.tflite"
// }
// }
// min_detection_confidence: 0.5
// }
// }
// }
class SingleFaceLandmarksDetectorGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
ASSIGN_OR_RETURN(
const auto* model_resources,
CreateModelResources<proto::FaceLandmarksDetectorGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(
auto outs,
BuildSingleFaceLandmarksDetectorGraph(
sc->Options<proto::FaceLandmarksDetectorGraphOptions>(),
*model_resources, graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
outs.landmarks >>
graph.Out(kNormLandmarksTag).Cast<NormalizedLandmarkList>();
outs.rect_next_frame >>
graph.Out(kFaceRectNextFrameTag).Cast<NormalizedRect>();
outs.presence >> graph.Out(kPresenceTag).Cast<bool>();
outs.presence_score >> graph.Out(kPresenceScoreTag).Cast<float>();
return graph.GetConfig();
}
private:
// Adds a mediapipe face landmark detection graph into the provided
// builder::Graph instance.
//
// subgraph_options: the mediapipe tasks module
// FaceLandmarksDetectorGraphOptions.
// model_resources: the ModelSources object initialized from a face landmark
// detection model file with model metadata.
// image_in: (mediapipe::Image) stream to run face landmark detection on.
// face_rect: (NormalizedRect) stream to run on the RoI of image.
// graph: the mediapipe graph instance to be updated.
absl::StatusOr<SingleFaceLandmarksOutputs>
BuildSingleFaceLandmarksDetectorGraph(
const proto::FaceLandmarksDetectorGraphOptions& subgraph_options,
const core::ModelResources& model_resources, Stream<Image> image_in,
Stream<NormalizedRect> face_rect, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options));
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
bool use_gpu =
components::processors::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu,
&preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In(kImageTag);
face_rect >> preprocessing.In(kNormRectTag);
auto image_size = preprocessing.Out(kImageSizeTag);
auto letterbox_padding = preprocessing.Out(kLetterboxPaddingTag);
auto input_tensors = preprocessing.Out(kTensorsTag);
auto& inference = AddInference(
model_resources, subgraph_options.base_options().acceleration(), graph);
input_tensors >> inference.In(kTensorsTag);
auto output_tensors = inference.Out(kTensorsTag);
// Split model output tensors to multiple streams.
auto& split_tensors_vector = graph.AddNode("SplitTensorVectorCalculator");
ConfigureSplitTensorVectorCalculator(
&split_tensors_vector
.GetOptions<mediapipe::SplitVectorCalculatorOptions>());
output_tensors >> split_tensors_vector.In("");
auto landmark_tensors = split_tensors_vector.Out(0);
auto presence_flag_tensors = split_tensors_vector.Out(1);
// Decodes the landmark tensors into a list of landmarks, where the landmark
// coordinates are normalized by the size of the input image to the model.
auto& tensors_to_landmarks = graph.AddNode("TensorsToLandmarksCalculator");
ASSIGN_OR_RETURN(auto image_tensor_specs,
vision::BuildInputImageTensorSpecs(model_resources));
ConfigureTensorsToLandmarksCalculator(
image_tensor_specs,
&tensors_to_landmarks
.GetOptions<mediapipe::TensorsToLandmarksCalculatorOptions>());
landmark_tensors >> tensors_to_landmarks.In(kTensorsTag);
auto landmarks = tensors_to_landmarks.Out(kNormLandmarksTag);
// Converts the presence flag tensor into a float that represents the
// confidence score of face presence.
auto& tensors_to_presence = graph.AddNode("TensorsToFloatsCalculator");
tensors_to_presence
.GetOptions<mediapipe::TensorsToFloatsCalculatorOptions>()
.set_activation(mediapipe::TensorsToFloatsCalculatorOptions::SIGMOID);
presence_flag_tensors >> tensors_to_presence.In(kTensorsTag);
auto presence_score = tensors_to_presence.Out(kFloatTag).Cast<float>();
// Applies a threshold to the confidence score to determine whether a
// face is present.
auto& presence_thresholding = graph.AddNode("ThresholdingCalculator");
presence_thresholding.GetOptions<mediapipe::ThresholdingCalculatorOptions>()
.set_threshold(subgraph_options.min_detection_confidence());
presence_score >> presence_thresholding.In(kFloatTag);
auto presence = presence_thresholding.Out(kFlagTag).Cast<bool>();
// Adjusts landmarks (already normalized to [0.f, 1.f]) on the letterboxed
// face image (after image transformation with the FIT scale mode) to the
// corresponding locations on the same image with the letterbox removed
// (face image before image transformation).
auto& landmark_letterbox_removal =
graph.AddNode("LandmarkLetterboxRemovalCalculator");
letterbox_padding >> landmark_letterbox_removal.In(kLetterboxPaddingTag);
landmarks >> landmark_letterbox_removal.In(kLandmarksTag);
auto landmarks_letterbox_removed =
landmark_letterbox_removal.Out(kLandmarksTag);
// Projects the landmarks from the cropped face image to the corresponding
// locations on the full image before cropping (input to the graph).
auto& landmark_projection = graph.AddNode("LandmarkProjectionCalculator");
landmarks_letterbox_removed >> landmark_projection.In(kNormLandmarksTag);
face_rect >> landmark_projection.In(kNormRectTag);
auto projected_landmarks = AllowIf(
landmark_projection[Output<NormalizedLandmarkList>(kNormLandmarksTag)],
presence, graph);
// Converts the face landmarks into a rectangle (normalized by image size)
// that encloses the face.
auto& landmarks_to_detection =
graph.AddNode("LandmarksToDetectionCalculator");
projected_landmarks >> landmarks_to_detection.In(kNormLandmarksTag);
auto face_landmarks_detection = landmarks_to_detection.Out(kDetectionTag);
auto& detection_to_rect = graph.AddNode("DetectionsToRectsCalculator");
ConfigureFaceDetectionsToRectsCalculator(
&detection_to_rect
.GetOptions<mediapipe::DetectionsToRectsCalculatorOptions>());
face_landmarks_detection >> detection_to_rect.In(kDetectionTag);
image_size >> detection_to_rect.In(kImageSizeTag);
auto face_landmarks_rect = detection_to_rect.Out(kNormRectTag);
// Expands the face rectangle so that in the next video frame it's likely to
// still contain the face even with some motion.
auto& face_rect_transformation =
graph.AddNode("RectTransformationCalculator");
ConfigureFaceRectTransformationCalculator(
&face_rect_transformation
.GetOptions<mediapipe::RectTransformationCalculatorOptions>());
image_size >> face_rect_transformation.In(kImageSizeTag);
face_landmarks_rect >> face_rect_transformation.In(kNormRectTag);
auto face_rect_next_frame =
AllowIf(face_rect_transformation.Out("").Cast<NormalizedRect>(),
presence, graph);
return {{
/* landmarks= */ projected_landmarks,
/* rect_next_frame= */ face_rect_next_frame,
/* presence= */ presence,
/* presence_score= */ presence_score,
}};
}
};
// clang-format off
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::face_landmarker::SingleFaceLandmarksDetectorGraph); // NOLINT
// clang-format on
// A "mediapipe.tasks.vision.face_landmarker.MultiFaceLandmarksDetectorGraph"
// performs multi face landmark detection.
// - Accepts an input image and a vector of face rect RoIs to detect the
// multiple face landmarks enclosed by the RoIs. Output vectors of
// face landmarks related results, where each element in the vectors
// corrresponds to the result of the same face.
//
// Inputs:
// IMAGE - Image
// Image to perform detection on.
// NORM_RECT - std::vector<NormalizedRect>
// A vector of multiple norm rects enclosing the face RoI to perform
// landmarks detection on.
//
//
// Outputs:
// LANDMARKS: - std::vector<NormalizedLandmarkList>
// Vector of detected face landmarks.
// FACE_RECTS_NEXT_FRAME - std::vector<NormalizedRect>
// Vector of the predicted rects enclosing the same face RoI for landmark
// detection on the next frame.
// PRESENCE - std::vector<bool>
// Vector of boolean value indicates whether the face is present.
// PRESENCE_SCORE - std::vector<float>
// Vector of float value indicates the probability that the face is present.
//
// Example:
// node {
// calculator:
// "mediapipe.tasks.vision.face_landmarker.MultiFaceLandmarksDetectorGraph"
// input_stream: "IMAGE:input_image"
// input_stream: "NORM_RECT:norm_rect"
// output_stream: "LANDMARKS:landmarks"
// output_stream: "FACE_RECTS_NEXT_FRAME:face_rects_next_frame"
// output_stream: "PRESENCE:presence"
// output_stream: "PRESENCE_SCORE:presence_score"
// options {
// [mediapipe.tasks.vision.face_landmarker.proto.FaceLandmarksDetectorGraphOptions.ext]
// {
// base_options {
// model_asset {
// file_name: "face_landmark_lite.tflite"
// }
// }
// min_detection_confidence: 0.5
// }
// }
// }
class MultiFaceLandmarksDetectorGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
ASSIGN_OR_RETURN(
auto outs,
BuildFaceLandmarksDetectorGraph(
sc->Options<proto::FaceLandmarksDetectorGraphOptions>(),
graph[Input<Image>(kImageTag)],
graph[Input<std::vector<NormalizedRect>>(kNormRectTag)], graph));
outs.landmarks_lists >> graph.Out(kNormLandmarksTag)
.Cast<std::vector<NormalizedLandmarkList>>();
outs.rects_next_frame >>
graph.Out(kFaceRectsNextFrameTag).Cast<std::vector<NormalizedRect>>();
outs.presences >> graph.Out(kPresenceTag).Cast<std::vector<bool>>();
outs.presence_scores >>
graph.Out(kPresenceScoreTag).Cast<std::vector<float>>();
return graph.GetConfig();
}
private:
absl::StatusOr<MultiFaceLandmarksOutputs> BuildFaceLandmarksDetectorGraph(
const proto::FaceLandmarksDetectorGraphOptions& subgraph_options,
Stream<Image> image_in,
Stream<std::vector<NormalizedRect>> multi_face_rects, Graph& graph) {
auto& face_landmark_subgraph = graph.AddNode(
"mediapipe.tasks.vision.face_landmarker."
"SingleFaceLandmarksDetectorGraph");
face_landmark_subgraph
.GetOptions<proto::FaceLandmarksDetectorGraphOptions>()
.CopyFrom(subgraph_options);
auto& begin_loop_multi_face_rects =
graph.AddNode("BeginLoopNormalizedRectCalculator");
image_in >> begin_loop_multi_face_rects.In(kCloneTag);
multi_face_rects >> begin_loop_multi_face_rects.In(kIterableTag);
auto batch_end = begin_loop_multi_face_rects.Out(kBatchEndTag);
auto image = begin_loop_multi_face_rects.Out(kCloneTag);
auto face_rect = begin_loop_multi_face_rects.Out(kItemTag);
image >> face_landmark_subgraph.In(kImageTag);
face_rect >> face_landmark_subgraph.In(kNormRectTag);
auto presence = face_landmark_subgraph.Out(kPresenceTag);
auto presence_score = face_landmark_subgraph.Out(kPresenceScoreTag);
auto face_rect_next_frame =
face_landmark_subgraph.Out(kFaceRectNextFrameTag);
auto landmarks = face_landmark_subgraph.Out(kNormLandmarksTag);
auto& end_loop_presence = graph.AddNode("EndLoopBooleanCalculator");
batch_end >> end_loop_presence.In(kBatchEndTag);
presence >> end_loop_presence.In(kItemTag);
auto presences =
end_loop_presence.Out(kIterableTag).Cast<std::vector<bool>>();
auto& end_loop_presence_score = graph.AddNode("EndLoopFloatCalculator");
batch_end >> end_loop_presence_score.In(kBatchEndTag);
presence_score >> end_loop_presence_score.In(kItemTag);
auto presence_scores =
end_loop_presence_score.Out(kIterableTag).Cast<std::vector<float>>();
auto& end_loop_landmarks =
graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator");
batch_end >> end_loop_landmarks.In(kBatchEndTag);
landmarks >> end_loop_landmarks.In(kItemTag);
auto landmark_lists = end_loop_landmarks.Out(kIterableTag)
.Cast<std::vector<NormalizedLandmarkList>>();
auto& end_loop_rects_next_frame =
graph.AddNode("EndLoopNormalizedRectCalculator");
batch_end >> end_loop_rects_next_frame.In(kBatchEndTag);
face_rect_next_frame >> end_loop_rects_next_frame.In(kItemTag);
auto face_rects_next_frame = end_loop_rects_next_frame.Out(kIterableTag)
.Cast<std::vector<NormalizedRect>>();
return {{
/* landmarks_lists= */ landmark_lists,
/* face_rects_next_frame= */ face_rects_next_frame,
/* presences= */ presences,
/* presence_scores= */ presence_scores,
}};
}
};
// clang-format off
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::face_landmarker::MultiFaceLandmarksDetectorGraph);
// NOLINT
// clang-format on
} // namespace face_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,324 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace face_landmarker {
namespace {
using ::file::Defaults;
using ::file::GetTextProto;
using ::mediapipe::NormalizedRect;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::core::TaskRunner;
using ::mediapipe::tasks::vision::DecodeImageFromFile;
using ::testing::ElementsAreArray;
using ::testing::EqualsProto;
using ::testing::Pointwise;
using ::testing::TestParamInfo;
using ::testing::TestWithParam;
using ::testing::Values;
using ::testing::proto::Approximately;
using ::testing::proto::Partially;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kFaceLandmarksDetectionModel[] = "face_landmark.tflite";
constexpr char kPortraitImageName[] = "portrait.jpg";
constexpr char kCatImageName[] = "cat.jpg";
constexpr char kPortraitExpectedFaceLandamrksName[] =
"portrait_expected_face_landmarks.pbtxt";
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageName[] = "image";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kNormRectName[] = "norm_rect";
constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
constexpr char kNormLandmarksName[] = "norm_landmarks";
constexpr char kFaceRectNextFrameTag[] = "FACE_RECT_NEXT_FRAME";
constexpr char kFaceRectNextFrameName[] = "face_rect_next_frame";
constexpr char kFaceRectsNextFrameTag[] = "FACE_RECTS_NEXT_FRAME";
constexpr char kFaceRectsNextFrameName[] = "face_rects_next_frame";
constexpr char kPresenceTag[] = "PRESENCE";
constexpr char kPresenceName[] = "presence";
constexpr char kPresenceScoreTag[] = "PRESENCE_SCORE";
constexpr char kPresenceScoreName[] = "presence_score";
constexpr float kFractionDiff = 0.05; // percentage
constexpr float kAbsMargin = 0.03;
// Helper function to create a Single Face Landmark TaskRunner.
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSingleFaceLandmarksTaskRunner(
absl::string_view model_name) {
Graph graph;
auto& face_landmark_detection = graph.AddNode(
"mediapipe.tasks.vision.face_landmarker."
"SingleFaceLandmarksDetectorGraph");
auto options = std::make_unique<proto::FaceLandmarksDetectorGraphOptions>();
options->mutable_base_options()->mutable_model_asset()->set_file_name(
JoinPath("./", kTestDataDirectory, model_name));
options->set_min_detection_confidence(0.5);
face_landmark_detection.GetOptions<proto::FaceLandmarksDetectorGraphOptions>()
.Swap(options.get());
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
face_landmark_detection.In(kImageTag);
graph[Input<NormalizedRect>(kNormRectTag)].SetName(kNormRectName) >>
face_landmark_detection.In(kNormRectTag);
face_landmark_detection.Out(kNormLandmarksTag).SetName(kNormLandmarksName) >>
graph[Output<NormalizedLandmarkList>(kNormLandmarksTag)];
face_landmark_detection.Out(kPresenceTag).SetName(kPresenceName) >>
graph[Output<bool>(kPresenceTag)];
face_landmark_detection.Out(kPresenceScoreTag).SetName(kPresenceScoreName) >>
graph[Output<float>(kPresenceScoreTag)];
face_landmark_detection.Out(kFaceRectNextFrameTag)
.SetName(kFaceRectNextFrameName) >>
graph[Output<NormalizedRect>(kFaceRectNextFrameTag)];
return TaskRunner::Create(
graph.GetConfig(),
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
}
// Helper function to create a Multi Face Landmark TaskRunner.
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateMultiFaceLandmarksTaskRunner(
absl::string_view model_name) {
Graph graph;
auto& face_landmark_detection = graph.AddNode(
"mediapipe.tasks.vision.face_landmarker."
"MultiFaceLandmarksDetectorGraph");
auto options = std::make_unique<proto::FaceLandmarksDetectorGraphOptions>();
options->mutable_base_options()->mutable_model_asset()->set_file_name(
JoinPath("./", kTestDataDirectory, model_name));
options->set_min_detection_confidence(0.5);
face_landmark_detection.GetOptions<proto::FaceLandmarksDetectorGraphOptions>()
.Swap(options.get());
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
face_landmark_detection.In(kImageTag);
graph[Input<std::vector<NormalizedRect>>(kNormRectTag)].SetName(
kNormRectName) >>
face_landmark_detection.In(kNormRectTag);
face_landmark_detection.Out(kNormLandmarksTag).SetName(kNormLandmarksName) >>
graph[Output<std::vector<NormalizedLandmarkList>>(kNormLandmarksTag)];
face_landmark_detection.Out(kPresenceTag).SetName(kPresenceName) >>
graph[Output<std::vector<bool>>(kPresenceTag)];
face_landmark_detection.Out(kPresenceScoreTag).SetName(kPresenceScoreName) >>
graph[Output<std::vector<float>>(kPresenceScoreTag)];
face_landmark_detection.Out(kFaceRectsNextFrameTag)
.SetName(kFaceRectsNextFrameName) >>
graph[Output<std::vector<NormalizedRect>>(kFaceRectsNextFrameTag)];
return TaskRunner::Create(
graph.GetConfig(),
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
}
NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) {
NormalizedLandmarkList expected_landmark_list;
MP_EXPECT_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, filename),
&expected_landmark_list, Defaults()));
return expected_landmark_list;
}
// Helper function to construct NormalizeRect proto.
NormalizedRect MakeNormRect(float x_center, float y_center, float width,
float height, float rotation) {
NormalizedRect hand_rect;
hand_rect.set_x_center(x_center);
hand_rect.set_y_center(y_center);
hand_rect.set_width(width);
hand_rect.set_height(height);
hand_rect.set_rotation(rotation);
return hand_rect;
}
// Struct holding the parameters for parameterized FaceLandmarksDetectionTest
// class.
struct SingeFaceTestParams {
// The name of this test, for convenience when displaying test results.
std::string test_name;
// The filename of the model to test.
std::string input_model_name;
// The filename of the test image.
std::string test_image_name;
// RoI on image to detect hands.
NormalizedRect norm_rect;
// Expected hand presence value.
bool expected_presence;
// The expected output landmarks positions.
NormalizedLandmarkList expected_landmarks;
// The max value difference between expected_positions and detected positions.
float landmarks_diff_threshold;
};
struct MultiFaceTestParams {
// The name of this test, for convenience when displaying test results.
std::string test_name;
// The filename of the model to test.
std::string input_model_name;
// The filename of the test image.
std::string test_image_name;
// RoI on image to detect hands.
std::vector<NormalizedRect> norm_rects;
// Expected hand presence value.
std::vector<bool> expected_presence;
// The expected output landmarks positions.
std::optional<std::vector<NormalizedLandmarkList>> expected_landmarks_lists;
// The max value difference between expected_positions and detected positions.
float landmarks_diff_threshold;
};
class SingleFaceLandmarksDetectionTest
: public testing::TestWithParam<SingeFaceTestParams> {};
TEST_P(SingleFaceLandmarksDetectionTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateSingleFaceLandmarksTaskRunner(
GetParam().input_model_name));
auto output_packets = task_runner->Process(
{{kImageName, MakePacket<Image>(std::move(image))},
{kNormRectName,
MakePacket<NormalizedRect>(std::move(GetParam().norm_rect))}});
MP_ASSERT_OK(output_packets);
const bool presence = (*output_packets)[kPresenceName].Get<bool>();
ASSERT_EQ(presence, GetParam().expected_presence);
if (presence) {
const NormalizedLandmarkList landmarks =
(*output_packets)[kNormLandmarksName].Get<NormalizedLandmarkList>();
const NormalizedLandmarkList& expected_landmarks =
GetParam().expected_landmarks;
EXPECT_THAT(
landmarks,
Approximately(Partially(EqualsProto(expected_landmarks)),
/*margin=*/kAbsMargin,
/*fraction=*/GetParam().landmarks_diff_threshold));
}
}
class MultiFaceLandmarksDetectionTest
: public testing::TestWithParam<MultiFaceTestParams> {};
TEST_P(MultiFaceLandmarksDetectionTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateMultiFaceLandmarksTaskRunner(
GetParam().input_model_name));
auto output_packets = task_runner->Process(
{{kImageName, MakePacket<Image>(std::move(image))},
{kNormRectName, MakePacket<std::vector<NormalizedRect>>(
std::move(GetParam().norm_rects))}});
MP_ASSERT_OK(output_packets);
const std::vector<bool>& presences =
(*output_packets)[kPresenceName].Get<std::vector<bool>>();
EXPECT_THAT(presences, ElementsAreArray(GetParam().expected_presence));
if (GetParam().expected_landmarks_lists) {
const std::vector<NormalizedLandmarkList>& landmarks_lists =
(*output_packets)[kNormLandmarksName]
.Get<std::vector<NormalizedLandmarkList>>();
EXPECT_THAT(landmarks_lists,
Pointwise(Approximately(
Partially(EqualsProto()), /*margin=*/kAbsMargin,
/*fraction=*/GetParam().landmarks_diff_threshold),
*GetParam().expected_landmarks_lists));
}
}
INSTANTIATE_TEST_SUITE_P(
FaceLandmarksDetectionTest, SingleFaceLandmarksDetectionTest,
Values(SingeFaceTestParams{
/* test_name= */ "Portrait",
/*input_model_name= */ kFaceLandmarksDetectionModel,
/*test_image_name=*/kPortraitImageName,
/*norm_rect= */ MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0),
/*expected_presence = */ true,
/*expected_landmarks = */
GetExpectedLandmarkList(kPortraitExpectedFaceLandamrksName),
/*landmarks_diff_threshold = */ kFractionDiff}),
[](const TestParamInfo<SingleFaceLandmarksDetectionTest::ParamType>& info) {
return info.param.test_name;
});
INSTANTIATE_TEST_SUITE_P(
FaceLandmarksDetectionTest, MultiFaceLandmarksDetectionTest,
Values(
MultiFaceTestParams{
/* test_name= */ "Portrait",
/*input_model_name= */ kFaceLandmarksDetectionModel,
/*test_image_name=*/kPortraitImageName,
/*norm_rects= */ {MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0)},
/*expected_presence = */ {true},
/*expected_landmarks_list = */
{{GetExpectedLandmarkList(kPortraitExpectedFaceLandamrksName)}},
/*landmarks_diff_threshold = */ kFractionDiff},
MultiFaceTestParams{
/* test_name= */ "NoFace",
/*input_model_name= */ kFaceLandmarksDetectionModel,
/*test_image_name=*/kCatImageName,
/*norm_rects= */ {MakeNormRect(0.5, 0.5, 1.0, 1.0, 0)},
/*expected_presence = */ {false},
/*expected_landmarks_list = */ std::nullopt,
/*landmarks_diff_threshold = */ kFractionDiff}),
[](const TestParamInfo<MultiFaceLandmarksDetectionTest::ParamType>& info) {
return info.param.test_name;
});
} // namespace
} // namespace face_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,31 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = [
"//mediapipe/tasks:internal",
])
licenses(["notice"])
mediapipe_proto_library(
name = "face_landmarks_detector_graph_options_proto",
srcs = ["face_landmarks_detector_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)

View File

@ -0,0 +1,38 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe.tasks.vision.face_landmarker.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.facelandmarker.proto";
option java_outer_classname = "FaceLandmarksDetectorGraphOptionsProto";
message FaceLandmarksDetectorGraphOptions {
extend mediapipe.CalculatorOptions {
optional FaceLandmarksDetectorGraphOptions ext = 508968149;
}
// Base options for configuring Task library, such as specifying the TfLite
// model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
// Minimum confidence value ([0.0, 1.0]) for confidence score to be considered
// successfully detecting a face in the image.
optional float min_detection_confidence = 2 [default = 0.5];
}

View File

@ -39,6 +39,7 @@ mediapipe_files(srcs = [
"deeplabv3.tflite",
"face_detection_full_range.tflite",
"face_detection_full_range_sparse.tflite",
"face_landmark.tflite",
"fist.jpg",
"fist.png",
"hand_landmark_full.tflite",
@ -136,6 +137,7 @@ filegroup(
"deeplabv3.tflite",
"face_detection_full_range.tflite",
"face_detection_full_range_sparse.tflite",
"face_landmark.tflite",
"hand_landmark_full.tflite",
"hand_landmark_lite.tflite",
"hand_landmarker.task",
@ -148,6 +150,7 @@ filegroup(
"mobilenet_v2_1.0_224.tflite",
"mobilenet_v3_small_100_224_embedder.tflite",
"palm_detection_full.tflite",
"portrait_expected_face_landmarks.pbtxt",
"selfie_segm_128_128_3.tflite",
"selfie_segm_144_256_3.tflite",
],
@ -169,6 +172,7 @@ filegroup(
"pointing_up_landmarks.pbtxt",
"pointing_up_rotated_landmarks.pbtxt",
"portrait_expected_detection.pbtxt",
"portrait_expected_face_landmarks.pbtxt",
"thumb_up_landmarks.pbtxt",
"thumb_up_rotated_landmarks.pbtxt",
"victory_landmarks.pbtxt",

File diff suppressed because it is too large Load Diff

View File

@ -258,8 +258,8 @@ def external_files():
http_file(
name = "com_google_mediapipe_face_landmark_tflite",
sha256 = "c603fa6149219a3e9487dc9abd7a0c24474c77263273d24868378cdf40aa26d1",
urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmark.tflite?generation=1662063817995673"],
sha256 = "1055cb9d4a9ca8b8c688902a3a5194311138ba256bcc94e336d8373a5f30c814",
urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmark.tflite?generation=1676316347980492"],
)
http_file(
@ -718,6 +718,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_detection.pbtxt?generation=1674261627835475"],
)
http_file(
name = "com_google_mediapipe_portrait_expected_face_landmarks_pbtxt",
sha256 = "4ac8587379bd072c36cda0d7345f5e592fae51b30522475e0b49c18aab108ce7",
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_landmarks.pbtxt?generation=1676316357333369"],
)
http_file(
name = "com_google_mediapipe_portrait_jpg",
sha256 = "a6f11efaa834706db23f275b6115058fa87fc7f14362681e6abe14e82749de3e",