Face Detector C++ API

PiperOrigin-RevId: 513959349
This commit is contained in:
MediaPipe Team 2023-03-03 17:19:10 -08:00 committed by Copybara-Service
parent 5b2678a49f
commit 3d41eabc2e
9 changed files with 753 additions and 14 deletions

View File

@ -58,3 +58,25 @@ cc_library(
], ],
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "face_detector",
srcs = ["face_detector.cc"],
hdrs = ["face_detector.h"],
deps = [
":face_detector_graph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/tasks/cc/components/containers:detection_result",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)

View File

@ -0,0 +1,197 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/face_detector/face_detector.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace face_detector {
namespace {
using FaceDetectorGraphOptionsProto =
::mediapipe::tasks::vision::face_detector::proto::FaceDetectorGraphOptions;
constexpr char kFaceDetectorGraphTypeName[] =
"mediapipe.tasks.vision.face_detector.FaceDetectorGraph";
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kDetectionsStreamName[] = "detections";
constexpr int kMicroSecondsPerMilliSecond = 1000;
// Creates a MediaPipe graph config that contains a subgraph node of
// "mediapipe.tasks.vision.face_detector.FaceDetectorGraph". If the task is
// running in the live stream mode, a "FlowLimiterCalculator" will be added to
// limit the number of frames in flight.
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<FaceDetectorGraphOptionsProto> options,
bool enable_flow_limiting) {
api2::builder::Graph graph;
auto& subgraph = graph.AddNode(kFaceDetectorGraphTypeName);
subgraph.GetOptions<FaceDetectorGraphOptionsProto>().Swap(options.get());
graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName);
subgraph.Out(kDetectionsTag).SetName(kDetectionsStreamName) >>
graph.Out(kDetectionsTag);
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator(
graph, subgraph, {kImageTag, kNormRectTag}, kDetectionsTag);
}
graph.In(kImageTag) >> subgraph.In(kImageTag);
graph.In(kNormRectTag) >> subgraph.In(kNormRectTag);
return graph.GetConfig();
}
// Converts the user-facing FaceDetectorOptions struct to the internal
// FaceDetectorGraphOptions proto.
std::unique_ptr<FaceDetectorGraphOptionsProto>
ConvertFaceDetectorGraphOptionsProto(FaceDetectorOptions* options) {
auto options_proto = std::make_unique<FaceDetectorGraphOptionsProto>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode != core::RunningMode::IMAGE);
options_proto->set_min_detection_confidence(
options->min_detection_confidence);
options_proto->set_min_suppression_threshold(
options->min_suppression_threshold);
return options_proto;
}
} // namespace
absl::StatusOr<std::unique_ptr<FaceDetector>> FaceDetector::Create(
std::unique_ptr<FaceDetectorOptions> options) {
auto options_proto = ConvertFaceDetectorGraphOptionsProto(options.get());
tasks::core::PacketsCallback packets_callback = nullptr;
if (options->result_callback) {
auto result_callback = options->result_callback;
packets_callback =
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
if (!status_or_packets.ok()) {
Image image;
result_callback(status_or_packets.status(), image,
Timestamp::Unset().Value());
return;
}
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
return;
}
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
if (status_or_packets.value()[kDetectionsStreamName].IsEmpty()) {
Packet empty_packet =
status_or_packets.value()[kDetectionsStreamName];
result_callback(
{FaceDetectorResult()}, image_packet.Get<Image>(),
empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
return;
}
Packet detections_packet =
status_or_packets.value()[kDetectionsStreamName];
result_callback(
components::containers::ConvertToDetectionResult(
detections_packet.Get<std::vector<mediapipe::Detection>>()),
image_packet.Get<Image>(),
detections_packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond);
};
}
return core::VisionTaskApiFactory::Create<FaceDetector,
FaceDetectorGraphOptionsProto>(
CreateGraphConfig(
std::move(options_proto),
options->running_mode == core::RunningMode::LIVE_STREAM),
std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback));
}
absl::StatusOr<FaceDetectorResult> FaceDetector::Detect(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) {
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessImageData(
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
if (output_packets[kDetectionsStreamName].IsEmpty()) {
return {FaceDetectorResult()};
}
return components::containers::ConvertToDetectionResult(
output_packets[kDetectionsStreamName]
.Get<std::vector<mediapipe::Detection>>());
}
absl::StatusOr<FaceDetectorResult> FaceDetector::DetectForVideo(
mediapipe::Image image, uint64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessVideoData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
if (output_packets[kDetectionsStreamName].IsEmpty()) {
return {FaceDetectorResult()};
}
return components::containers::ConvertToDetectionResult(
output_packets[kDetectionsStreamName]
.Get<std::vector<mediapipe::Detection>>());
}
absl::Status FaceDetector::DetectAsync(
mediapipe::Image image, uint64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
return SendLiveStreamData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
}
} // namespace face_detector
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,163 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_H_
#define MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_H_
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace face_detector {
using FaceDetectorResult =
::mediapipe::tasks::components::containers::DetectionResult;
// The options for configuring a mediapipe face detector task.
struct FaceDetectorOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, op resolver, etc.
tasks::core::BaseOptions base_options;
// The running mode of the task. Default to the image mode.
// Face Detector has three running modes:
// 1) The image mode for detecting faces on single image inputs.
// 2) The video mode for detecting faces on the decoded frames of a video.
// 3) The live stream mode for detecting faces on the live stream of input
// data, such as from camera. In this mode, the "result_callback" below must
// be specified to receive the detection results asynchronously.
core::RunningMode running_mode = core::RunningMode::IMAGE;
// The minimum confidence score for the face detection to be considered
// successful.
float min_detection_confidence = 0.5;
// The minimum non-maximum-suppression threshold for face detection to be
// considered overlapped.
float min_suppression_threshold = 0.3;
// The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM.
std::function<void(absl::StatusOr<FaceDetectorResult>, const Image&,
uint64_t)>
result_callback = nullptr;
};
class FaceDetector : core::BaseVisionTaskApi {
public:
using BaseVisionTaskApi::BaseVisionTaskApi;
// Creates a FaceDetector from a FaceDetectorOptions to process image data
// or streaming data. Face detector can be created with one of the following
// three running modes:
// 1) Image mode for detecting faces on single image inputs. Users
// provide mediapipe::Image to the `Detect` method, and will receive the
// deteced face detection results as the return value.
// 2) Video mode for detecting faces on the decoded frames of a
// video. Users call `DetectForVideo` method, and will receive the detected
// face detection results as the return value.
// 3) Live stream mode for detecting faces on the live stream of the
// input data, such as from camera. Users call `DetectAsync` to push the
// image data into the FaceDetector, the detected results along with the
// input timestamp and the image that face detector runs on will be
// available in the result callback when the face detector finishes the
// work.
static absl::StatusOr<std::unique_ptr<FaceDetector>> Create(
std::unique_ptr<FaceDetectorOptions> options);
// Performs face detection on the given image.
// Only use this method when the FaceDetector is created with the image
// running mode.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing detection, by setting
// its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// The image can be of any size with format RGB or RGBA.
// TODO: Describes how the input image will be preprocessed
// after the yuv support is implemented.
absl::StatusOr<FaceDetectorResult> Detect(
Image image,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Performs face detection on the provided video frame.
// Only use this method when the FaceDetector is created with the video
// running mode.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing detection, by setting
// its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
absl::StatusOr<FaceDetectorResult> DetectForVideo(
Image image, uint64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Sends live image data to perform face detection, and the results
// will be available via the "result_callback" provided in the
// FaceDetectorOptions. Only use this method when the FaceDetector
// is created with the live stream running mode.
//
// The image can be of any size with format RGB or RGBA. It's required to
// provide a timestamp (in milliseconds) to indicate when the input image is
// sent to the face detector. The input timestamps must be monotonically
// increasing.
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing detection, by setting
// its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// The "result_callback" provides
// - A vector of FaceDetectorResult, each is the detected results
// for a input frame.
// - The const reference to the corresponding input image that the face
// detector runs on. Note that the const reference to the image will no
// longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
absl::Status DetectAsync(Image image, uint64_t timestamp_ms,
std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt);
// Shuts down the FaceDetector when all works are done.
absl::Status Close() { return runner_->Close(); }
};
} // namespace face_detector
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_FACE_DETECTOR_FACE_DETECTOR_H_

View File

@ -280,8 +280,10 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
auto& detection_projection = graph.AddNode("DetectionProjectionCalculator"); auto& detection_projection = graph.AddNode("DetectionProjectionCalculator");
nms_detections >> detection_projection.In(kDetectionsTag); nms_detections >> detection_projection.In(kDetectionsTag);
matrix >> detection_projection.In(kProjectionMatrixTag); matrix >> detection_projection.In(kProjectionMatrixTag);
auto face_detections = detection_projection.Out(kDetectionsTag); Source<std::vector<Detection>> face_detections =
detection_projection.Out(kDetectionsTag).Cast<std::vector<Detection>>();
if (subgraph_options.has_num_faces()) {
// Clip face detections to maximum number of faces; // Clip face detections to maximum number of faces;
auto& clip_detection_vector_size = auto& clip_detection_vector_size =
graph.AddNode("ClipDetectionVectorSizeCalculator"); graph.AddNode("ClipDetectionVectorSizeCalculator");
@ -289,8 +291,9 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
.GetOptions<mediapipe::ClipVectorSizeCalculatorOptions>() .GetOptions<mediapipe::ClipVectorSizeCalculatorOptions>()
.set_max_vec_size(subgraph_options.num_faces()); .set_max_vec_size(subgraph_options.num_faces());
face_detections >> clip_detection_vector_size.In(""); face_detections >> clip_detection_vector_size.In("");
auto clipped_face_detections = face_detections =
clip_detection_vector_size.Out("").Cast<std::vector<Detection>>(); clip_detection_vector_size.Out("").Cast<std::vector<Detection>>();
}
// Converts results of face detection into a rectangle (normalized by image // Converts results of face detection into a rectangle (normalized by image
// size) that encloses the face and is rotated such that the line connecting // size) that encloses the face and is rotated such that the line connecting
@ -300,7 +303,7 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
&detections_to_rects &detections_to_rects
.GetOptions<mediapipe::DetectionsToRectsCalculatorOptions>()); .GetOptions<mediapipe::DetectionsToRectsCalculatorOptions>());
image_size >> detections_to_rects.In(kImageSizeTag); image_size >> detections_to_rects.In(kImageSizeTag);
clipped_face_detections >> detections_to_rects.In(kDetectionsTag); face_detections >> detections_to_rects.In(kDetectionsTag);
auto face_rects = detections_to_rects.Out(kNormRectsTag) auto face_rects = detections_to_rects.Out(kNormRectsTag)
.Cast<std::vector<NormalizedRect>>(); .Cast<std::vector<NormalizedRect>>();

View File

@ -0,0 +1,303 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/face_detector/face_detector.h"
#include <vector>
#include "absl/flags/flag.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/detection_result.h"
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace face_detector {
namespace {
using ::file::Defaults;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::NormalizedKeypoint;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::TestParamInfo;
using ::testing::TestWithParam;
using ::testing::Values;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kShortRangeBlazeFaceModel[] =
"face_detection_short_range.tflite";
constexpr char kPortraitImage[] = "portrait.jpg";
constexpr char kPortraitRotatedImage[] = "portrait_rotated.jpg";
constexpr char kPortraitExpectedDetection[] =
"portrait_expected_detection.pbtxt";
constexpr char kPortraitRotatedExpectedDetection[] =
"portrait_rotated_expected_detection.pbtxt";
constexpr char kCatImageName[] = "cat.jpg";
constexpr float kKeypointErrorThreshold = 1e-2;
FaceDetectorResult GetExpectedFaceDetectorResult(absl::string_view file_name) {
mediapipe::Detection detection;
CHECK_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name),
&detection, Defaults()))
<< "Expected face detection result does not exist.";
return components::containers::ConvertToDetectionResult({detection});
}
void ExpectKeypointsCorrect(
const std::vector<NormalizedKeypoint> actual_keypoints,
const std::vector<NormalizedKeypoint> expected_keypoints) {
ASSERT_EQ(actual_keypoints.size(), expected_keypoints.size());
for (int i = 0; i < actual_keypoints.size(); i++) {
EXPECT_NEAR(actual_keypoints[i].x, expected_keypoints[i].x,
kKeypointErrorThreshold);
EXPECT_NEAR(actual_keypoints[i].y, expected_keypoints[i].y,
kKeypointErrorThreshold);
}
}
void ExpectFaceDetectorResultsCorrect(
const FaceDetectorResult& actual_results,
const FaceDetectorResult& expected_results) {
EXPECT_EQ(actual_results.detections.size(),
expected_results.detections.size());
for (int i = 0; i < actual_results.detections.size(); i++) {
const auto& actual_bbox = actual_results.detections[i].bounding_box;
const auto& expected_bbox = expected_results.detections[i].bounding_box;
EXPECT_EQ(actual_bbox, expected_bbox);
ASSERT_TRUE(actual_results.detections[i].keypoints.has_value());
ExpectKeypointsCorrect(actual_results.detections[i].keypoints.value(),
expected_results.detections[i].keypoints.value());
}
}
struct TestParams {
// The name of this test, for convenience when displaying test results.
std::string test_name;
// The filename of test image.
std::string test_image_name;
// The filename of face landmark detection model.
std::string face_detection_model_name;
// The rotation to apply to the test image before processing, in degrees
// clockwise.
int rotation;
// Expected face detector results.
FaceDetectorResult expected_result;
};
class ImageModeTest : public testing::TestWithParam<TestParams> {};
TEST_P(ImageModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
auto options = std::make_unique<FaceDetectorOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, GetParam().face_detection_model_name);
options->running_mode = core::RunningMode::IMAGE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<FaceDetector> face_detector,
FaceDetector::Create(std::move(options)));
FaceDetectorResult face_detector_result;
if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation;
MP_ASSERT_OK_AND_ASSIGN(
face_detector_result,
face_detector->Detect(image, image_processing_options));
} else {
MP_ASSERT_OK_AND_ASSIGN(face_detector_result, face_detector->Detect(image));
}
ExpectFaceDetectorResultsCorrect(face_detector_result,
GetParam().expected_result);
MP_ASSERT_OK(face_detector->Close());
}
INSTANTIATE_TEST_SUITE_P(
FaceDetectorTest, ImageModeTest,
Values(
TestParams{/* test_name= */ "PortraitShortRange",
/* test_image_name= */ kPortraitImage,
/* face_detection_model_name= */ kShortRangeBlazeFaceModel,
/* rotation= */ 0,
/* expected_result = */
GetExpectedFaceDetectorResult(kPortraitExpectedDetection)},
TestParams{
/* test_name= */ "PortraitRotatedShortRange",
/* test_image_name= */ kPortraitRotatedImage,
/* face_detection_model_name= */ kShortRangeBlazeFaceModel,
/* rotation= */ -90,
/* expected_result = */
GetExpectedFaceDetectorResult(kPortraitRotatedExpectedDetection)},
TestParams{/* test_name= */ "NoFace",
/* test_image_name= */ kCatImageName,
/* face_detection_model_name= */ kShortRangeBlazeFaceModel,
/* rotation= */ 0,
/* expected_result = */
{}}),
[](const TestParamInfo<ImageModeTest::ParamType>& info) {
return info.param.test_name;
});
class VideoModeTest : public testing::TestWithParam<TestParams> {};
TEST_P(VideoModeTest, Succeeds) {
const int iterations = 100;
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
auto options = std::make_unique<FaceDetectorOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, GetParam().face_detection_model_name);
options->running_mode = core::RunningMode::VIDEO;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<FaceDetector> face_detector,
FaceDetector::Create(std::move(options)));
const FaceDetectorResult& expected_result = GetParam().expected_result;
for (int i = 0; i < iterations; i++) {
FaceDetectorResult face_detector_result;
if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation;
MP_ASSERT_OK_AND_ASSIGN(
face_detector_result,
face_detector->DetectForVideo(image, i, image_processing_options));
} else {
MP_ASSERT_OK_AND_ASSIGN(face_detector_result,
face_detector->DetectForVideo(image, i));
}
ExpectFaceDetectorResultsCorrect(face_detector_result, expected_result);
}
MP_ASSERT_OK(face_detector->Close());
}
INSTANTIATE_TEST_SUITE_P(
FaceDetectorTest, VideoModeTest,
Values(
TestParams{/* test_name= */ "PortraitShortRange",
/* test_image_name= */ kPortraitImage,
/* face_detection_model_name= */ kShortRangeBlazeFaceModel,
/* rotation= */ 0,
/* expected_result = */
GetExpectedFaceDetectorResult(kPortraitExpectedDetection)},
TestParams{
/* test_name= */ "PortraitRotatedShortRange",
/* test_image_name= */ kPortraitRotatedImage,
/* face_detection_model_name= */ kShortRangeBlazeFaceModel,
/* rotation= */ -90,
/* expected_result = */
GetExpectedFaceDetectorResult(kPortraitRotatedExpectedDetection)},
TestParams{/* test_name= */ "NoFace",
/* test_image_name= */ kCatImageName,
/* face_detection_model_name= */ kShortRangeBlazeFaceModel,
/* rotation= */ 0,
/* expected_result = */
{}}),
[](const TestParamInfo<ImageModeTest::ParamType>& info) {
return info.param.test_name;
});
class LiveStreamModeTest : public testing::TestWithParam<TestParams> {};
TEST_P(LiveStreamModeTest, Succeeds) {
const int iterations = 100;
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
auto options = std::make_unique<FaceDetectorOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, GetParam().face_detection_model_name);
options->running_mode = core::RunningMode::LIVE_STREAM;
std::vector<FaceDetectorResult> face_detector_results;
std::vector<std::pair<int, int>> image_sizes;
std::vector<uint64_t> timestamps;
options->result_callback = [&face_detector_results, &image_sizes,
&timestamps](
absl::StatusOr<FaceDetectorResult> results,
const Image& image, uint64_t timestamp_ms) {
MP_ASSERT_OK(results.status());
face_detector_results.push_back(std::move(results.value()));
image_sizes.push_back({image.width(), image.height()});
timestamps.push_back(timestamp_ms);
};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<FaceDetector> face_detector,
FaceDetector::Create(std::move(options)));
for (int i = 0; i < iterations; ++i) {
if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation;
MP_ASSERT_OK(
face_detector->DetectAsync(image, i + 1, image_processing_options));
} else {
MP_ASSERT_OK(face_detector->DetectAsync(image, i + 1));
}
}
MP_ASSERT_OK(face_detector->Close());
ASSERT_LE(face_detector_results.size(), iterations);
ASSERT_GT(face_detector_results.size(), 0);
const FaceDetectorResult& expected_results = GetParam().expected_result;
for (int i = 0; i < face_detector_results.size(); ++i) {
ExpectFaceDetectorResultsCorrect(face_detector_results[i],
expected_results);
}
for (const auto& image_size : image_sizes) {
EXPECT_EQ(image_size.first, image.width());
EXPECT_EQ(image_size.second, image.height());
}
uint64_t timestamp_ms = 0;
for (const auto& timestamp : timestamps) {
EXPECT_GT(timestamp, timestamp_ms);
timestamp_ms = timestamp;
}
}
INSTANTIATE_TEST_SUITE_P(
FaceDetectorTest, LiveStreamModeTest,
Values(
TestParams{/* test_name= */ "PortraitShortRange",
/* test_image_name= */ kPortraitImage,
/* face_detection_model_name= */ kShortRangeBlazeFaceModel,
/* rotation= */ 0,
/* expected_result = */
GetExpectedFaceDetectorResult(kPortraitExpectedDetection)},
TestParams{
/* test_name= */ "PortraitRotatedShortRange",
/* test_image_name= */ kPortraitRotatedImage,
/* face_detection_model_name= */ kShortRangeBlazeFaceModel,
/* rotation= */ -90,
/* expected_result = */
GetExpectedFaceDetectorResult(kPortraitRotatedExpectedDetection)},
TestParams{/* test_name= */ "NoFace",
/* test_image_name= */ kCatImageName,
/* face_detection_model_name= */ kShortRangeBlazeFaceModel,
/* rotation= */ 0,
/* expected_result = */
{}}),
[](const TestParamInfo<ImageModeTest::ParamType>& info) {
return info.param.test_name;
});
} // namespace
} // namespace face_detector
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -41,5 +41,5 @@ message FaceDetectorGraphOptions {
optional float min_suppression_threshold = 3 [default = 0.5]; optional float min_suppression_threshold = 3 [default = 0.5];
// Maximum number of faces to detect in the image. // Maximum number of faces to detect in the image.
optional int32 num_faces = 4 [default = 1]; optional int32 num_faces = 4;
} }

View File

@ -64,6 +64,7 @@ mediapipe_files(srcs = [
"pointing_up.jpg", "pointing_up.jpg",
"pointing_up_rotated.jpg", "pointing_up_rotated.jpg",
"portrait.jpg", "portrait.jpg",
"portrait_rotated.jpg",
"right_hands.jpg", "right_hands.jpg",
"right_hands_rotated.jpg", "right_hands_rotated.jpg",
"segmentation_golden_rotation0.png", "segmentation_golden_rotation0.png",
@ -86,6 +87,7 @@ exports_files(
"expected_right_up_hand_landmarks.prototxt", "expected_right_up_hand_landmarks.prototxt",
"gesture_recognizer.task", "gesture_recognizer.task",
"portrait_expected_detection.pbtxt", "portrait_expected_detection.pbtxt",
"portrait_rotated_expected_detection.pbtxt",
], ],
) )
@ -114,6 +116,7 @@ filegroup(
"pointing_up.jpg", "pointing_up.jpg",
"pointing_up_rotated.jpg", "pointing_up_rotated.jpg",
"portrait.jpg", "portrait.jpg",
"portrait_rotated.jpg",
"right_hands.jpg", "right_hands.jpg",
"right_hands_rotated.jpg", "right_hands_rotated.jpg",
"segmentation_golden_rotation0.png", "segmentation_golden_rotation0.png",
@ -178,6 +181,7 @@ filegroup(
"portrait_expected_detection.pbtxt", "portrait_expected_detection.pbtxt",
"portrait_expected_face_landmarks.pbtxt", "portrait_expected_face_landmarks.pbtxt",
"portrait_expected_face_landmarks_with_attention.pbtxt", "portrait_expected_face_landmarks_with_attention.pbtxt",
"portrait_rotated_expected_detection.pbtxt",
"thumb_up_landmarks.pbtxt", "thumb_up_landmarks.pbtxt",
"thumb_up_rotated_landmarks.pbtxt", "thumb_up_rotated_landmarks.pbtxt",
"victory_landmarks.pbtxt", "victory_landmarks.pbtxt",

View File

@ -0,0 +1,35 @@
# proto-file: mediapipe/framework/formats/detection.proto
# proto-message: Detection
location_data {
format: BOUNDING_BOX
bounding_box {
xmin: 674
ymin: 283
width: 236
height: 236
}
relative_keypoints {
x: 0.8207535
y: 0.44679973
}
relative_keypoints {
x: 0.8196571
y: 0.56261915
}
relative_keypoints {
x: 0.76194185
y: 0.5171923
}
relative_keypoints {
x: 0.7199387
y: 0.5136067
}
relative_keypoints {
x: 0.8070089
y: 0.36298043
}
relative_keypoints {
x: 0.8088217
y: 0.61204016
}
}

View File

@ -786,8 +786,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_portrait_expected_face_landmarks_with_attention_pbtxt", name = "com_google_mediapipe_portrait_expected_face_landmarks_with_attention_pbtxt",
sha256 = "dae959456f001015278f3a1535bd03c9fa0990a3df951135645ce23293be0613", sha256 = "f2ccd889654b914996e4aab0d7831a3e73d3b63d6c14f6bac4bec5cd3415bce4",
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_landmarks_with_attention.pbtxt?generation=1677522777298874"], urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_landmarks_with_attention.pbtxt?generation=1676415475626542"],
) )
http_file( http_file(
@ -796,6 +796,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait.jpg?generation=1674261630039907"], urls = ["https://storage.googleapis.com/mediapipe-assets/portrait.jpg?generation=1674261630039907"],
) )
http_file(
name = "com_google_mediapipe_portrait_rotated_expected_detection_pbtxt",
sha256 = "7e680fe0918d1e8409b0e0e4576a982e20afa871e6af9c13b7a626de1d5341a2",
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_rotated_expected_detection.pbtxt?generation=1677194677875312"],
)
http_file(
name = "com_google_mediapipe_portrait_rotated_jpg",
sha256 = "f91ca0e4f827b06e9ac037cf58d95f1f3ffbe34238119b7d47eda35456007f33",
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_rotated.jpg?generation=1677194680138164"],
)
http_file( http_file(
name = "com_google_mediapipe_pose_detection_tflite", name = "com_google_mediapipe_pose_detection_tflite",
sha256 = "a63c614bef30d35947f13be361820b1e4e3bec9cfeebf4d11216a18373108e85", sha256 = "a63c614bef30d35947f13be361820b1e4e3bec9cfeebf4d11216a18373108e85",