Update face detector graph for downstream face landmarks graph.

PiperOrigin-RevId: 511566984
This commit is contained in:
MediaPipe Team 2023-02-22 12:29:51 -08:00 committed by Copybara-Service
parent fbbc13d756
commit 000aeeb036
7 changed files with 188 additions and 52 deletions

View File

@ -31,9 +31,8 @@ cc_library(
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto",
"//mediapipe/calculators/tflite:ssd_anchors_calculator", "//mediapipe/calculators/tflite:ssd_anchors_calculator",
"//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto", "//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto",
"//mediapipe/calculators/util:detection_label_id_to_text_calculator",
"//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto",
"//mediapipe/calculators/util:detection_projection_calculator", "//mediapipe/calculators/util:detection_projection_calculator",
"//mediapipe/calculators/util:detection_transformation_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto",
"//mediapipe/calculators/util:non_max_suppression_calculator", "//mediapipe/calculators/util:non_max_suppression_calculator",

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" #include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" #include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h" #include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h"
#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h"
#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h" #include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h" #include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h"
#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h" #include "mediapipe/calculators/util/rect_transformation_calculator.pb.h"
@ -58,21 +57,40 @@ namespace {
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kDetectionsTag[] = "DETECTIONS"; constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kAnchorsTag[] = "ANCHORS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kNormRectsTag[] = "NORM_RECTS";
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kMatrixTag[] = "MATRIX";
constexpr char kFaceRectsTag[] = "FACE_RECTS";
constexpr char kExpandedFaceRectsTag[] = "EXPANDED_FACE_RECTS";
constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS";
struct FaceDetectionOuts {
Source<std::vector<Detection>> face_detections;
Source<std::vector<NormalizedRect>> face_rects;
Source<std::vector<NormalizedRect>> expanded_face_rects;
Source<Image> image;
};
void ConfigureSsdAnchorsCalculator( void ConfigureSsdAnchorsCalculator(
mediapipe::SsdAnchorsCalculatorOptions* options) { mediapipe::SsdAnchorsCalculatorOptions* options) {
// TODO config SSD anchors parameters from metadata. // TODO config SSD anchors parameters from metadata.
options->set_num_layers(1); options->set_num_layers(4);
options->set_min_scale(0.1484375); options->set_min_scale(0.1484375);
options->set_max_scale(0.75); options->set_max_scale(0.75);
options->set_input_size_height(192); options->set_input_size_height(128);
options->set_input_size_width(192); options->set_input_size_width(128);
options->set_anchor_offset_x(0.5); options->set_anchor_offset_x(0.5);
options->set_anchor_offset_y(0.5); options->set_anchor_offset_y(0.5);
options->add_strides(4); options->add_strides(8);
options->add_strides(16);
options->add_strides(16);
options->add_strides(16);
options->add_aspect_ratios(1.0); options->add_aspect_ratios(1.0);
options->set_fixed_anchor_size(true); options->set_fixed_anchor_size(true);
options->set_interpolated_scale_aspect_ratio(0.0); options->set_interpolated_scale_aspect_ratio(1.0);
} }
void ConfigureTensorsToDetectionsCalculator( void ConfigureTensorsToDetectionsCalculator(
@ -80,7 +98,7 @@ void ConfigureTensorsToDetectionsCalculator(
mediapipe::TensorsToDetectionsCalculatorOptions* options) { mediapipe::TensorsToDetectionsCalculatorOptions* options) {
// TODO use metadata to configure these fields. // TODO use metadata to configure these fields.
options->set_num_classes(1); options->set_num_classes(1);
options->set_num_boxes(2304); options->set_num_boxes(896);
options->set_num_coords(16); options->set_num_coords(16);
options->set_box_coord_offset(0); options->set_box_coord_offset(0);
options->set_keypoint_coord_offset(4); options->set_keypoint_coord_offset(4);
@ -90,10 +108,10 @@ void ConfigureTensorsToDetectionsCalculator(
options->set_score_clipping_thresh(100.0); options->set_score_clipping_thresh(100.0);
options->set_reverse_output_order(true); options->set_reverse_output_order(true);
options->set_min_score_thresh(tasks_options.min_detection_confidence()); options->set_min_score_thresh(tasks_options.min_detection_confidence());
options->set_x_scale(192.0); options->set_x_scale(128.0);
options->set_y_scale(192.0); options->set_y_scale(128.0);
options->set_w_scale(192.0); options->set_w_scale(128.0);
options->set_h_scale(192.0); options->set_h_scale(128.0);
} }
void ConfigureNonMaxSuppressionCalculator( void ConfigureNonMaxSuppressionCalculator(
@ -107,8 +125,70 @@ void ConfigureNonMaxSuppressionCalculator(
mediapipe::NonMaxSuppressionCalculatorOptions::WEIGHTED); mediapipe::NonMaxSuppressionCalculatorOptions::WEIGHTED);
} }
void ConfigureDetectionsToRectsCalculator(
mediapipe::DetectionsToRectsCalculatorOptions* options) {
// Left eye.
options->set_rotation_vector_start_keypoint_index(0);
// Right ete.
options->set_rotation_vector_end_keypoint_index(1);
options->set_rotation_vector_target_angle_degrees(0);
}
void ConfigureRectTransformationCalculator(
mediapipe::RectTransformationCalculatorOptions* options) {
options->set_scale_x(1.5);
options->set_scale_y(1.5);
}
} // namespace } // namespace
// A "mediapipe.tasks.vision.face_detector.FaceDetectorGraph" performs face
// detection.
//
// Inputs:
// IMAGE - Image
// Image to perform detection on.
// NORM_RECT - NormalizedRect @Optional
// Describes image rotation and region of image to perform detection on. If
// not provided, whole image is used for face detection.
//
// Outputs:
// DETECTIONS - std::vector<Detection>
// Detected face with maximum `num_faces` specified in options.
// FACE_RECTS - std::vector<NormalizedRect>
// Detected face bounding boxes in normalized coordinates.
// EXPANDED_FACE_RECTS - std::vector<NormalizedRect>
// Expanded face bounding boxes in normalized coordinates so that bounding
// boxes likely contain the whole face. This is usually used as RoI for face
// landmarks detection to run on.
// IMAGE - Image
// The input image that the face detector runs on and has the pixel data
// stored on the target storage (CPU vs GPU).
// All returned coordinates are in the unrotated and uncropped input image
// coordinates system.
//
// Example:
// node {
// calculator: "mediapipe.tasks.vision.face_detector.FaceDetectorGraph"
// input_stream: "IMAGE:image"
// input_stream: "NORM_RECT:norm_rect"
// output_stream: "DETECTIONS:palm_detections"
// output_stream: "FACE_RECTS:face_rects"
// output_stream: "EXPANDED_FACE_RECTS:expanded_face_rects"
// output_stream: "IMAGE:image_out"
// options {
// [mediapipe.tasks.vision.face_detector.proto.FaceDetectorGraphOptions.ext]
// {
// base_options {
// model_asset {
// file_name: "face_detection.tflite"
// }
// }
// min_detection_confidence: 0.5
// num_faces: 2
// }
// }
// }
class FaceDetectorGraph : public core::ModelTaskGraph { class FaceDetectorGraph : public core::ModelTaskGraph {
public: public:
absl::StatusOr<CalculatorGraphConfig> GetConfig( absl::StatusOr<CalculatorGraphConfig> GetConfig(
@ -116,17 +196,24 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
ASSIGN_OR_RETURN(const auto* model_resources, ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<FaceDetectorGraphOptions>(sc)); CreateModelResources<FaceDetectorGraphOptions>(sc));
Graph graph; Graph graph;
ASSIGN_OR_RETURN(auto face_detections, ASSIGN_OR_RETURN(auto outs,
BuildFaceDetectionSubgraph( BuildFaceDetectionSubgraph(
sc->Options<FaceDetectorGraphOptions>(), sc->Options<FaceDetectorGraphOptions>(),
*model_resources, graph[Input<Image>(kImageTag)], *model_resources, graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>(kNormRectTag)], graph)); graph[Input<NormalizedRect>(kNormRectTag)], graph));
face_detections >> graph[Output<std::vector<Detection>>(kDetectionsTag)]; outs.face_detections >>
graph.Out(kDetectionsTag).Cast<std::vector<Detection>>();
outs.face_rects >>
graph.Out(kFaceRectsTag).Cast<std::vector<NormalizedRect>>();
outs.expanded_face_rects >>
graph.Out(kExpandedFaceRectsTag).Cast<std::vector<NormalizedRect>>();
outs.image >> graph.Out(kImageTag).Cast<Image>();
return graph.GetConfig(); return graph.GetConfig();
} }
private: private:
absl::StatusOr<Source<std::vector<Detection>>> BuildFaceDetectionSubgraph( absl::StatusOr<FaceDetectionOuts> BuildFaceDetectionSubgraph(
const FaceDetectorGraphOptions& subgraph_options, const FaceDetectorGraphOptions& subgraph_options,
const core::ModelResources& model_resources, Source<Image> image_in, const core::ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) { Source<NormalizedRect> norm_rect_in, Graph& graph) {
@ -149,17 +236,18 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
image_to_tensor_options.set_keep_aspect_ratio(true); image_to_tensor_options.set_keep_aspect_ratio(true);
image_to_tensor_options.set_border_mode( image_to_tensor_options.set_border_mode(
mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO);
image_in >> preprocessing.In("IMAGE"); image_in >> preprocessing.In(kImageTag);
norm_rect_in >> preprocessing.In("NORM_RECT"); norm_rect_in >> preprocessing.In(kNormRectTag);
auto preprocessed_tensors = preprocessing.Out("TENSORS"); auto preprocessed_tensors = preprocessing.Out(kTensorsTag);
auto matrix = preprocessing.Out("MATRIX"); auto matrix = preprocessing.Out(kMatrixTag);
auto image_size = preprocessing.Out(kImageSizeTag);
// Face detection model inferece. // Face detection model inferece.
auto& inference = AddInference( auto& inference = AddInference(
model_resources, subgraph_options.base_options().acceleration(), graph); model_resources, subgraph_options.base_options().acceleration(), graph);
preprocessed_tensors >> inference.In("TENSORS"); preprocessed_tensors >> inference.In(kTensorsTag);
auto model_output_tensors = auto model_output_tensors =
inference.Out("TENSORS").Cast<std::vector<Tensor>>(); inference.Out(kTensorsTag).Cast<std::vector<Tensor>>();
// Generates a single side packet containing a vector of SSD anchors. // Generates a single side packet containing a vector of SSD anchors.
auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator"); auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator");
@ -174,9 +262,9 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
subgraph_options, subgraph_options,
&tensors_to_detections &tensors_to_detections
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()); .GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>());
model_output_tensors >> tensors_to_detections.In("TENSORS"); model_output_tensors >> tensors_to_detections.In(kTensorsTag);
anchors >> tensors_to_detections.SideIn("ANCHORS"); anchors >> tensors_to_detections.SideIn(kAnchorsTag);
auto detections = tensors_to_detections.Out("DETECTIONS"); auto detections = tensors_to_detections.Out(kDetectionsTag);
// Non maximum suppression removes redundant face detections. // Non maximum suppression removes redundant face detections.
auto& non_maximum_suppression = auto& non_maximum_suppression =
@ -190,12 +278,60 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
// Projects detections back into the input image coordinates system. // Projects detections back into the input image coordinates system.
auto& detection_projection = graph.AddNode("DetectionProjectionCalculator"); auto& detection_projection = graph.AddNode("DetectionProjectionCalculator");
nms_detections >> detection_projection.In("DETECTIONS"); nms_detections >> detection_projection.In(kDetectionsTag);
matrix >> detection_projection.In("PROJECTION_MATRIX"); matrix >> detection_projection.In(kProjectionMatrixTag);
auto face_detections = auto face_detections = detection_projection.Out(kDetectionsTag);
detection_projection[Output<std::vector<Detection>>("DETECTIONS")];
return {face_detections}; // Clip face detections to maximum number of faces;
auto& clip_detection_vector_size =
graph.AddNode("ClipDetectionVectorSizeCalculator");
clip_detection_vector_size
.GetOptions<mediapipe::ClipVectorSizeCalculatorOptions>()
.set_max_vec_size(subgraph_options.num_faces());
face_detections >> clip_detection_vector_size.In("");
auto clipped_face_detections =
clip_detection_vector_size.Out("").Cast<std::vector<Detection>>();
// Converts results of face detection into a rectangle (normalized by image
// size) that encloses the face and is rotated such that the line connecting
// left eye and right eye is aligned with the X-axis of the rectangle.
auto& detections_to_rects = graph.AddNode("DetectionsToRectsCalculator");
ConfigureDetectionsToRectsCalculator(
&detections_to_rects
.GetOptions<mediapipe::DetectionsToRectsCalculatorOptions>());
image_size >> detections_to_rects.In(kImageSizeTag);
clipped_face_detections >> detections_to_rects.In(kDetectionsTag);
auto face_rects = detections_to_rects.Out(kNormRectsTag)
.Cast<std::vector<NormalizedRect>>();
// Expands and shifts the rectangle that contains the face so that it's
// likely to cover the entire face.
auto& rect_transformation = graph.AddNode("RectTransformationCalculator");
ConfigureRectTransformationCalculator(
&rect_transformation
.GetOptions<mediapipe::RectTransformationCalculatorOptions>());
face_rects >> rect_transformation.In(kNormRectsTag);
image_size >> rect_transformation.In(kImageSizeTag);
auto expanded_face_rects =
rect_transformation.Out("").Cast<std::vector<NormalizedRect>>();
// Calculator to convert relative detection bounding boxes to pixel
// detection bounding boxes.
auto& detection_transformation =
graph.AddNode("DetectionTransformationCalculator");
detection_projection.Out(kDetectionsTag) >>
detection_transformation.In(kDetectionsTag);
preprocessing.Out(kImageSizeTag) >>
detection_transformation.In(kImageSizeTag);
auto face_pixel_detections =
detection_transformation.Out(kPixelDetectionsTag)
.Cast<std::vector<Detection>>();
return FaceDetectionOuts{
/* face_detections= */ face_pixel_detections,
/* face_rects= */ face_rects,
/* expanded_face_rects= */ expanded_face_rects,
/* image= */ preprocessing.Out("IMAGE").Cast<Image>()};
} }
}; };

View File

@ -74,6 +74,8 @@ constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kFullRangeBlazeFaceModel[] = "face_detection_full_range.tflite"; constexpr char kFullRangeBlazeFaceModel[] = "face_detection_full_range.tflite";
constexpr char kFullRangeSparseBlazeFaceModel[] = constexpr char kFullRangeSparseBlazeFaceModel[] =
"face_detection_full_range_sparse.tflite"; "face_detection_full_range_sparse.tflite";
constexpr char kShortRangeBlazeFaceModel[] =
"face_detection_short_range.tflite";
constexpr char kPortraitImage[] = "portrait.jpg"; constexpr char kPortraitImage[] = "portrait.jpg";
constexpr char kPortraitExpectedDetection[] = constexpr char kPortraitExpectedDetection[] =
"portrait_expected_detection.pbtxt"; "portrait_expected_detection.pbtxt";
@ -161,14 +163,8 @@ TEST_P(FaceDetectorGraphTest, Succeed) {
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
FaceDetectorGraphTest, FaceDetectorGraphTest, FaceDetectorGraphTest, FaceDetectorGraphTest,
Values(TestParams{.test_name = "FullRange", Values(TestParams{.test_name = "ShortRange",
.face_detection_model_name = kFullRangeBlazeFaceModel, .face_detection_model_name = kShortRangeBlazeFaceModel,
.test_image_name = kPortraitImage,
.expected_result = {GetExpectedFaceDetectionResult(
kPortraitExpectedDetection)}},
TestParams{
.test_name = "FullRangeSparse",
.face_detection_model_name = kFullRangeSparseBlazeFaceModel,
.test_image_name = kPortraitImage, .test_image_name = kPortraitImage,
.expected_result = {GetExpectedFaceDetectionResult( .expected_result = {GetExpectedFaceDetectionResult(
kPortraitExpectedDetection)}}), kPortraitExpectedDetection)}}),

View File

@ -39,4 +39,7 @@ message FaceDetectorGraphOptions {
// IoU threshold ([0,0, 1.0]) for non-maximu-suppression to be considered // IoU threshold ([0,0, 1.0]) for non-maximu-suppression to be considered
// duplicate detetions. // duplicate detetions.
optional float min_suppression_threshold = 3 [default = 0.5]; optional float min_suppression_threshold = 3 [default = 0.5];
// Maximum number of faces to detect in the image.
optional int32 num_faces = 4 [default = 1];
} }

View File

@ -39,6 +39,7 @@ mediapipe_files(srcs = [
"deeplabv3.tflite", "deeplabv3.tflite",
"face_detection_full_range.tflite", "face_detection_full_range.tflite",
"face_detection_full_range_sparse.tflite", "face_detection_full_range_sparse.tflite",
"face_detection_short_range.tflite",
"face_landmark.tflite", "face_landmark.tflite",
"fist.jpg", "fist.jpg",
"fist.png", "fist.png",
@ -137,6 +138,7 @@ filegroup(
"deeplabv3.tflite", "deeplabv3.tflite",
"face_detection_full_range.tflite", "face_detection_full_range.tflite",
"face_detection_full_range_sparse.tflite", "face_detection_full_range_sparse.tflite",
"face_detection_short_range.tflite",
"face_landmark.tflite", "face_landmark.tflite",
"hand_landmark_full.tflite", "hand_landmark_full.tflite",
"hand_landmark_lite.tflite", "hand_landmark_lite.tflite",

View File

@ -1,12 +1,12 @@
# proto-file: mediapipe/framework/formats/detection.proto # proto-file: mediapipe/framework/formats/detection.proto
# proto-message: Detection # proto-message: Detection
location_data { location_data {
format: RELATIVE_BOUNDING_BOX format: BOUNDING_BOX
relative_bounding_box { bounding_box {
xmin: 0.35494408 xmin: 283
ymin: 0.1059662 ymin: 115
width: 0.28768203 width: 234
height: 0.23037356 height: 234
} }
relative_keypoints { relative_keypoints {
x: 0.44416338 x: 0.44416338
@ -25,7 +25,7 @@ location_data {
y: 0.2719954 y: 0.2719954
} }
relative_keypoints { relative_keypoints {
x: 0.37245658 x: 0.36063305
y: 0.20143759 y: 0.20143759
} }
relative_keypoints { relative_keypoints {

View File

@ -252,8 +252,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_face_detection_short_range_tflite", name = "com_google_mediapipe_face_detection_short_range_tflite",
sha256 = "3bc182eb9f33925d9e58b5c8d59308a760f4adea8f282370e428c51212c26633", sha256 = "bbff11cebd1eb27a1e004cae0b0e63ec8c551cbf34a4451148b4908b8db3eca8",
urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_short_range.tflite?generation=1661875748538815"], urls = ["https://storage.googleapis.com/mediapipe-assets/face_detection_short_range.tflite?generation=1677044301978921"],
) )
http_file( http_file(
@ -264,8 +264,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_face_landmark_with_attention_tflite", name = "com_google_mediapipe_face_landmark_with_attention_tflite",
sha256 = "883b7411747bac657c30c462d305d312e9dec6adbf8b85e2f5d8d722fca9455d", sha256 = "e06a804e0144f9929eda782122916b35d60c697c3c9344013ca2bbe76a6ce2b4",
urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmark_with_attention.tflite?generation=1661875751615925"], urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmark_with_attention.tflite?generation=1676415468821650"],
) )
http_file( http_file(
@ -714,8 +714,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_portrait_expected_detection_pbtxt", name = "com_google_mediapipe_portrait_expected_detection_pbtxt",
sha256 = "bb54e08e87844ef14bb185d5cb808908eb6011bfa6db48bd22d9650f6fda338b", sha256 = "ace755f0fd0ba3b2d75e4f8bb1b08d2f65975fd5570898004540dfef735c1c3d",
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_detection.pbtxt?generation=1674261627835475"], urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_detection.pbtxt?generation=1677044311581104"],
) )
http_file( http_file(