From b7dd4cfe723899dc572e22016e4f4d48636d223b Mon Sep 17 00:00:00 2001 From: dmaletskiy Date: Mon, 12 Jul 2021 17:52:15 +0300 Subject: [PATCH] feat: Added face mesh DLL example with side models Change List: - added graphs for running face mesh dll example with face_detections and face_landmarks models paths saved in side pockets (these pathed can be configured in `MPFaceMeshDetector` constructor - added possibility to set maximum nuber of faces to detect (by default 1) --- .../examples/desktop/face_mesh_dll/BUILD | 6 +- .../desktop/face_mesh_dll/face_mesh_cpu.cpp | 39 ++- .../desktop/face_mesh_dll/face_mesh_lib.cpp | 198 ++++++++++---- .../desktop/face_mesh_dll/face_mesh_lib.h | 38 ++- mediapipe/modules/face_detection/BUILD | 12 + ...detection_short_range_side_model_cpu.pbtxt | 86 ++++++ mediapipe/modules/face_landmark/BUILD | 38 +++ ...ont_side_model_cpu_with_face_counter.pbtxt | 256 ++++++++++++++++++ .../face_landmark_side_model_cpu.pbtxt | 143 ++++++++++ 9 files changed, 741 insertions(+), 75 deletions(-) create mode 100644 mediapipe/modules/face_detection/face_detection_short_range_side_model_cpu.pbtxt create mode 100644 mediapipe/modules/face_landmark/face_landmark_front_side_model_cpu_with_face_counter.pbtxt create mode 100644 mediapipe/modules/face_landmark/face_landmark_side_model_cpu.pbtxt diff --git a/mediapipe/examples/desktop/face_mesh_dll/BUILD b/mediapipe/examples/desktop/face_mesh_dll/BUILD index ff5709093..3a20d0f43 100644 --- a/mediapipe/examples/desktop/face_mesh_dll/BUILD +++ b/mediapipe/examples/desktop/face_mesh_dll/BUILD @@ -47,9 +47,9 @@ windows_dll_library( "//mediapipe/calculators/core:constant_side_packet_calculator", "//mediapipe/calculators/core:flow_limiter_calculator", - "//mediapipe/modules/face_landmark:face_landmark_front_cpu_with_face_counter", - - + "//mediapipe/calculators/tflite:tflite_model_calculator", + "//mediapipe/calculators/util:local_file_contents_calculator", + "//mediapipe/modules/face_landmark:face_landmark_front_side_model_cpu_with_face_counter", ] ) diff --git a/mediapipe/examples/desktop/face_mesh_dll/face_mesh_cpu.cpp b/mediapipe/examples/desktop/face_mesh_dll/face_mesh_cpu.cpp index 210d19c07..90462477a 100644 --- a/mediapipe/examples/desktop/face_mesh_dll/face_mesh_cpu.cpp +++ b/mediapipe/examples/desktop/face_mesh_dll/face_mesh_cpu.cpp @@ -21,7 +21,22 @@ int main(int argc, char **argv) { LOG(INFO) << "VideoCapture initialized."; - MPFaceMeshDetector *faceMeshDetector = FaceMeshDetector_Construct(); + // Maximum number of faces that can be detected + constexpr int maxNumFaces = 1; + constexpr char face_detection_model_path[] = + "mediapipe/modules/face_detection/face_detection_short_range.tflite"; + constexpr char face_landmark_model_path[] = + "mediapipe/modules/face_landmark/face_landmark.tflite"; + + MPFaceMeshDetector *faceMeshDetector = FaceMeshDetector_Construct( + maxNumFaces, face_detection_model_path, face_landmark_model_path); + + // allocate memory for face landmarks + auto multiFaceLandmarks = new cv::Point2f *[maxNumFaces]; + constexpr auto mediapipeFaceLandmarksNum = 468; + for (int i = 0; i < maxNumFaces; ++i) { + multiFaceLandmarks[i] = new cv::Point2f[mediapipeFaceLandmarksNum]; + } LOG(INFO) << "FaceMeshDetector constructed."; @@ -36,26 +51,26 @@ int main(int argc, char **argv) { LOG(INFO) << "Ignore empty frames from camera."; continue; } + cv::Mat camera_frame; cv::cvtColor(camera_frame_raw, camera_frame, cv::COLOR_BGR2RGB); cv::flip(camera_frame, camera_frame, /*flipcode=HORIZONTAL*/ 1); - std::unique_ptr>> multi_face_landmarks( - reinterpret_cast> *>( - FaceMeshDetector_ProcessFrame2D(faceMeshDetector, camera_frame))); + int faceCount = + FaceMeshDetector_GetFaceCount(faceMeshDetector, camera_frame); - const auto multi_face_landmarks_num = multi_face_landmarks->size(); + LOG(INFO) << "Detected faces num: " << faceCount; - LOG(INFO) << "Got multi_face_landmarks_num: " << multi_face_landmarks_num; + if (faceCount > 0) { - if (multi_face_landmarks_num) { - auto &face_landmarks = multi_face_landmarks->operator[](0); + FaceMeshDetector_GetFaceLandmarks(faceMeshDetector, multiFaceLandmarks); + + auto &face_landmarks = multiFaceLandmarks[0]; auto &landmark = face_landmarks[0]; LOG(INFO) << "First landmark: x - " << landmark.x << ", y - " << landmark.y; } - const int pressed_key = cv::waitKey(5); if (pressed_key >= 0 && pressed_key != 255) grab_frames = false; @@ -65,5 +80,11 @@ int main(int argc, char **argv) { LOG(INFO) << "Shutting down."; + // deallocate memory for face landmarks + for (int i = 0; i < maxNumFaces; ++i) { + delete[] multiFaceLandmarks[i]; + } + delete[] multiFaceLandmarks; + FaceMeshDetector_Destruct(faceMeshDetector); } \ No newline at end of file diff --git a/mediapipe/examples/desktop/face_mesh_dll/face_mesh_lib.cpp b/mediapipe/examples/desktop/face_mesh_dll/face_mesh_lib.cpp index 54ac3185a..5bba0efea 100644 --- a/mediapipe/examples/desktop/face_mesh_dll/face_mesh_lib.cpp +++ b/mediapipe/examples/desktop/face_mesh_dll/face_mesh_lib.cpp @@ -2,20 +2,51 @@ #include "face_mesh_lib.h" -MPFaceMeshDetector::MPFaceMeshDetector() { - const auto status = InitFaceMeshDetector(); +#define DEBUG + +MPFaceMeshDetector::MPFaceMeshDetector(int numFaces, + const char *face_detection_model_path, + const char *face_landmark_model_path) { + const auto status = InitFaceMeshDetector(numFaces, face_detection_model_path, + face_landmark_model_path); if (!status.ok()) { LOG(INFO) << "Failed constructing FaceMeshDetector."; + LOG(INFO) << status.message(); } } -absl::Status MPFaceMeshDetector::InitFaceMeshDetector() { - LOG(INFO) << "Get calculator graph config contents: " << graphConfig; +absl::Status +MPFaceMeshDetector::InitFaceMeshDetector(int numFaces, + const char *face_detection_model_path, + const char *face_landmark_model_path) { + if (numFaces <= 0) { + numFaces = 1; + } + + if (face_detection_model_path == nullptr) { + face_detection_model_path = + "mediapipe/modules/face_detection/face_detection_short_range.tflite"; + } + + if (face_landmark_model_path == nullptr) { + face_landmark_model_path = + "mediapipe/modules/face_landmark/face_landmark.tflite"; + } + + auto preparedGraphConfig = absl::StrReplaceAll( + graphConfig, {{"$numFaces", std::to_string(numFaces)}}); + preparedGraphConfig = absl::StrReplaceAll( + preparedGraphConfig, + {{"$faceDetectionModelPath", face_detection_model_path}}); + preparedGraphConfig = absl::StrReplaceAll( + preparedGraphConfig, + {{"$faceLandmarkModelPath", face_landmark_model_path}}); + + LOG(INFO) << "Get calculator graph config contents: " << preparedGraphConfig; mediapipe::CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie( - graphConfig); - + preparedGraphConfig); LOG(INFO) << "Initialize the calculator graph."; MP_RETURN_IF_ERROR(graph.Initialize(config)); @@ -34,13 +65,13 @@ absl::Status MPFaceMeshDetector::InitFaceMeshDetector() { MP_RETURN_IF_ERROR(graph.StartRun({})); - return absl::Status(); + LOG(INFO) << "MPFaceMeshDetector constructed successfully."; + + return absl::OkStatus(); } -absl::Status MPFaceMeshDetector::ProcessFrameWithStatus( - const cv::Mat &camera_frame, - std::unique_ptr>> - &multi_face_landmarks) { +absl::Status +MPFaceMeshDetector::GetFaceCountWithStatus(const cv::Mat &camera_frame) { // Wrap Mat into an ImageFrame. auto input_frame = absl::make_unique( mediapipe::ImageFormat::SRGB, camera_frame.cols, camera_frame.rows, @@ -49,82 +80,99 @@ absl::Status MPFaceMeshDetector::ProcessFrameWithStatus( camera_frame.copyTo(input_frame_mat); // Send image packet into the graph. - - size_t frame_timestamp_us = - (double)cv::getTickCount() / (double)cv::getTickFrequency() * 1e6; + size_t frame_timestamp_us = static_cast(cv::getTickCount()) / + static_cast(cv::getTickFrequency()) * 1e6; MP_RETURN_IF_ERROR(graph.AddPacketToInputStream( kInputStream, mediapipe::Adopt(input_frame.release()) .At(mediapipe::Timestamp(frame_timestamp_us)))); - LOG(INFO) << "Pushed new frame."; - mediapipe::Packet face_count_packet; if (!face_count_poller_ptr || !face_count_poller_ptr->Next(&face_count_packet)) { - LOG(INFO) << "Failed during getting next face_count_packet."; - - return absl::Status(); + return absl::CancelledError( + "Failed during getting next face_count_packet."); } + auto &face_count = face_count_packet.Get(); - if (!face_count) { - return absl::Status(); + faceCount = face_count; + + return absl::OkStatus(); +} + +int MPFaceMeshDetector::GetFaceCount(const cv::Mat &camera_frame) { + const auto status = GetFaceCountWithStatus(camera_frame); + if (!status.ok()) { + LOG(INFO) << "Failed GetFaceCount."; + LOG(INFO) << status.message(); + } + + return faceCount; +} + +absl::Status MPFaceMeshDetector::GetFaceLandmarksWithStatus( + cv::Point2f **multi_face_landmarks) { + + if (faceCount <= 0) { + return absl::CancelledError( + "Failed during gettinglandmarks, because faceCount is <= 0."); } mediapipe::Packet face_landmarks_packet; if (!landmarks_poller_ptr || !landmarks_poller_ptr->Next(&face_landmarks_packet)) { - LOG(INFO) << "Failed during getting next landmarks_packet."; - - return absl::Status(); + return absl::CancelledError("Failed during getting next landmarks_packet."); } auto &output_landmarks_vector = face_landmarks_packet .Get<::std::vector<::mediapipe::NormalizedLandmarkList>>(); - multi_face_landmarks->reserve(output_landmarks_vector.size()); - - for (const auto &normalizedLandmarkList : output_landmarks_vector) { - multi_face_landmarks->emplace_back(); - - auto &face_landmarks = multi_face_landmarks->back(); - + for (int i = 0; i < faceCount; ++i) { + const auto &normalizedLandmarkList = output_landmarks_vector[i]; const auto landmarks_num = normalizedLandmarkList.landmark_size(); + auto &face_landmarks = multi_face_landmarks[i]; - face_landmarks.reserve(landmarks_num); - - for (int i = 0; i < landmarks_num; ++i) { - auto &landmark = normalizedLandmarkList.landmark(i); - - face_landmarks.emplace_back(landmark.x(), landmark.y()); + for (int j = 0; j < landmarks_num; ++j) { + const auto &landmark = normalizedLandmarkList.landmark(j); + face_landmarks[j].x = landmark.x(); + face_landmarks[j].y = landmark.y(); } } - return absl::Status(); + faceCount = -1; + + return absl::OkStatus(); } -std::vector> * -MPFaceMeshDetector::ProcessFrame2D(const cv::Mat &camera_frame) { - auto landmarks = std::make_unique>>(); - - ProcessFrameWithStatus(camera_frame, landmarks); - - return landmarks.release(); +void MPFaceMeshDetector::GetFaceLandmarks(cv::Point2f **multi_face_landmarks) { + const auto status = GetFaceLandmarksWithStatus(multi_face_landmarks); + if (!status.ok()) { + LOG(INFO) << "Failed GetFaceLandmarks."; + LOG(INFO) << status.message(); + } } extern "C" { -DLLEXPORT MPFaceMeshDetector *FaceMeshDetector_Construct() { - return new MPFaceMeshDetector(); +DLLEXPORT MPFaceMeshDetector * +FaceMeshDetector_Construct(int numFaces, const char *face_detection_model_path, + const char *face_landmark_model_path) { + return new MPFaceMeshDetector(numFaces, face_detection_model_path, + face_landmark_model_path); } DLLEXPORT void FaceMeshDetector_Destruct(MPFaceMeshDetector *detector) { delete detector; } -DLLEXPORT void * -FaceMeshDetector_ProcessFrame2D(MPFaceMeshDetector *detector, - const cv::Mat &camera_frame) { - return reinterpret_cast(detector->ProcessFrame2D(camera_frame)); +DLLEXPORT int FaceMeshDetector_GetFaceCount(MPFaceMeshDetector *detector, + const cv::Mat &camera_frame) { + return detector->GetFaceCount(camera_frame); +} + +DLLEXPORT void +FaceMeshDetector_GetFaceLandmarks(MPFaceMeshDetector *detector, + cv::Point2f **multi_face_landmarks) { + detector->GetFaceLandmarks(multi_face_landmarks); } } @@ -163,16 +211,60 @@ node { output_side_packet: "PACKET:num_faces" node_options: { [type.googleapis.com/mediapipe.ConstantSidePacketCalculatorOptions]: { - packet { int_value: 1 } + packet { int_value: $numFaces } } } } +# Defines side packets for further use in the graph. +node { + calculator: "ConstantSidePacketCalculator" + output_side_packet: "PACKET:face_detection_model_path" + options: { + [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { + packet { string_value: "$faceDetectionModelPath" } + } + } +} + +# Defines side packets for further use in the graph. +node { + calculator: "ConstantSidePacketCalculator" + output_side_packet: "PACKET:face_landmark_model_path" + node_options: { + [type.googleapis.com/mediapipe.ConstantSidePacketCalculatorOptions]: { + packet { string_value: "$faceLandmarkModelPath" } + } + } +} + +node { + calculator: "LocalFileContentsCalculator" + input_side_packet: "FILE_PATH:0:face_detection_model_path" + input_side_packet: "FILE_PATH:1:face_landmark_model_path" + output_side_packet: "CONTENTS:0:face_detection_model_blob" + output_side_packet: "CONTENTS:1:face_landmark_model_blob" +} + +node { + calculator: "TfLiteModelCalculator" + input_side_packet: "MODEL_BLOB:face_detection_model_blob" + output_side_packet: "MODEL:face_detection_model" +} +node { + calculator: "TfLiteModelCalculator" + input_side_packet: "MODEL_BLOB:face_landmark_model_blob" + output_side_packet: "MODEL:face_landmark_model" +} + + # Subgraph that detects faces and corresponding landmarks. node { - calculator: "FaceLandmarkFrontCpuWithFaceCounter" + calculator: "FaceLandmarkFrontSideModelCpuWithFaceCounter" input_stream: "IMAGE:throttled_input_video" input_side_packet: "NUM_FACES:num_faces" + input_side_packet: "MODEL:0:face_detection_model" + input_side_packet: "MODEL:1:face_landmark_model" output_stream: "LANDMARKS:multi_face_landmarks" output_stream: "ROIS_FROM_LANDMARKS:face_rects_from_landmarks" output_stream: "DETECTIONS:face_detections" diff --git a/mediapipe/examples/desktop/face_mesh_dll/face_mesh_lib.h b/mediapipe/examples/desktop/face_mesh_dll/face_mesh_lib.h index d6fe713e4..88c3ed680 100644 --- a/mediapipe/examples/desktop/face_mesh_dll/face_mesh_lib.h +++ b/mediapipe/examples/desktop/face_mesh_dll/face_mesh_lib.h @@ -13,11 +13,13 @@ #include "absl/flags/flag.h" #include "absl/flags/parse.h" +#include "absl/strings/str_replace.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_graph.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/output_stream_poller.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/opencv_highgui_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" @@ -27,15 +29,20 @@ class MPFaceMeshDetector { public: - MPFaceMeshDetector(); - std::vector> *ProcessFrame2D(const cv::Mat &camera_frame); + MPFaceMeshDetector(int numFaces, const char *face_detection_model_path, + const char *face_landmark_model_path); + int GetFaceCount(const cv::Mat &camera_frame); + void GetFaceLandmarks(cv::Point2f **multi_face_landmarks); private: - absl::Status InitFaceMeshDetector(); - absl::Status - ProcessFrameWithStatus(const cv::Mat &camera_frame, - std::unique_ptr>> - &multi_face_landmarks); + absl::Status InitFaceMeshDetector(int numFaces, + const char *face_detection_model_path, + const char *face_landmark_model_path); + absl::Status ProcessFrameWithStatus( + const cv::Mat &camera_frame, + std::vector> &multi_face_landmarks); + absl::Status GetFaceCountWithStatus(const cv::Mat &camera_frame); + absl::Status GetFaceLandmarksWithStatus(cv::Point2f **multi_face_landmarks); static const char kInputStream[]; static const char kOutputStream_landmarks[]; @@ -47,18 +54,29 @@ private: std::unique_ptr landmarks_poller_ptr; std::unique_ptr face_count_poller_ptr; + + int faceCount = -1; }; #ifdef __cplusplus extern "C" { #endif -DLLEXPORT MPFaceMeshDetector *FaceMeshDetector_Construct(); +DLLEXPORT MPFaceMeshDetector *FaceMeshDetector_Construct( + int numFaces = 1, + const char *face_detection_model_path = + "mediapipe/modules/face_detection/face_detection_short_range.tflite", + const char *face_landmark_model_path = + "mediapipe/modules/face_landmark/face_landmark.tflite"); + DLLEXPORT void FaceMeshDetector_Destruct(MPFaceMeshDetector *detector); -DLLEXPORT void *FaceMeshDetector_ProcessFrame2D(MPFaceMeshDetector *detector, - const cv::Mat &camera_frame); +DLLEXPORT int FaceMeshDetector_GetFaceCount(MPFaceMeshDetector *detector, + const cv::Mat &camera_frame); +DLLEXPORT void +FaceMeshDetector_GetFaceLandmarks(MPFaceMeshDetector *detector, + cv::Point2f **multi_face_landmarks); #ifdef __cplusplus }; diff --git a/mediapipe/modules/face_detection/BUILD b/mediapipe/modules/face_detection/BUILD index 839418c77..4a0b41544 100644 --- a/mediapipe/modules/face_detection/BUILD +++ b/mediapipe/modules/face_detection/BUILD @@ -57,6 +57,18 @@ mediapipe_simple_subgraph( ], ) +mediapipe_simple_subgraph( + name = "face_detection_short_range_side_model_cpu", + graph = "face_detection_short_range_side_model_cpu.pbtxt", + register_as = "FaceDetectionShortRangeSideModelCpu", + deps = [ + ":face_detection_short_range_common", + "//mediapipe/calculators/tensor:image_to_tensor_calculator", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/util:to_image_calculator", + ], +) + mediapipe_simple_subgraph( name = "face_detection_short_range_gpu", graph = "face_detection_short_range_gpu.pbtxt", diff --git a/mediapipe/modules/face_detection/face_detection_short_range_side_model_cpu.pbtxt b/mediapipe/modules/face_detection/face_detection_short_range_side_model_cpu.pbtxt new file mode 100644 index 000000000..57639bab2 --- /dev/null +++ b/mediapipe/modules/face_detection/face_detection_short_range_side_model_cpu.pbtxt @@ -0,0 +1,86 @@ +# MediaPipe graph to detect faces. (CPU input, and inference is executed on +# CPU.) +# +# It is required that "face_detection_short_range.tflite" is available at +# "mediapipe/modules/face_detection/face_detection_short_range.tflite" +# path during execution. +# +# EXAMPLE: +# node { +# calculator: "FaceDetectionShortRangeCpu" +# input_stream: "IMAGE:image" +# input_side_packet: "MODEL:face_detection_model" +# output_stream: "DETECTIONS:face_detections" +# } + +type: "FaceDetectionShortRangeCpu" + +# CPU image. (ImageFrame) +input_stream: "IMAGE:image" + +# TfLite model to detect faces. +# (std::unique_ptr>) +# NOTE: mediapipe/modules/face_detection/face_detection_short_range.tflite +# model only, can be passed here, otherwise - results are undefined. +input_side_packet: "MODEL:face_detection_model" + +# Detected faces. (std::vector) +# NOTE: there will not be an output packet in the DETECTIONS stream for this +# particular timestamp if none of faces detected. However, the MediaPipe +# framework will internally inform the downstream calculators of the absence of +# this packet so that they don't wait for it unnecessarily. +output_stream: "DETECTIONS:detections" + +# Converts the input CPU image (ImageFrame) to the multi-backend image type +# (Image). +node: { + calculator: "ToImageCalculator" + input_stream: "IMAGE_CPU:image" + output_stream: "IMAGE:multi_backend_image" +} + +# Transforms the input image into a 128x128 tensor while keeping the aspect +# ratio (what is expected by the corresponding face detection model), resulting +# in potential letterboxing in the transformed image. +node: { + calculator: "ImageToTensorCalculator" + input_stream: "IMAGE:multi_backend_image" + output_stream: "TENSORS:input_tensors" + output_stream: "MATRIX:transform_matrix" + options: { + [mediapipe.ImageToTensorCalculatorOptions.ext] { + output_tensor_width: 128 + output_tensor_height: 128 + keep_aspect_ratio: true + output_tensor_float_range { + min: -1.0 + max: 1.0 + } + border_mode: BORDER_ZERO + } + } +} + +# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "InferenceCalculator" + input_stream: "TENSORS:input_tensors" + output_stream: "TENSORS:detection_tensors" + input_side_packet: "MODEL:face_detection_model" + options { + [mediapipe.InferenceCalculatorOptions.ext] { + delegate { tflite {} } + } + } +} + +# Performs tensor post processing to generate face detections. +node { + calculator: "FaceDetectionShortRangeCommon" + input_stream: "TENSORS:detection_tensors" + input_stream: "MATRIX:transform_matrix" + output_stream: "DETECTIONS:detections" +} diff --git a/mediapipe/modules/face_landmark/BUILD b/mediapipe/modules/face_landmark/BUILD index 30720c1b0..6e642d7fc 100644 --- a/mediapipe/modules/face_landmark/BUILD +++ b/mediapipe/modules/face_landmark/BUILD @@ -37,6 +37,22 @@ mediapipe_simple_subgraph( ], ) +mediapipe_simple_subgraph( + name = "face_landmark_side_model_cpu", + graph = "face_landmark_side_model_cpu.pbtxt", + register_as = "FaceLandmarkSideModelCpu", + deps = [ + "//mediapipe/calculators/core:gate_calculator", + "//mediapipe/calculators/core:split_vector_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:tensors_to_floats_calculator", + "//mediapipe/calculators/tensor:tensors_to_landmarks_calculator", + "//mediapipe/calculators/util:landmark_projection_calculator", + "//mediapipe/calculators/util:thresholding_calculator", + ], +) + mediapipe_simple_subgraph( name = "face_landmark_gpu", graph = "face_landmark_gpu.pbtxt", @@ -96,6 +112,28 @@ mediapipe_simple_subgraph( ], ) +mediapipe_simple_subgraph( + name = "face_landmark_front_side_model_cpu_with_face_counter", + graph = "face_landmark_front_side_model_cpu_with_face_counter.pbtxt", + register_as = "FaceLandmarkFrontSideModelCpuWithFaceCounter", + deps = [ + ":face_detection_front_detection_to_roi", + ":face_landmark_side_model_cpu", + ":face_landmark_landmarks_to_roi", + "//mediapipe/calculators/core:begin_loop_calculator", + "//mediapipe/calculators/core:clip_vector_size_calculator", + "//mediapipe/calculators/core:constant_side_packet_calculator", + "//mediapipe/calculators/core:end_loop_calculator", + "//mediapipe/calculators/core:gate_calculator", + "//mediapipe/calculators/core:previous_loopback_calculator", + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/util:association_norm_rect_calculator", + "//mediapipe/calculators/util:collection_has_min_size_calculator", + "//mediapipe/calculators/util:counting_vector_size_calculator", + "//mediapipe/modules/face_detection:face_detection_short_range_side_model_cpu", + ], +) + mediapipe_simple_subgraph( name = "face_landmark_front_gpu", graph = "face_landmark_front_gpu.pbtxt", diff --git a/mediapipe/modules/face_landmark/face_landmark_front_side_model_cpu_with_face_counter.pbtxt b/mediapipe/modules/face_landmark/face_landmark_front_side_model_cpu_with_face_counter.pbtxt new file mode 100644 index 000000000..dc83f17b7 --- /dev/null +++ b/mediapipe/modules/face_landmark/face_landmark_front_side_model_cpu_with_face_counter.pbtxt @@ -0,0 +1,256 @@ +# MediaPipe graph to detect/predict face landmarks. (CPU input, and inference is +# executed on CPU.) This graph tries to skip face detection as much as possible +# by using previously detected/predicted landmarks for new images. +# +# EXAMPLE: +# node { +# calculator: "FaceLandmarkFrontSideModelCpu" +# input_stream: "IMAGE:image" +# input_side_packet: "NUM_FACES:num_faces" +# input_side_packet: "MODEL:0:face_detection_model" +# input_side_packet: "MODEL:1:face_landmark_model" +# output_stream: "LANDMARKS:multi_face_landmarks" +# } + +type: "FaceLandmarkFrontSideModelCpu" + +# CPU image. (ImageFrame) +input_stream: "IMAGE:image" + +# Max number of faces to detect/track. (int) +input_side_packet: "NUM_FACES:num_faces" +# TfLite model to detect faces. +# (std::unique_ptr>) +# NOTE: mediapipe/modules/face_detection/face_detection_short_range.tflite +# model only, can be passed here, otherwise - results are undefined. +input_side_packet: "MODEL:0:face_detection_model" +# TfLite model to detect face landmarks. +# (std::unique_ptr>) +# NOTE: mediapipe/modules/face_landmark/face_landmark.tflite model +# only, can be passed here, otherwise - results are undefined. +input_side_packet: "MODEL:1:face_landmark_model" + +# Collection of detected/predicted faces, each represented as a list of 468 face +# landmarks. (std::vector) +# NOTE: there will not be an output packet in the LANDMARKS stream for this +# particular timestamp if none of faces detected. However, the MediaPipe +# framework will internally inform the downstream calculators of the absence of +# this packet so that they don't wait for it unnecessarily. +output_stream: "LANDMARKS:multi_face_landmarks" + +# Extra outputs (for debugging, for instance). +# Detected faces. (std::vector) +output_stream: "DETECTIONS:face_detections" +# Regions of interest calculated based on landmarks. +# (std::vector) +output_stream: "ROIS_FROM_LANDMARKS:face_rects_from_landmarks" +# Regions of interest calculated based on face detections. +# (std::vector) +output_stream: "ROIS_FROM_DETECTIONS:face_rects_from_detections" + +# (int) +output_stream: "FACE_COUNT_FROM_LANDMARKS:face_count" + + +# Defines whether landmarks on the previous image should be used to help +# localize landmarks on the current image. +node { + name: "ConstantSidePacketCalculator" + calculator: "ConstantSidePacketCalculator" + output_side_packet: "PACKET:use_prev_landmarks" + options: { + [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { + packet { bool_value: true } + } + } +} +node { + calculator: "GateCalculator" + input_side_packet: "ALLOW:use_prev_landmarks" + input_stream: "prev_face_rects_from_landmarks" + output_stream: "gated_prev_face_rects_from_landmarks" +} + +# Determines if an input vector of NormalizedRect has a size greater than or +# equal to the provided num_faces. +node { + calculator: "NormalizedRectVectorHasMinSizeCalculator" + input_stream: "ITERABLE:prev_face_rects_from_landmarks" + input_side_packet: "num_faces" + output_stream: "prev_has_enough_faces" +} + +# Drops the incoming image if FaceLandmarkCpu was able to identify face presence +# in the previous image. Otherwise, passes the incoming image through to trigger +# a new round of face detection in FaceDetectionShortRangeCpu. +node { + calculator: "GateCalculator" + input_stream: "image" + input_stream: "DISALLOW:prev_has_enough_faces" + output_stream: "gated_image" + options: { + [mediapipe.GateCalculatorOptions.ext] { + empty_packets_as_allow: true + } + } +} + +# Detects faces. +node { + calculator: "FaceDetectionShortRangeSideModelCpu" + input_stream: "IMAGE:gated_image" + input_side_packet: "MODEL:face_detection_model" + output_stream: "DETECTIONS:all_face_detections" +} + +# Makes sure there are no more detections than the provided num_faces. +node { + calculator: "ClipDetectionVectorSizeCalculator" + input_stream: "all_face_detections" + output_stream: "face_detections" + input_side_packet: "num_faces" +} + +# Calculate size of the image. +node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE:gated_image" + output_stream: "SIZE:gated_image_size" +} + +# Outputs each element of face_detections at a fake timestamp for the rest of +# the graph to process. Clones the image size packet for each face_detection at +# the fake timestamp. At the end of the loop, outputs the BATCH_END timestamp +# for downstream calculators to inform them that all elements in the vector have +# been processed. +node { + calculator: "BeginLoopDetectionCalculator" + input_stream: "ITERABLE:face_detections" + input_stream: "CLONE:gated_image_size" + output_stream: "ITEM:face_detection" + output_stream: "CLONE:detections_loop_image_size" + output_stream: "BATCH_END:detections_loop_end_timestamp" +} + +# Calculates region of interest based on face detections, so that can be used +# to detect landmarks. +node { + calculator: "FaceDetectionFrontDetectionToRoi" + input_stream: "DETECTION:face_detection" + input_stream: "IMAGE_SIZE:detections_loop_image_size" + output_stream: "ROI:face_rect_from_detection" +} + +# Counting a multi_faceLandmarks vector size. The image stream is only used to +# make the calculator work even when there is no input vector. +node { + calculator: "CountingNormalizedLandmarkListVectorSizeCalculator" + input_stream: "CLOCK:image" + input_stream: "VECTOR:multi_face_landmarks" + output_stream: "COUNT:face_count" +} + +# Collects a NormalizedRect for each face into a vector. Upon receiving the +# BATCH_END timestamp, outputs the vector of NormalizedRect at the BATCH_END +# timestamp. +node { + calculator: "EndLoopNormalizedRectCalculator" + input_stream: "ITEM:face_rect_from_detection" + input_stream: "BATCH_END:detections_loop_end_timestamp" + output_stream: "ITERABLE:face_rects_from_detections" +} + +# Performs association between NormalizedRect vector elements from previous +# image and rects based on face detections from the current image. This +# calculator ensures that the output face_rects vector doesn't contain +# overlapping regions based on the specified min_similarity_threshold. +node { + calculator: "AssociationNormRectCalculator" + input_stream: "face_rects_from_detections" + input_stream: "prev_face_rects_from_landmarks" + output_stream: "face_rects" + options: { + [mediapipe.AssociationCalculatorOptions.ext] { + min_similarity_threshold: 0.5 + } + } +} + +# Calculate size of the image. +node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE:image" + output_stream: "SIZE:image_size" +} + +# Outputs each element of face_rects at a fake timestamp for the rest of the +# graph to process. Clones image and image size packets for each +# single_face_rect at the fake timestamp. At the end of the loop, outputs the +# BATCH_END timestamp for downstream calculators to inform them that all +# elements in the vector have been processed. +node { + calculator: "BeginLoopNormalizedRectCalculator" + input_stream: "ITERABLE:face_rects" + input_stream: "CLONE:0:image" + input_stream: "CLONE:1:image_size" + output_stream: "ITEM:face_rect" + output_stream: "CLONE:0:landmarks_loop_image" + output_stream: "CLONE:1:landmarks_loop_image_size" + output_stream: "BATCH_END:landmarks_loop_end_timestamp" +} + +# Detects face landmarks within specified region of interest of the image. +node { + calculator: "FaceLandmarkSideModelCpu" + input_stream: "IMAGE:landmarks_loop_image" + input_stream: "ROI:face_rect" + input_side_packet: "MODEL:face_landmark_model" + output_stream: "LANDMARKS:face_landmarks" +} + +# Calculates region of interest based on face landmarks, so that can be reused +# for subsequent image. +node { + calculator: "FaceLandmarkLandmarksToRoi" + input_stream: "LANDMARKS:face_landmarks" + input_stream: "IMAGE_SIZE:landmarks_loop_image_size" + output_stream: "ROI:face_rect_from_landmarks" +} + +# Collects a set of landmarks for each face into a vector. Upon receiving the +# BATCH_END timestamp, outputs the vector of landmarks at the BATCH_END +# timestamp. +node { + calculator: "EndLoopNormalizedLandmarkListVectorCalculator" + input_stream: "ITEM:face_landmarks" + input_stream: "BATCH_END:landmarks_loop_end_timestamp" + output_stream: "ITERABLE:multi_face_landmarks" +} + +# Collects a NormalizedRect for each face into a vector. Upon receiving the +# BATCH_END timestamp, outputs the vector of NormalizedRect at the BATCH_END +# timestamp. +node { + calculator: "EndLoopNormalizedRectCalculator" + input_stream: "ITEM:face_rect_from_landmarks" + input_stream: "BATCH_END:landmarks_loop_end_timestamp" + output_stream: "ITERABLE:face_rects_from_landmarks" +} + +# Caches face rects calculated from landmarks, and upon the arrival of the next +# input image, sends out the cached rects with timestamps replaced by that of +# the input image, essentially generating a packet that carries the previous +# face rects. Note that upon the arrival of the very first input image, a +# timestamp bound update occurs to jump start the feedback loop. +node { + calculator: "PreviousLoopbackCalculator" + input_stream: "MAIN:image" + input_stream: "LOOP:face_rects_from_landmarks" + input_stream_info: { + tag_index: "LOOP" + back_edge: true + } + output_stream: "PREV_LOOP:prev_face_rects_from_landmarks" +} diff --git a/mediapipe/modules/face_landmark/face_landmark_side_model_cpu.pbtxt b/mediapipe/modules/face_landmark/face_landmark_side_model_cpu.pbtxt new file mode 100644 index 000000000..d8537fd82 --- /dev/null +++ b/mediapipe/modules/face_landmark/face_landmark_side_model_cpu.pbtxt @@ -0,0 +1,143 @@ +# MediaPipe graph to detect/predict face landmarks. (CPU input, and inference is +# executed on CPU.) +# +# It is required that "face_landmark.tflite" is available at +# "mediapipe/modules/face_landmark/face_landmark.tflite" +# path during execution. +# +# EXAMPLE: +# node { +# calculator: "FaceLandmarkCpu" +# input_stream: "IMAGE:image" +# input_stream: "ROI:face_roi" +# input_side_packet: "MODEL:face_landmark_model" +# output_stream: "LANDMARKS:face_landmarks" +# } + +type: "FaceLandmarkCpu" + +# CPU image. (ImageFrame) +input_stream: "IMAGE:image" +# ROI (region of interest) within the given image where a face is located. +# (NormalizedRect) +input_stream: "ROI:roi" + +# TfLite model to detect face landmarks. +# (std::unique_ptr>) +# NOTE: mediapipe/modules/face_landmark/face_landmark.tflite model +# only, can be passed here, otherwise - results are undefined. +input_side_packet: "MODEL:face_landmark_model" + + +# 468 face landmarks within the given ROI. (NormalizedLandmarkList) +# NOTE: if a face is not present within the given ROI, for this particular +# timestamp there will not be an output packet in the LANDMARKS stream. However, +# the MediaPipe framework will internally inform the downstream calculators of +# the absence of this packet so that they don't wait for it unnecessarily. +output_stream: "LANDMARKS:face_landmarks" + +# Transforms the input image into a 192x192 tensor. +node: { + calculator: "ImageToTensorCalculator" + input_stream: "IMAGE:image" + input_stream: "NORM_RECT:roi" + output_stream: "TENSORS:input_tensors" + options: { + [mediapipe.ImageToTensorCalculatorOptions.ext] { + output_tensor_width: 192 + output_tensor_height: 192 + output_tensor_float_range { + min: 0.0 + max: 1.0 + } + } + } +} + +# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "InferenceCalculator" + input_stream: "TENSORS:input_tensors" + output_stream: "TENSORS:output_tensors" + input_side_packet: "MODEL:face_landmark_model" + options { + [mediapipe.InferenceCalculatorOptions.ext] { + delegate { tflite {} } + } + } +} + +# Splits a vector of tensors into multiple vectors. +node { + calculator: "SplitTensorVectorCalculator" + input_stream: "output_tensors" + output_stream: "landmark_tensors" + output_stream: "face_flag_tensor" + options: { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges: { begin: 0 end: 1 } + ranges: { begin: 1 end: 2 } + } + } +} + +# Converts the face-flag tensor into a float that represents the confidence +# score of face presence. +node { + calculator: "TensorsToFloatsCalculator" + input_stream: "TENSORS:face_flag_tensor" + output_stream: "FLOAT:face_presence_score" + options { + [mediapipe.TensorsToFloatsCalculatorOptions.ext] { + activation: SIGMOID + } + } +} + +# Applies a threshold to the confidence score to determine whether a face is +# present. +node { + calculator: "ThresholdingCalculator" + input_stream: "FLOAT:face_presence_score" + output_stream: "FLAG:face_presence" + options: { + [mediapipe.ThresholdingCalculatorOptions.ext] { + threshold: 0.5 + } + } +} + +# Drop landmarks tensors if face is not present. +node { + calculator: "GateCalculator" + input_stream: "landmark_tensors" + input_stream: "ALLOW:face_presence" + output_stream: "ensured_landmark_tensors" +} + +# Decodes the landmark tensors into a vector of landmarks, where the landmark +# coordinates are normalized by the size of the input image to the model. +node { + calculator: "TensorsToLandmarksCalculator" + input_stream: "TENSORS:ensured_landmark_tensors" + output_stream: "NORM_LANDMARKS:landmarks" + options: { + [mediapipe.TensorsToLandmarksCalculatorOptions.ext] { + num_landmarks: 468 + input_image_width: 192 + input_image_height: 192 + } + } +} + +# Projects the landmarks from the cropped face image to the corresponding +# locations on the full image before cropping (input to the graph). +node { + calculator: "LandmarkProjectionCalculator" + input_stream: "NORM_LANDMARKS:landmarks" + input_stream: "NORM_RECT:roi" + output_stream: "NORM_LANDMARKS:face_landmarks" +}