Added possibility to get detected face bounding boxes

This commit is contained in:
dmaletskiy 2021-08-12 17:40:07 +03:00
parent 6d89ef3e9e
commit 393ad8ffcc
4 changed files with 213 additions and 41 deletions

View File

@ -36,6 +36,7 @@ windows_dll_library(
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:opencv_highgui", "//mediapipe/framework/port:opencv_highgui",
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",

View File

@ -36,7 +36,8 @@ int main(int argc, char **argv) {
for (int i = 0; i < maxNumFaces; ++i) { for (int i = 0; i < maxNumFaces; ++i) {
multiFaceLandmarks[i] = new cv::Point2f[MPFaceMeshDetectorLandmarksNum]; multiFaceLandmarks[i] = new cv::Point2f[MPFaceMeshDetectorLandmarksNum];
} }
const auto faceCount = std::make_unique<int>();
std::vector<cv::Rect> multiFaceBoundingBoxes(maxNumFaces);
LOG(INFO) << "FaceMeshDetector constructed."; LOG(INFO) << "FaceMeshDetector constructed.";
@ -54,14 +55,21 @@ int main(int argc, char **argv) {
cv::Mat camera_frame; cv::Mat camera_frame;
cv::cvtColor(camera_frame_raw, camera_frame, cv::COLOR_BGR2RGB); cv::cvtColor(camera_frame_raw, camera_frame, cv::COLOR_BGR2RGB);
cv::flip(camera_frame, camera_frame, /*flipcode=HORIZONTAL*/ 1);
MPFaceMeshDetectorProcessFrame2D(faceMeshDetector, camera_frame, int faceCount = 0;
faceCount.get(), multiFaceLandmarks);
LOG(INFO) << "Detected faces num: " << *faceCount; MPFaceMeshDetectorDetectFaces(faceMeshDetector, camera_frame,
multiFaceBoundingBoxes.data(), &faceCount);
if (*faceCount > 0) { if (faceCount > 0) {
auto &face_bounding_box = multiFaceBoundingBoxes[0];
cv::rectangle(camera_frame_raw, face_bounding_box, cv::Scalar(0, 255, 0),
3);
int landmarksNum = 0;
MPFaceMeshDetectorDetect2DLandmarks(faceMeshDetector, multiFaceLandmarks,
&landmarksNum);
auto &face_landmarks = multiFaceLandmarks[0]; auto &face_landmarks = multiFaceLandmarks[0];
auto &landmark = face_landmarks[0]; auto &landmark = face_landmarks[0];

View File

@ -52,11 +52,17 @@ MPFaceMeshDetector::InitFaceMeshDetector(int numFaces,
graph.AddOutputStreamPoller(kOutputStream_landmarks)); graph.AddOutputStreamPoller(kOutputStream_landmarks));
ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller face_count_poller, ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller face_count_poller,
graph.AddOutputStreamPoller(kOutputStream_faceCount)); graph.AddOutputStreamPoller(kOutputStream_faceCount));
ASSIGN_OR_RETURN(
mediapipe::OutputStreamPoller face_rects_from_landmarks_poller,
graph.AddOutputStreamPoller(kOutputStream_face_rects_from_landmarks));
landmarks_poller_ptr = std::make_unique<mediapipe::OutputStreamPoller>( landmarks_poller_ptr = std::make_unique<mediapipe::OutputStreamPoller>(
std::move(landmarks_poller)); std::move(landmarks_poller));
face_count_poller_ptr = std::make_unique<mediapipe::OutputStreamPoller>( face_count_poller_ptr = std::make_unique<mediapipe::OutputStreamPoller>(
std::move(face_count_poller)); std::move(face_count_poller));
face_rects_from_landmarks_poller_ptr =
std::make_unique<mediapipe::OutputStreamPoller>(
std::move(face_rects_from_landmarks_poller));
MP_RETURN_IF_ERROR(graph.StartRun({})); MP_RETURN_IF_ERROR(graph.StartRun({}));
@ -65,10 +71,19 @@ MPFaceMeshDetector::InitFaceMeshDetector(int numFaces,
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status MPFaceMeshDetector::ProcessFrame2DWithStatus( absl::Status
const cv::Mat &camera_frame, int *numFaces, MPFaceMeshDetector::DetectFacesWithStatus(const cv::Mat &camera_frame,
cv::Point2f **multi_face_landmarks) { cv::Rect *multi_face_bounding_boxes,
int *numFaces) {
if (!numFaces || !multi_face_bounding_boxes) {
return absl::InvalidArgumentError(
"MPFaceMeshDetector::DetectFacesWithStatus requires notnull pointer to "
"save results data.");
}
// Reset face counts.
*numFaces = 0; *numFaces = 0;
face_count = 0;
// Wrap Mat into an ImageFrame. // Wrap Mat into an ImageFrame.
auto input_frame = absl::make_unique<mediapipe::ImageFrame>( auto input_frame = absl::make_unique<mediapipe::ImageFrame>(
@ -92,26 +107,89 @@ absl::Status MPFaceMeshDetector::ProcessFrame2DWithStatus(
"Failed during getting next face_count_packet."); "Failed during getting next face_count_packet.");
} }
auto &face_count = face_count_packet.Get<int>(); auto &face_count_val = face_count_packet.Get<int>();
if (face_count <= 0) { if (face_count_val <= 0) {
return absl::OkStatus(); return absl::OkStatus();
} }
// Get face bounding boxes.
mediapipe::Packet face_rects_from_landmarks_packet;
if (!face_rects_from_landmarks_poller_ptr ||
!face_rects_from_landmarks_poller_ptr->Next(
&face_rects_from_landmarks_packet)) {
return absl::CancelledError(
"Failed during getting next face_rects_from_landmarks_packet.");
}
auto &face_bounding_boxes =
face_rects_from_landmarks_packet
.Get<::std::vector<::mediapipe::NormalizedRect>>();
image_width = camera_frame.cols;
image_height = camera_frame.rows;
const auto image_width_f = static_cast<float>(image_width);
const auto image_height_f = static_cast<float>(image_height);
// Convert vector<NormalizedRect> (center based Rects) to cv::Rect*
// (leftTop based Rects).
for (int i = 0; i < face_count_val; ++i) {
const auto &normalized_bounding_box = face_bounding_boxes[i];
auto &bounding_box = multi_face_bounding_boxes[i];
const auto width =
static_cast<int>(normalized_bounding_box.width() * image_width_f);
const auto height =
static_cast<int>(normalized_bounding_box.height() * image_height_f);
bounding_box.x =
static_cast<int>(normalized_bounding_box.x_center() * image_width_f) -
(width >> 1);
bounding_box.y =
static_cast<int>(normalized_bounding_box.y_center() * image_height_f) -
(height >> 1);
bounding_box.width = width;
bounding_box.height = height;
}
// Get face landmarks. // Get face landmarks.
mediapipe::Packet face_landmarks_packet;
if (!landmarks_poller_ptr || if (!landmarks_poller_ptr ||
!landmarks_poller_ptr->Next(&face_landmarks_packet)) { !landmarks_poller_ptr->Next(&face_landmarks_packet)) {
return absl::CancelledError("Failed during getting next landmarks_packet."); return absl::CancelledError("Failed during getting next landmarks_packet.");
} }
auto &output_landmarks_vector = *numFaces = face_count_val;
face_count = face_count_val;
return absl::OkStatus();
}
void MPFaceMeshDetector::DetectFaces(const cv::Mat &camera_frame,
cv::Rect *multi_face_bounding_boxes,
int *numFaces) {
const auto status =
DetectFacesWithStatus(camera_frame, multi_face_bounding_boxes, numFaces);
if (!status.ok()) {
LOG(INFO) << "MPFaceMeshDetector::DetectFaces failed: " << status.message();
}
}
absl::Status MPFaceMeshDetector::DetectLandmarksWithStatus(
cv::Point2f **multi_face_landmarks) {
if (face_landmarks_packet.IsEmpty()) {
return absl::CancelledError("Face landmarks packet is empty.");
}
auto &face_landmarks =
face_landmarks_packet face_landmarks_packet
.Get<::std::vector<::mediapipe::NormalizedLandmarkList>>(); .Get<::std::vector<::mediapipe::NormalizedLandmarkList>>();
const auto image_width_f = static_cast<float>(image_width);
const auto image_height_f = static_cast<float>(image_height);
// Convert landmarks to cv::Point2f**. // Convert landmarks to cv::Point2f**.
for (int i = 0; i < face_count; ++i) { for (int i = 0; i < face_count; ++i) {
const auto &normalizedLandmarkList = output_landmarks_vector[i]; const auto &normalizedLandmarkList = face_landmarks[i];
const auto landmarks_num = normalizedLandmarkList.landmark_size(); const auto landmarks_num = normalizedLandmarkList.landmark_size();
if (landmarks_num != kLandmarksNum) { if (landmarks_num != kLandmarksNum) {
@ -122,25 +200,70 @@ absl::Status MPFaceMeshDetector::ProcessFrame2DWithStatus(
for (int j = 0; j < landmarks_num; ++j) { for (int j = 0; j < landmarks_num; ++j) {
const auto &landmark = normalizedLandmarkList.landmark(j); const auto &landmark = normalizedLandmarkList.landmark(j);
face_landmarks[j].x = landmark.x(); face_landmarks[j].x = landmark.x() * image_width_f;
face_landmarks[j].y = landmark.y(); face_landmarks[j].y = landmark.y() * image_height_f;
} }
} }
*numFaces = face_count;
return absl::OkStatus(); return absl::OkStatus();
} }
void MPFaceMeshDetector::ProcessFrame2D(const cv::Mat &camera_frame, absl::Status MPFaceMeshDetector::DetectLandmarksWithStatus(
int *numFaces, cv::Point3f **multi_face_landmarks) {
cv::Point2f **multi_face_landmarks) {
const auto status = if (face_landmarks_packet.IsEmpty()) {
ProcessFrame2DWithStatus(camera_frame, numFaces, multi_face_landmarks); return absl::CancelledError("Face landmarks packet is empty.");
if (!status.ok()) {
LOG(INFO) << "Failed ProcessFrame2D.";
LOG(INFO) << status.message();
} }
auto &face_landmarks =
face_landmarks_packet
.Get<::std::vector<::mediapipe::NormalizedLandmarkList>>();
const auto image_width_f = static_cast<float>(image_width);
const auto image_height_f = static_cast<float>(image_height);
// Convert landmarks to cv::Point3f**.
for (int i = 0; i < face_count; ++i) {
const auto &normalized_landmark_list = face_landmarks[i];
const auto landmarks_num = normalized_landmark_list.landmark_size();
if (landmarks_num != kLandmarksNum) {
return absl::CancelledError("Detected unexpected landmarks number.");
}
auto &face_landmarks = multi_face_landmarks[i];
for (int j = 0; j < landmarks_num; ++j) {
const auto &landmark = normalized_landmark_list.landmark(j);
face_landmarks[j].x = landmark.x() * image_width_f;
face_landmarks[j].y = landmark.y() * image_height_f;
face_landmarks[j].z = landmark.z();
}
}
return absl::OkStatus();
}
void MPFaceMeshDetector::DetectLandmarks(cv::Point2f **multi_face_landmarks,
int *numFaces) {
*numFaces = 0;
const auto status = DetectLandmarksWithStatus(multi_face_landmarks);
if (!status.ok()) {
LOG(INFO) << "MPFaceMeshDetector::DetectLandmarks failed: "
<< status.message();
}
*numFaces = face_count;
}
void MPFaceMeshDetector::DetectLandmarks(cv::Point3f **multi_face_landmarks,
int *numFaces) {
*numFaces = 0;
const auto status = DetectLandmarksWithStatus(multi_face_landmarks);
if (!status.ok()) {
LOG(INFO) << "MPFaceMeshDetector::DetectLandmarks failed: "
<< status.message();
}
*numFaces = face_count;
} }
extern "C" { extern "C" {
@ -155,14 +278,26 @@ DLLEXPORT void MPFaceMeshDetectorDestruct(MPFaceMeshDetector *detector) {
delete detector; delete detector;
} }
DLLEXPORT void MPFaceMeshDetectorDetectFaces(
MPFaceMeshDetector *detector, const cv::Mat &camera_frame,
cv::Rect *multi_face_bounding_boxes, int *numFaces) {
detector->DetectFaces(camera_frame, multi_face_bounding_boxes, numFaces);
}
DLLEXPORT void DLLEXPORT void
MPFaceMeshDetectorProcessFrame2D(MPFaceMeshDetector *detector, MPFaceMeshDetectorDetect2DLandmarks(MPFaceMeshDetector *detector,
const cv::Mat &camera_frame, int *numFaces, cv::Point2f **multi_face_landmarks,
cv::Point2f **multi_face_landmarks) { int *numFaces) {
detector->ProcessFrame2D(camera_frame, numFaces, multi_face_landmarks); detector->DetectLandmarks(multi_face_landmarks, numFaces);
}
DLLEXPORT void
MPFaceMeshDetectorDetect3DLandmarks(MPFaceMeshDetector *detector,
cv::Point3f **multi_face_landmarks,
int *numFaces) {
detector->DetectLandmarks(multi_face_landmarks, numFaces);
} }
DLLEXPORT const int MPFaceMeshDetectorLandmarksNum = MPFaceMeshDetector::kLandmarksNum; DLLEXPORT const int MPFaceMeshDetectorLandmarksNum =
MPFaceMeshDetector::kLandmarksNum;
} }
const std::string MPFaceMeshDetector::graphConfig = R"pb( const std::string MPFaceMeshDetector::graphConfig = R"pb(
@ -178,6 +313,10 @@ output_stream: "multi_face_landmarks"
# Detected faces count. (int) # Detected faces count. (int)
output_stream: "face_count" output_stream: "face_count"
# Regions of interest calculated based on landmarks.
# (std::vector<NormalizedRect>)
output_stream: "face_rects_from_landmarks"
node { node {
calculator: "FlowLimiterCalculator" calculator: "FlowLimiterCalculator"
input_stream: "input_video" input_stream: "input_video"

View File

@ -20,6 +20,7 @@
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/output_stream_poller.h" #include "mediapipe/framework/output_stream_poller.h"
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/opencv_highgui_inc.h" #include "mediapipe/framework/port/opencv_highgui_inc.h"
@ -30,25 +31,33 @@
class MPFaceMeshDetector { class MPFaceMeshDetector {
public: public:
static constexpr auto kLandmarksNum = 468;
MPFaceMeshDetector(int numFaces, const char *face_detection_model_path, MPFaceMeshDetector(int numFaces, const char *face_detection_model_path,
const char *face_landmark_model_path); const char *face_landmark_model_path);
void ProcessFrame2D(const cv::Mat &camera_frame, int *numFaces, void DetectFaces(const cv::Mat &camera_frame,
cv::Point2f **multi_face_landmarks); cv::Rect *multi_face_bounding_boxes, int *numFaces);
void DetectLandmarks(cv::Point2f **multi_face_landmarks, int *numFaces);
void DetectLandmarks(cv::Point3f **multi_face_landmarks, int *numFaces);
static constexpr auto kLandmarksNum = 468;
private: private:
absl::Status InitFaceMeshDetector(int numFaces, absl::Status InitFaceMeshDetector(int numFaces,
const char *face_detection_model_path, const char *face_detection_model_path,
const char *face_landmark_model_path); const char *face_landmark_model_path);
absl::Status ProcessFrame2DWithStatus(const cv::Mat &camera_frame, absl::Status DetectFacesWithStatus(const cv::Mat &camera_frame,
int *numFaces, cv::Rect *multi_face_bounding_boxes,
cv::Point2f **multi_face_landmarks); int *numFaces);
absl::Status DetectLandmarksWithStatus(cv::Point2f **multi_face_landmarks);
absl::Status DetectLandmarksWithStatus(cv::Point3f **multi_face_landmarks);
static constexpr auto kInputStream = "input_video"; static constexpr auto kInputStream = "input_video";
static constexpr auto kOutputStream_landmarks = "multi_face_landmarks"; static constexpr auto kOutputStream_landmarks = "multi_face_landmarks";
static constexpr auto kOutputStream_faceCount = "face_count"; static constexpr auto kOutputStream_faceCount = "face_count";
static constexpr auto kOutputStream_face_rects_from_landmarks =
"face_rects_from_landmarks";
static const std::string graphConfig; static const std::string graphConfig;
@ -56,6 +65,13 @@ private:
std::unique_ptr<mediapipe::OutputStreamPoller> landmarks_poller_ptr; std::unique_ptr<mediapipe::OutputStreamPoller> landmarks_poller_ptr;
std::unique_ptr<mediapipe::OutputStreamPoller> face_count_poller_ptr; std::unique_ptr<mediapipe::OutputStreamPoller> face_count_poller_ptr;
std::unique_ptr<mediapipe::OutputStreamPoller>
face_rects_from_landmarks_poller_ptr;
int face_count;
int image_width;
int image_height;
mediapipe::Packet face_landmarks_packet;
}; };
#ifdef __cplusplus #ifdef __cplusplus
@ -68,10 +84,18 @@ MPFaceMeshDetectorConstruct(int numFaces, const char *face_detection_model_path,
DLLEXPORT void MPFaceMeshDetectorDestruct(MPFaceMeshDetector *detector); DLLEXPORT void MPFaceMeshDetectorDestruct(MPFaceMeshDetector *detector);
DLLEXPORT void MPFaceMeshDetectorDetectFaces(
MPFaceMeshDetector *detector, const cv::Mat &camera_frame,
cv::Rect *multi_face_bounding_boxes, int *numFaces);
DLLEXPORT void DLLEXPORT void
MPFaceMeshDetectorProcessFrame2D(MPFaceMeshDetector *detector, MPFaceMeshDetectorDetect2DLandmarks(MPFaceMeshDetector *detector,
const cv::Mat &camera_frame, int *numFaces, cv::Point2f **multi_face_landmarks,
cv::Point2f **multi_face_landmarks); int *numFaces);
DLLEXPORT void
MPFaceMeshDetectorDetect3DLandmarks(MPFaceMeshDetector *detector,
cv::Point3f **multi_face_landmarks,
int *numFaces);
DLLEXPORT extern const int MPFaceMeshDetectorLandmarksNum; DLLEXPORT extern const int MPFaceMeshDetectorLandmarksNum;