Merge branch 'master' into ios-gesture-recognizer-files

This commit is contained in:
Prianka Liz Kariat 2023-05-25 19:17:18 +05:30
commit b16905e362
35 changed files with 1761 additions and 96 deletions

View File

@ -194,6 +194,7 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
@ -225,10 +226,8 @@ cc_library(
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
"//mediapipe/util:render_data_cc_proto", "//mediapipe/util:render_data_cc_proto",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
], ],
@ -907,6 +906,7 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",

View File

@ -17,6 +17,7 @@
#include <vector> #include <vector>
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
@ -72,4 +73,7 @@ typedef BeginLoopCalculator<std::vector<GpuBuffer>>
BeginLoopGpuBufferCalculator; BeginLoopGpuBufferCalculator;
REGISTER_CALCULATOR(BeginLoopGpuBufferCalculator); REGISTER_CALCULATOR(BeginLoopGpuBufferCalculator);
// A calculator to process std::vector<mediapipe::Image>.
typedef BeginLoopCalculator<std::vector<Image>> BeginLoopImageCalculator;
REGISTER_CALCULATOR(BeginLoopImageCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -80,4 +80,8 @@ typedef EndLoopCalculator<std::vector<::mediapipe::Image>>
EndLoopImageCalculator; EndLoopImageCalculator;
REGISTER_CALCULATOR(EndLoopImageCalculator); REGISTER_CALCULATOR(EndLoopImageCalculator);
typedef EndLoopCalculator<std::vector<std::array<float, 16>>>
EndLoopAffineMatrixCalculator;
REGISTER_CALCULATOR(EndLoopAffineMatrixCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -18,6 +18,7 @@
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
@ -86,4 +87,12 @@ REGISTER_CALCULATOR(SplitUint64tVectorCalculator);
typedef SplitVectorCalculator<float, false> SplitFloatVectorCalculator; typedef SplitVectorCalculator<float, false> SplitFloatVectorCalculator;
REGISTER_CALCULATOR(SplitFloatVectorCalculator); REGISTER_CALCULATOR(SplitFloatVectorCalculator);
typedef SplitVectorCalculator<mediapipe::Image, false>
SplitImageVectorCalculator;
REGISTER_CALCULATOR(SplitImageVectorCalculator);
typedef SplitVectorCalculator<std::array<float, 16>, false>
SplitAffineMatrixVectorCalculator;
REGISTER_CALCULATOR(SplitAffineMatrixVectorCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -231,18 +231,26 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
SubgraphContext* sc) override { SubgraphContext* sc) override {
bool output_stylized = HasOutput(sc->OriginalNode(), kStylizedImageTag); bool output_stylized = HasOutput(sc->OriginalNode(), kStylizedImageTag);
bool output_alignment = HasOutput(sc->OriginalNode(), kFaceAlignmentTag); bool output_alignment = HasOutput(sc->OriginalNode(), kFaceAlignmentTag);
ASSIGN_OR_RETURN(
const auto* model_asset_bundle_resources,
CreateModelAssetBundleResources<FaceStylizerGraphOptions>(sc));
// Copies the file content instead of passing the pointer of file in
// memory if the subgraph model resource service is not available.
auto face_stylizer_external_file = absl::make_unique<ExternalFile>(); auto face_stylizer_external_file = absl::make_unique<ExternalFile>();
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( if (sc->Options<FaceStylizerGraphOptions>().has_base_options()) {
*model_asset_bundle_resources, ASSIGN_OR_RETURN(
sc->MutableOptions<FaceStylizerGraphOptions>(), const auto* model_asset_bundle_resources,
output_stylized ? face_stylizer_external_file.get() : nullptr, CreateModelAssetBundleResources<FaceStylizerGraphOptions>(sc));
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) // Copies the file content instead of passing the pointer of file in
.IsAvailable())); // memory if the subgraph model resource service is not available.
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
*model_asset_bundle_resources,
sc->MutableOptions<FaceStylizerGraphOptions>(),
output_stylized ? face_stylizer_external_file.get() : nullptr,
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable()));
} else if (output_stylized) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Face stylizer must specify its base options when the "
"\"STYLIZED_IMAGE\" output stream is connected.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
Graph graph; Graph graph;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto face_landmark_lists, auto face_landmark_lists,
@ -347,7 +355,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
auto& image_to_tensor = graph.AddNode("ImageToTensorCalculator"); auto& image_to_tensor = graph.AddNode("ImageToTensorCalculator");
auto& image_to_tensor_options = auto& image_to_tensor_options =
image_to_tensor.GetOptions<ImageToTensorCalculatorOptions>(); image_to_tensor.GetOptions<ImageToTensorCalculatorOptions>();
image_to_tensor_options.mutable_output_tensor_float_range()->set_min(-1); image_to_tensor_options.mutable_output_tensor_float_range()->set_min(0);
image_to_tensor_options.mutable_output_tensor_float_range()->set_max(1); image_to_tensor_options.mutable_output_tensor_float_range()->set_max(1);
image_to_tensor_options.set_output_tensor_width(kFaceAlignmentOutputSize); image_to_tensor_options.set_output_tensor_width(kFaceAlignmentOutputSize);
image_to_tensor_options.set_output_tensor_height( image_to_tensor_options.set_output_tensor_height(
@ -363,7 +371,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
graph.AddNode("mediapipe.tasks.TensorsToImageCalculator"); graph.AddNode("mediapipe.tasks.TensorsToImageCalculator");
auto& tensors_to_image_options = auto& tensors_to_image_options =
tensors_to_image.GetOptions<TensorsToImageCalculatorOptions>(); tensors_to_image.GetOptions<TensorsToImageCalculatorOptions>();
tensors_to_image_options.mutable_input_tensor_float_range()->set_min(-1); tensors_to_image_options.mutable_input_tensor_float_range()->set_min(0);
tensors_to_image_options.mutable_input_tensor_float_range()->set_max(1); tensors_to_image_options.mutable_input_tensor_float_range()->set_max(1);
face_alignment_image >> tensors_to_image.In(kTensorsTag); face_alignment_image >> tensors_to_image.In(kTensorsTag);
face_alignment = tensors_to_image.Out(kImageTag).Cast<Image>(); face_alignment = tensors_to_image.Out(kImageTag).Cast<Image>();

View File

@ -49,11 +49,12 @@ OBJC_TASK_COMMON_DEPS = [
] ]
CALCULATORS_AND_GRAPHS = [ CALCULATORS_AND_GRAPHS = [
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
"//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph", "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph",
"//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/tasks/cc/vision/face_detector:face_detector_graph",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
] ]
strip_api_include_path_prefix( strip_api_include_path_prefix(
@ -76,6 +77,9 @@ strip_api_include_path_prefix(
"//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderResult.h", "//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderResult.h",
"//mediapipe/tasks/ios/vision/core:sources/MPPRunningMode.h", "//mediapipe/tasks/ios/vision/core:sources/MPPRunningMode.h",
"//mediapipe/tasks/ios/vision/core:sources/MPPImage.h", "//mediapipe/tasks/ios/vision/core:sources/MPPImage.h",
"//mediapipe/tasks/ios/vision/face_detector:sources/MPPFaceDetector.h",
"//mediapipe/tasks/ios/vision/face_detector:sources/MPPFaceDetectorOptions.h",
"//mediapipe/tasks/ios/vision/face_detector:sources/MPPFaceDetectorResult.h",
"//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifier.h", "//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifier.h",
"//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierOptions.h", "//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierOptions.h",
"//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierResult.h", "//mediapipe/tasks/ios/vision/image_classifier:sources/MPPImageClassifierResult.h",
@ -157,6 +161,9 @@ apple_static_xcframework(
":MPPTaskResult.h", ":MPPTaskResult.h",
":MPPImage.h", ":MPPImage.h",
":MPPRunningMode.h", ":MPPRunningMode.h",
":MPPFaceDetector.h",
":MPPFaceDetectorOptions.h",
":MPPFaceDetectorResult.h",
":MPPImageClassifier.h", ":MPPImageClassifier.h",
":MPPImageClassifierOptions.h", ":MPPImageClassifierOptions.h",
":MPPImageClassifierResult.h", ":MPPImageClassifierResult.h",
@ -165,6 +172,7 @@ apple_static_xcframework(
":MPPObjectDetectorResult.h", ":MPPObjectDetectorResult.h",
], ],
deps = [ deps = [
"//mediapipe/tasks/ios/vision/face_detector:MPPFaceDetector",
"//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifier", "//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifier",
"//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetector", "//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetector",
], ],

View File

@ -0,0 +1,64 @@
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
load(
"//mediapipe/framework/tool:ios.bzl",
"MPP_TASK_MINIMUM_OS_VERSION",
)
load(
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
"tflite_ios_lab_runner",
)
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
TFL_DEFAULT_TAGS = [
"apple",
]
# Following sanitizer tests are not supported by iOS test targets.
TFL_DISABLED_SANITIZER_TAGS = [
"noasan",
"nomsan",
"notsan",
]
objc_library(
name = "MPPFaceDetectorObjcTestLibrary",
testonly = 1,
srcs = ["MPPFaceDetectorTests.mm"],
copts = [
"-ObjC++",
"-std=c++17",
"-x objective-c++",
],
data = [
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
"//mediapipe/tasks/testdata/vision:test_protos",
],
deps = [
"//mediapipe/tasks/ios/common:MPPCommon",
"//mediapipe/tasks/ios/components/containers/utils:MPPDetectionHelpers",
"//mediapipe/tasks/ios/test/vision/utils:MPPImageTestUtils",
"//mediapipe/tasks/ios/vision/face_detector:MPPFaceDetector",
"//mediapipe/tasks/ios/vision/face_detector:MPPFaceDetectorResult",
"//third_party/apple_frameworks:UIKit",
] + select({
"//third_party:opencv_ios_sim_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
"//third_party:opencv_ios_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
"//third_party:opencv_ios_x86_64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
"//conditions:default": ["@ios_opencv//:OpencvFramework"],
}),
)
ios_unit_test(
name = "MPPFaceDetectorObjcTest",
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
deps = [
":MPPFaceDetectorObjcTestLibrary",
],
)

View File

@ -0,0 +1,522 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import <UIKit/UIKit.h>
#import <XCTest/XCTest.h>
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.h"
#import "mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h"
#import "mediapipe/tasks/ios/vision/face_detector/sources/MPPFaceDetector.h"
#import "mediapipe/tasks/ios/vision/face_detector/sources/MPPFaceDetectorResult.h"
static NSDictionary *const kPortraitImage =
@{@"name" : @"portrait", @"type" : @"jpg", @"orientation" : @(UIImageOrientationUp)};
static NSDictionary *const kPortraitRotatedImage =
@{@"name" : @"portrait_rotated", @"type" : @"jpg", @"orientation" : @(UIImageOrientationRight)};
static NSDictionary *const kCatImage = @{@"name" : @"cat", @"type" : @"jpg"};
static NSString *const kShortRangeBlazeFaceModel = @"face_detection_short_range";
static NSArray<NSArray *> *const kPortraitExpectedKeypoints = @[
@[ @0.44416f, @0.17643f ], @[ @0.55514f, @0.17731f ], @[ @0.50467f, @0.22657f ],
@[ @0.50227f, @0.27199f ], @[ @0.36063f, @0.20143f ], @[ @0.60841f, @0.20409f ]
];
static NSArray<NSArray *> *const kPortraitRotatedExpectedKeypoints = @[
@[ @0.82075f, @0.44679f ], @[ @0.81965f, @0.56261f ], @[ @0.76194f, @0.51719f ],
@[ @0.71993f, @0.51719f ], @[ @0.80700f, @0.36298f ], @[ @0.80882f, @0.61204f ]
];
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
static NSString *const kLiveStreamTestsDictFaceDetectorKey = @"face_detector";
static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
static const float kKeypointErrorThreshold = 1e-2;
#define AssertEqualErrors(error, expectedError) \
XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
@interface MPPFaceDetectorTests : XCTestCase <MPPFaceDetectorLiveStreamDelegate> {
NSDictionary *liveStreamSucceedsTestDict;
NSDictionary *outOfOrderTimestampTestDict;
}
@end
@implementation MPPFaceDetectorTests
#pragma mark General Tests
- (void)testCreateFaceDetectorWithMissingModelPathFails {
NSString *modelPath = [MPPFaceDetectorTests filePathWithName:@"" extension:@""];
NSError *error = nil;
MPPFaceDetector *faceDetector = [[MPPFaceDetector alloc] initWithModelPath:modelPath
error:&error];
XCTAssertNil(faceDetector);
NSError *expectedError = [NSError
errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', "
@"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."
}];
AssertEqualErrors(error, expectedError);
}
#pragma mark Image Mode Tests
- (void)testDetectWithImageModeAndPotraitSucceeds {
NSString *modelPath = [MPPFaceDetectorTests filePathWithName:kShortRangeBlazeFaceModel
extension:@"tflite"];
MPPFaceDetector *faceDetector = [[MPPFaceDetector alloc] initWithModelPath:modelPath error:nil];
[self assertResultsOfDetectInImageWithFileInfo:kPortraitImage
usingFaceDetector:faceDetector
containsExpectedKeypoints:kPortraitExpectedKeypoints];
}
- (void)testDetectWithImageModeAndRotatedPotraitSucceeds {
NSString *modelPath = [MPPFaceDetectorTests filePathWithName:kShortRangeBlazeFaceModel
extension:@"tflite"];
MPPFaceDetector *faceDetector = [[MPPFaceDetector alloc] initWithModelPath:modelPath error:nil];
XCTAssertNotNil(faceDetector);
MPPImage *image = [self imageWithFileInfo:kPortraitRotatedImage];
[self assertResultsOfDetectInImage:image
usingFaceDetector:faceDetector
containsExpectedKeypoints:kPortraitRotatedExpectedKeypoints];
}
- (void)testDetectWithImageModeAndNoFaceSucceeds {
NSString *modelPath = [MPPFaceDetectorTests filePathWithName:kShortRangeBlazeFaceModel
extension:@"tflite"];
MPPFaceDetector *faceDetector = [[MPPFaceDetector alloc] initWithModelPath:modelPath error:nil];
XCTAssertNotNil(faceDetector);
NSError *error;
MPPImage *mppImage = [self imageWithFileInfo:kCatImage];
MPPFaceDetectorResult *faceDetectorResult = [faceDetector detectInImage:mppImage error:&error];
XCTAssertNil(error);
XCTAssertNotNil(faceDetectorResult);
XCTAssertEqual(faceDetectorResult.detections.count, 0);
}
#pragma mark Video Mode Tests
- (void)testDetectWithVideoModeAndPotraitSucceeds {
MPPFaceDetectorOptions *options =
[self faceDetectorOptionsWithModelName:kShortRangeBlazeFaceModel];
options.runningMode = MPPRunningModeVideo;
MPPFaceDetector *faceDetector = [self faceDetectorWithOptionsSucceeds:options];
MPPImage *image = [self imageWithFileInfo:kPortraitImage];
for (int i = 0; i < 3; i++) {
MPPFaceDetectorResult *faceDetectorResult = [faceDetector detectInVideoFrame:image
timestampInMilliseconds:i
error:nil];
[self assertFaceDetectorResult:faceDetectorResult
containsExpectedKeypoints:kPortraitExpectedKeypoints];
}
}
- (void)testDetectWithVideoModeAndRotatedPotraitSucceeds {
MPPFaceDetectorOptions *options =
[self faceDetectorOptionsWithModelName:kShortRangeBlazeFaceModel];
options.runningMode = MPPRunningModeVideo;
MPPFaceDetector *faceDetector = [self faceDetectorWithOptionsSucceeds:options];
MPPImage *image = [self imageWithFileInfo:kPortraitRotatedImage];
for (int i = 0; i < 3; i++) {
MPPFaceDetectorResult *faceDetectorResult = [faceDetector detectInVideoFrame:image
timestampInMilliseconds:i
error:nil];
[self assertFaceDetectorResult:faceDetectorResult
containsExpectedKeypoints:kPortraitRotatedExpectedKeypoints];
}
}
#pragma mark Live Stream Mode Tests
- (void)testDetectWithLiveStreamModeAndPotraitSucceeds {
NSInteger iterationCount = 100;
// Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
// normal expectation will fail if expectation.fullfill() is not called
// `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
// only succeed if expectation is not fullfilled for the specified `expectedFulfillmentCount`.
// Since it is not possible to predict how many times the expectation is supposed to be
// fullfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
// `expectation.isInverted = true` ensures that test succeeds if expectation is fullfilled <=
// `iterationCount` times.
XCTestExpectation *expectation = [[XCTestExpectation alloc]
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
expectation.expectedFulfillmentCount = iterationCount + 1;
expectation.inverted = YES;
MPPFaceDetectorOptions *options =
[self faceDetectorOptionsWithModelName:kShortRangeBlazeFaceModel];
options.runningMode = MPPRunningModeLiveStream;
options.faceDetectorLiveStreamDelegate = self;
MPPFaceDetector *faceDetector = [self faceDetectorWithOptionsSucceeds:options];
MPPImage *image = [self imageWithFileInfo:kPortraitImage];
liveStreamSucceedsTestDict = @{
kLiveStreamTestsDictFaceDetectorKey : faceDetector,
kLiveStreamTestsDictExpectationKey : expectation
};
for (int i = 0; i < iterationCount; i++) {
XCTAssertTrue([faceDetector detectAsyncInImage:image timestampInMilliseconds:i error:nil]);
}
NSTimeInterval timeout = 0.5f;
[self waitForExpectations:@[ expectation ] timeout:timeout];
}
- (void)testDetectWithOutOfOrderTimestampsAndLiveStreamModeFails {
MPPFaceDetectorOptions *options =
[self faceDetectorOptionsWithModelName:kShortRangeBlazeFaceModel];
options.runningMode = MPPRunningModeLiveStream;
options.faceDetectorLiveStreamDelegate = self;
XCTestExpectation *expectation = [[XCTestExpectation alloc]
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
expectation.expectedFulfillmentCount = 1;
MPPFaceDetector *faceDetector = [self faceDetectorWithOptionsSucceeds:options];
liveStreamSucceedsTestDict = @{
kLiveStreamTestsDictFaceDetectorKey : faceDetector,
kLiveStreamTestsDictExpectationKey : expectation
};
MPPImage *image = [self imageWithFileInfo:kPortraitImage];
XCTAssertTrue([faceDetector detectAsyncInImage:image timestampInMilliseconds:1 error:nil]);
NSError *error;
XCTAssertFalse([faceDetector detectAsyncInImage:image timestampInMilliseconds:0 error:&error]);
NSError *expectedError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"INVALID_ARGUMENT: Input timestamp must be monotonically increasing."
}];
AssertEqualErrors(error, expectedError);
NSTimeInterval timeout = 0.5f;
[self waitForExpectations:@[ expectation ] timeout:timeout];
}
#pragma mark Running Mode Tests
- (void)testCreateFaceDetectorFailsWithDelegateInNonLiveStreamMode {
MPPRunningMode runningModesToTest[] = {MPPRunningModeImage, MPPRunningModeVideo};
for (int i = 0; i < sizeof(runningModesToTest) / sizeof(runningModesToTest[0]); i++) {
MPPFaceDetectorOptions *options =
[self faceDetectorOptionsWithModelName:kShortRangeBlazeFaceModel];
options.runningMode = runningModesToTest[i];
options.faceDetectorLiveStreamDelegate = self;
[self assertCreateFaceDetectorWithOptions:options
failsWithExpectedError:
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"The vision task is in image or video mode. The "
@"delegate must not be set in the task's options."
}]];
}
}
- (void)testCreateFaceDetectorFailsWithMissingDelegateInLiveStreamMode {
MPPFaceDetectorOptions *options =
[self faceDetectorOptionsWithModelName:kShortRangeBlazeFaceModel];
options.runningMode = MPPRunningModeLiveStream;
[self assertCreateFaceDetectorWithOptions:options
failsWithExpectedError:
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"The vision task is in live stream mode. An "
@"object must be set as the delegate of the task "
@"in its options to ensure asynchronous delivery "
@"of results."
}]];
}
- (void)testDetectFailsWithCallingWrongApiInImageMode {
MPPFaceDetectorOptions *options =
[self faceDetectorOptionsWithModelName:kShortRangeBlazeFaceModel];
MPPFaceDetector *faceDetector = [self faceDetectorWithOptionsSucceeds:options];
MPPImage *image = [self imageWithFileInfo:kPortraitImage];
NSError *liveStreamApiCallError;
XCTAssertFalse([faceDetector detectAsyncInImage:image
timestampInMilliseconds:0
error:&liveStreamApiCallError]);
NSError *expectedLiveStreamApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
@"stream mode. Current Running Mode: Image"
}];
AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
NSError *videoApiCallError;
XCTAssertFalse([faceDetector detectInVideoFrame:image
timestampInMilliseconds:0
error:&videoApiCallError]);
NSError *expectedVideoApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
@"video mode. Current Running Mode: Image"
}];
AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
}
- (void)testDetectFailsWithCallingWrongApiInVideoMode {
MPPFaceDetectorOptions *options =
[self faceDetectorOptionsWithModelName:kShortRangeBlazeFaceModel];
options.runningMode = MPPRunningModeVideo;
MPPFaceDetector *faceDetector = [self faceDetectorWithOptionsSucceeds:options];
MPPImage *image = [self imageWithFileInfo:kPortraitImage];
NSError *liveStreamApiCallError;
XCTAssertFalse([faceDetector detectAsyncInImage:image
timestampInMilliseconds:0
error:&liveStreamApiCallError]);
NSError *expectedLiveStreamApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
@"stream mode. Current Running Mode: Video"
}];
AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
NSError *imageApiCallError;
XCTAssertFalse([faceDetector detectInImage:image error:&imageApiCallError]);
NSError *expectedImageApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
@"image mode. Current Running Mode: Video"
}];
AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
}
- (void)testDetectFailsWithCallingWrongApiInLiveStreamMode {
MPPFaceDetectorOptions *options =
[self faceDetectorOptionsWithModelName:kShortRangeBlazeFaceModel];
options.runningMode = MPPRunningModeLiveStream;
options.faceDetectorLiveStreamDelegate = self;
MPPFaceDetector *faceDetector = [self faceDetectorWithOptionsSucceeds:options];
MPPImage *image = [self imageWithFileInfo:kPortraitImage];
NSError *imageApiCallError;
XCTAssertFalse([faceDetector detectInImage:image error:&imageApiCallError]);
NSError *expectedImageApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
@"image mode. Current Running Mode: Live Stream"
}];
AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
NSError *videoApiCallError;
XCTAssertFalse([faceDetector detectInVideoFrame:image
timestampInMilliseconds:0
error:&videoApiCallError]);
NSError *expectedVideoApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
@"video mode. Current Running Mode: Live Stream"
}];
AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
}
- (void)testDetectWithLiveStreamModeSucceeds {
MPPFaceDetectorOptions *options =
[self faceDetectorOptionsWithModelName:kShortRangeBlazeFaceModel];
options.runningMode = MPPRunningModeLiveStream;
options.faceDetectorLiveStreamDelegate = self;
NSInteger iterationCount = 100;
// Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
// normal expectation will fail if expectation.fullfill() is not called times. An normal
// expectation will fail if expectation.fullfill() is not called
// `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
// only succeed if expectation is not fullfilled for the specified `expectedFulfillmentCount`.
// Since it it not possible to determine how many times the expectation is supposed to be
// fullfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
// `expectation.isInverted = true` ensures that test succeeds if expectation is fullfilled <=
// `iterationCount` times.
XCTestExpectation *expectation = [[XCTestExpectation alloc]
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
expectation.expectedFulfillmentCount = iterationCount + 1;
expectation.inverted = YES;
MPPFaceDetector *faceDetector = [self faceDetectorWithOptionsSucceeds:options];
liveStreamSucceedsTestDict = @{
kLiveStreamTestsDictFaceDetectorKey : faceDetector,
kLiveStreamTestsDictExpectationKey : expectation
};
MPPImage *image = [self imageWithFileInfo:kPortraitImage];
for (int i = 0; i < iterationCount; i++) {
XCTAssertTrue([faceDetector detectAsyncInImage:image timestampInMilliseconds:i error:nil]);
}
NSTimeInterval timeout = 0.5f;
[self waitForExpectations:@[ expectation ] timeout:timeout];
}
#pragma mark MPPFaceDetectorLiveStreamDelegate Methods
- (void)faceDetector:(MPPFaceDetector *)faceDetector
didFinishDetectionWithResult:(MPPFaceDetectorResult *)faceDetectorResult
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError *)error {
[self assertFaceDetectorResult:faceDetectorResult
containsExpectedKeypoints:kPortraitExpectedKeypoints];
if (faceDetector == outOfOrderTimestampTestDict[kLiveStreamTestsDictFaceDetectorKey]) {
[outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
} else if (faceDetector == liveStreamSucceedsTestDict[kLiveStreamTestsDictFaceDetectorKey]) {
[liveStreamSucceedsTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
}
}
+ (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
NSString *filePath =
[[NSBundle bundleForClass:[MPPFaceDetectorTests class]] pathForResource:fileName
ofType:extension];
return filePath;
}
- (void)assertKeypoints:(NSArray<MPPNormalizedKeypoint *> *)keypoints
areEqualToExpectedKeypoints:(NSArray<NSArray *> *)expectedKeypoint {
XCTAssertEqual(keypoints.count, expectedKeypoint.count);
for (int i = 0; i < keypoints.count; ++i) {
XCTAssertEqualWithAccuracy(keypoints[i].location.x, [expectedKeypoint[i][0] floatValue],
kKeypointErrorThreshold, @"index i = %d", i);
XCTAssertEqualWithAccuracy(keypoints[i].location.y, [expectedKeypoint[i][1] floatValue],
kKeypointErrorThreshold, @"index i = %d", i);
}
}
- (void)assertDetections:(NSArray<MPPDetection *> *)detections
containExpectedKeypoints:(NSArray<NSArray *> *)expectedKeypoints {
XCTAssertEqual(detections.count, 1);
MPPDetection *detection = detections[0];
XCTAssertNotNil(detection);
[self assertKeypoints:detections[0].keypoints areEqualToExpectedKeypoints:expectedKeypoints];
}
- (void)assertFaceDetectorResult:(MPPFaceDetectorResult *)faceDetectorResult
containsExpectedKeypoints:(NSArray<NSArray *> *)expectedKeypoints {
[self assertDetections:faceDetectorResult.detections containExpectedKeypoints:expectedKeypoints];
}
#pragma mark Face Detector Initializers
- (MPPFaceDetectorOptions *)faceDetectorOptionsWithModelName:(NSString *)modelName {
NSString *modelPath = [MPPFaceDetectorTests filePathWithName:modelName extension:@"tflite"];
MPPFaceDetectorOptions *faceDetectorOptions = [[MPPFaceDetectorOptions alloc] init];
faceDetectorOptions.baseOptions.modelAssetPath = modelPath;
return faceDetectorOptions;
}
- (void)assertCreateFaceDetectorWithOptions:(MPPFaceDetectorOptions *)faceDetectorOptions
failsWithExpectedError:(NSError *)expectedError {
NSError *error = nil;
MPPFaceDetector *faceDetector = [[MPPFaceDetector alloc] initWithOptions:faceDetectorOptions
error:&error];
XCTAssertNil(faceDetector);
AssertEqualErrors(error, expectedError);
}
- (MPPFaceDetector *)faceDetectorWithOptionsSucceeds:(MPPFaceDetectorOptions *)faceDetectorOptions {
MPPFaceDetector *faceDetector = [[MPPFaceDetector alloc] initWithOptions:faceDetectorOptions
error:nil];
XCTAssertNotNil(faceDetector);
return faceDetector;
}
#pragma mark Assert Detection Results
- (MPPImage *)imageWithFileInfo:(NSDictionary *)fileInfo {
UIImageOrientation orientation = (UIImageOrientation)[fileInfo[@"orientation"] intValue];
MPPImage *image = [MPPImage imageFromBundleWithClass:[MPPFaceDetectorTests class]
fileName:fileInfo[@"name"]
ofType:fileInfo[@"type"]
orientation:orientation];
XCTAssertNotNil(image);
return image;
}
- (void)assertResultsOfDetectInImage:(MPPImage *)mppImage
usingFaceDetector:(MPPFaceDetector *)faceDetector
containsExpectedKeypoints:(NSArray<NSArray *> *)expectedKeypoints {
NSError *error;
MPPFaceDetectorResult *faceDetectorResult = [faceDetector detectInImage:mppImage error:&error];
XCTAssertNil(error);
XCTAssertNotNil(faceDetectorResult);
[self assertFaceDetectorResult:faceDetectorResult containsExpectedKeypoints:expectedKeypoints];
}
- (void)assertResultsOfDetectInImageWithFileInfo:(NSDictionary *)fileInfo
usingFaceDetector:(MPPFaceDetector *)faceDetector
containsExpectedKeypoints:(NSArray<NSArray *> *)expectedKeypoints {
MPPImage *mppImage = [self imageWithFileInfo:fileInfo];
[self assertResultsOfDetectInImage:mppImage
usingFaceDetector:faceDetector
containsExpectedKeypoints:expectedKeypoints];
}
@end

View File

@ -0,0 +1,62 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPFaceDetectorResult",
srcs = ["sources/MPPFaceDetectorResult.m"],
hdrs = ["sources/MPPFaceDetectorResult.h"],
deps = [
"//mediapipe/tasks/ios/components/containers:MPPDetection",
"//mediapipe/tasks/ios/core:MPPTaskResult",
],
)
objc_library(
name = "MPPFaceDetectorOptions",
srcs = ["sources/MPPFaceDetectorOptions.m"],
hdrs = ["sources/MPPFaceDetectorOptions.h"],
deps = [
":MPPFaceDetectorResult",
"//mediapipe/tasks/ios/core:MPPTaskOptions",
"//mediapipe/tasks/ios/vision/core:MPPRunningMode",
],
)
objc_library(
name = "MPPFaceDetector",
srcs = ["sources/MPPFaceDetector.mm"],
hdrs = ["sources/MPPFaceDetector.h"],
copts = [
"-ObjC++",
"-std=c++17",
"-x objective-c++",
],
deps = [
":MPPFaceDetectorOptions",
":MPPFaceDetectorResult",
"//mediapipe/tasks/cc/vision/face_detector:face_detector_graph",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
"//mediapipe/tasks/ios/core:MPPTaskInfo",
"//mediapipe/tasks/ios/vision/core:MPPImage",
"//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator",
"//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner",
"//mediapipe/tasks/ios/vision/face_detector/utils:MPPFaceDetectorOptionsHelpers",
"//mediapipe/tasks/ios/vision/face_detector/utils:MPPFaceDetectorResultHelpers",
],
)

View File

@ -0,0 +1,190 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
#import "mediapipe/tasks/ios/vision/face_detector/sources/MPPFaceDetectorOptions.h"
#import "mediapipe/tasks/ios/vision/face_detector/sources/MPPFaceDetectorResult.h"
NS_ASSUME_NONNULL_BEGIN
/**
* @brief Class that performs face detection on images.
*
* The API expects a TFLite model with mandatory TFLite Model Metadata.
*
* The API supports models with one image input tensor and one or more output tensors. To be more
* specific, here are the requirements:
*
* Input tensor
* (kTfLiteUInt8/kTfLiteFloat32)
* - image input of size `[batch x height x width x channels]`.
* - batch inference is not supported (`batch` is required to be 1).
* - only RGB inputs are supported (`channels` is required to be 3).
* - if type is kTfLiteFloat32, NormalizationOptions are required to be attached to the metadata
* for input normalization.
*
* Output tensors must be the 4 outputs of a `DetectionPostProcess` op, i.e:(kTfLiteFloat32)
* (kTfLiteUInt8/kTfLiteFloat32)
* - locations tensor of size `[num_results x 4]`, the inner array representing bounding boxes
* in the form [top, left, right, bottom].
* - BoundingBoxProperties are required to be attached to the metadata and must specify
* type=BOUNDARIES and coordinate_type=RATIO.
* (kTfLiteFloat32)
* - classes tensor of size `[num_results]`, each value representing the integer index of a
* class.
* - scores tensor of size `[num_results]`, each value representing the score of the detected
* face.
* - optional score calibration can be attached using ScoreCalibrationOptions and an
* AssociatedFile with type TENSOR_AXIS_SCORE_CALIBRATION. See metadata_schema.fbs [1] for more
* details.
* (kTfLiteFloat32)
* - integer num_results as a tensor of size `[1]`
*/
NS_SWIFT_NAME(FaceDetector)
@interface MPPFaceDetector : NSObject
/**
* Creates a new instance of `MPPFaceDetector` from an absolute path to a TensorFlow Lite model
* file stored locally on the device and the default `MPPFaceDetector`.
*
* @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
* @param error An optional error parameter populated when there is an error in initializing the
* face detector.
*
* @return A new instance of `MPPFaceDetector` with the given model path. `nil` if there is an
* error in initializing the face detector.
*/
- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error;
/**
* Creates a new instance of `MPPFaceDetector` from the given `MPPFaceDetectorOptions`.
*
* @param options The options of type `MPPFaceDetectorOptions` to use for configuring the
* `MPPFaceDetector`.
* @param error An optional error parameter populated when there is an error in initializing the
* face detector.
*
* @return A new instance of `MPPFaceDetector` with the given options. `nil` if there is an error
* in initializing the face detector.
*/
- (nullable instancetype)initWithOptions:(MPPFaceDetectorOptions *)options
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
/**
* Performs face detection on the provided MPPImage using the whole image as region of
* interest. Rotation will be applied according to the `orientation` property of the provided
* `MPPImage`. Only use this method when the `MPPFaceDetector` is created with
* `MPPRunningModeImage`.
*
* This method supports classification of RGBA images. If your `MPPImage` has a source type of
* `MPPImageSourceTypePixelBuffer` or `MPPImageSourceTypeSampleBuffer`, the underlying pixel buffer
* must have one of the following pixel format types:
* 1. kCVPixelFormatType_32BGRA
* 2. kCVPixelFormatType_32RGBA
*
* If your `MPPImage` has a source type of `MPPImageSourceTypeImage` ensure that the color space is
* RGB with an Alpha channel.
*
* @param image The `MPPImage` on which face detection is to be performed.
* @param error An optional error parameter populated when there is an error in performing face
* detection on the input image.
*
* @return An `MPPFaceDetectorResult` face that contains a list of detections, each detection
* has a bounding box that is expressed in the unrotated input frame of reference coordinates
* system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying
* image data.
*/
- (nullable MPPFaceDetectorResult *)detectInImage:(MPPImage *)image
error:(NSError **)error NS_SWIFT_NAME(detect(image:));
/**
* Performs face detection on the provided video frame of type `MPPImage` using the whole
* image as region of interest. Rotation will be applied according to the `orientation` property of
* the provided `MPPImage`. Only use this method when the `MPPFaceDetector` is created with
* `MPPRunningModeVideo`.
*
* This method supports classification of RGBA images. If your `MPPImage` has a source type of
* `MPPImageSourceTypePixelBuffer` or `MPPImageSourceTypeSampleBuffer`, the underlying pixel buffer
* must have one of the following pixel format types:
* 1. kCVPixelFormatType_32BGRA
* 2. kCVPixelFormatType_32RGBA
*
* If your `MPPImage` has a source type of `MPPImageSourceTypeImage` ensure that the color space is
* RGB with an Alpha channel.
*
* @param image The `MPPImage` on which face detection is to be performed.
* @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
* timestamps must be monotonically increasing.
* @param error An optional error parameter populated when there is an error in performing face
* detection on the input image.
*
* @return An `MPPFaceDetectorResult` face that contains a list of detections, each detection
* has a bounding box that is expressed in the unrotated input frame of reference coordinates
* system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the underlying
* image data.
*/
- (nullable MPPFaceDetectorResult *)detectInVideoFrame:(MPPImage *)image
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error
NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:));
/**
* Sends live stream image data of type `MPPImage` to perform face detection using the whole
* image as region of interest. Rotation will be applied according to the `orientation` property of
* the provided `MPPImage`. Only use this method when the `MPPFaceDetector` is created with
* `MPPRunningModeLiveStream`.
*
* The object which needs to be continuously notified of the available results of face
* detection must confirm to `MPPFaceDetectorLiveStreamDelegate` protocol and implement the
* `faceDetector:didFinishDetectionWithResult:timestampInMilliseconds:error:` delegate method.
*
* It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent
* to the face detector. The input timestamps must be monotonically increasing.
*
* This method supports classification of RGBA images. If your `MPPImage` has a source type of
* `MPPImageSourceTypePixelBuffer` or `MPPImageSourceTypeSampleBuffer`, the underlying pixel buffer
* must have one of the following pixel format types:
* 1. kCVPixelFormatType_32BGRA
* 2. kCVPixelFormatType_32RGBA
*
* If the input `MPPImage` has a source type of `MPPImageSourceTypeImage` ensure that the color
* space is RGB with an Alpha channel.
*
* If this method is used for classifying live camera frames using `AVFoundation`, ensure that you
* request `AVCaptureVideoDataOutput` to output frames in `kCMPixelFormat_32RGBA` using its
* `videoSettings` property.
*
* @param image A live stream image data of type `MPPImage` on which face detection is to be
* performed.
* @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
* image is sent to the face detector. The input timestamps must be monotonically increasing.
* @param error An optional error parameter populated when there is an error in performing face
* detection on the input live stream image data.
*
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
*/
- (BOOL)detectAsyncInImage:(MPPImage *)image
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error
NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:));
- (instancetype)init NS_UNAVAILABLE;
+ (instancetype)new NS_UNAVAILABLE;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,259 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/vision/face_detector/sources/MPPFaceDetector.h"
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h"
#import "mediapipe/tasks/ios/vision/face_detector/utils/sources/MPPFaceDetectorOptions+Helpers.h"
#import "mediapipe/tasks/ios/vision/face_detector/utils/sources/MPPFaceDetectorResult+Helpers.h"
using ::mediapipe::NormalizedRect;
using ::mediapipe::Packet;
using ::mediapipe::Timestamp;
using ::mediapipe::tasks::core::PacketMap;
using ::mediapipe::tasks::core::PacketsCallback;
static constexpr int kMicrosecondsPerMillisecond = 1000;
// Constants for the underlying MP Tasks Graph. See
// https://github.com/google/mediapipe/tree/master/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc
static NSString *const kDetectionsStreamName = @"detections_out";
static NSString *const kDetectionsTag = @"DETECTIONS";
static NSString *const kImageInStreamName = @"image_in";
static NSString *const kImageOutStreamName = @"image_out";
static NSString *const kImageTag = @"IMAGE";
static NSString *const kNormRectStreamName = @"norm_rect_in";
static NSString *const kNormRectTag = @"NORM_RECT";
static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.face_detector.FaceDetectorGraph";
static NSString *const kTaskName = @"faceDetector";
#define InputPacketMap(imagePacket, normalizedRectPacket) \
{ \
{kImageInStreamName.cppString, imagePacket}, { \
kNormRectStreamName.cppString, normalizedRectPacket \
} \
}
@interface MPPFaceDetector () {
/** iOS Vision Task Runner */
MPPVisionTaskRunner *_visionTaskRunner;
dispatch_queue_t _callbackQueue;
}
@property(nonatomic, weak) id<MPPFaceDetectorLiveStreamDelegate> faceDetectorLiveStreamDelegate;
- (void)processLiveStreamResult:(absl::StatusOr<PacketMap>)liveStreamResult;
@end
@implementation MPPFaceDetector
- (instancetype)initWithOptions:(MPPFaceDetectorOptions *)options error:(NSError **)error {
self = [super init];
if (self) {
MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc]
initWithTaskGraphName:kTaskGraphName
inputStreams:@[
[NSString stringWithFormat:@"%@:%@", kImageTag, kImageInStreamName],
[NSString stringWithFormat:@"%@:%@", kNormRectTag, kNormRectStreamName]
]
outputStreams:@[
[NSString stringWithFormat:@"%@:%@", kDetectionsTag, kDetectionsStreamName],
[NSString stringWithFormat:@"%@:%@", kImageTag, kImageOutStreamName]
]
taskOptions:options
enableFlowLimiting:options.runningMode == MPPRunningModeLiveStream
error:error];
if (!taskInfo) {
return nil;
}
PacketsCallback packetsCallback = nullptr;
if (options.faceDetectorLiveStreamDelegate) {
_faceDetectorLiveStreamDelegate = options.faceDetectorLiveStreamDelegate;
// Create a private serial dispatch queue in which the delegate method will be called
// asynchronously. This is to ensure that if the client performs a long running operation in
// the delegate method, the queue on which the C++ callbacks is invoked is not blocked and is
// freed up to continue with its operations.
_callbackQueue = dispatch_queue_create(
[MPPVisionTaskRunner uniqueDispatchQueueNameWithSuffix:kTaskName], NULL);
// Capturing `self` as weak in order to avoid `self` being kept in memory
// and cause a retain cycle, after self is set to `nil`.
MPPFaceDetector *__weak weakSelf = self;
packetsCallback = [=](absl::StatusOr<PacketMap> liveStreamResult) {
[weakSelf processLiveStreamResult:liveStreamResult];
};
}
_visionTaskRunner =
[[MPPVisionTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig]
runningMode:options.runningMode
packetsCallback:std::move(packetsCallback)
error:error];
if (!_visionTaskRunner) {
return nil;
}
}
return self;
}
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error {
MPPFaceDetectorOptions *options = [[MPPFaceDetectorOptions alloc] init];
options.baseOptions.modelAssetPath = modelPath;
return [self initWithOptions:options error:error];
}
- (std::optional<PacketMap>)inputPacketMapWithMPPImage:(MPPImage *)image
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error {
std::optional<NormalizedRect> rect =
[_visionTaskRunner normalizedRectFromRegionOfInterest:CGRectZero
imageSize:CGSizeMake(image.width, image.height)
imageOrientation:image.orientation
ROIAllowed:NO
error:error];
if (!rect.has_value()) {
return std::nullopt;
}
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image
timestampInMilliseconds:timestampInMilliseconds
error:error];
if (imagePacket.IsEmpty()) {
return std::nullopt;
}
Packet normalizedRectPacket =
[MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
timestampInMilliseconds:timestampInMilliseconds];
PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
return inputPacketMap;
}
- (nullable MPPFaceDetectorResult *)detectInImage:(MPPImage *)image error:(NSError **)error {
std::optional<NormalizedRect> rect =
[_visionTaskRunner normalizedRectFromRegionOfInterest:CGRectZero
imageSize:CGSizeMake(image.width, image.height)
imageOrientation:image.orientation
ROIAllowed:NO
error:error];
if (!rect.has_value()) {
return nil;
}
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image error:error];
if (imagePacket.IsEmpty()) {
return nil;
}
Packet normalizedRectPacket =
[MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()];
PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
std::optional<PacketMap> outputPacketMap = [_visionTaskRunner processImagePacketMap:inputPacketMap
error:error];
if (!outputPacketMap.has_value()) {
return nil;
}
return [MPPFaceDetectorResult
faceDetectorResultWithDetectionsPacket:outputPacketMap
.value()[kDetectionsStreamName.cppString]];
}
- (nullable MPPFaceDetectorResult *)detectInVideoFrame:(MPPImage *)image
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error {
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
timestampInMilliseconds:timestampInMilliseconds
error:error];
if (!inputPacketMap.has_value()) {
return nil;
}
std::optional<PacketMap> outputPacketMap =
[_visionTaskRunner processVideoFramePacketMap:inputPacketMap.value() error:error];
if (!outputPacketMap.has_value()) {
return nil;
}
return [MPPFaceDetectorResult
faceDetectorResultWithDetectionsPacket:outputPacketMap
.value()[kDetectionsStreamName.cppString]];
}
- (BOOL)detectAsyncInImage:(MPPImage *)image
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error {
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
timestampInMilliseconds:timestampInMilliseconds
error:error];
if (!inputPacketMap.has_value()) {
return NO;
}
return [_visionTaskRunner processLiveStreamPacketMap:inputPacketMap.value() error:error];
}
- (void)processLiveStreamResult:(absl::StatusOr<PacketMap>)liveStreamResult {
if (![self.faceDetectorLiveStreamDelegate
respondsToSelector:@selector(faceDetector:
didFinishDetectionWithResult:timestampInMilliseconds:error:)]) {
return;
}
NSError *callbackError = nil;
if (![MPPCommonUtils checkCppError:liveStreamResult.status() toError:&callbackError]) {
dispatch_async(_callbackQueue, ^{
[self.faceDetectorLiveStreamDelegate faceDetector:self
didFinishDetectionWithResult:nil
timestampInMilliseconds:Timestamp::Unset().Value()
error:callbackError];
});
return;
}
PacketMap &outputPacketMap = liveStreamResult.value();
if (outputPacketMap[kImageOutStreamName.cppString].IsEmpty()) {
return;
}
MPPFaceDetectorResult *result = [MPPFaceDetectorResult
faceDetectorResultWithDetectionsPacket:liveStreamResult
.value()[kDetectionsStreamName.cppString]];
NSInteger timeStampInMilliseconds =
outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() /
kMicrosecondsPerMillisecond;
dispatch_async(_callbackQueue, ^{
[self.faceDetectorLiveStreamDelegate faceDetector:self
didFinishDetectionWithResult:result
timestampInMilliseconds:timeStampInMilliseconds
error:callbackError];
});
}
@end

View File

@ -0,0 +1,101 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h"
#import "mediapipe/tasks/ios/vision/face_detector/sources/MPPFaceDetectorResult.h"
NS_ASSUME_NONNULL_BEGIN
@class MPPFaceDetector;
/**
* This protocol defines an interface for the delegates of `MPPFaceDetector` face to receive
* results of performing asynchronous face detection on images (i.e, when `runningMode` =
* `MPPRunningModeLiveStream`).
*
* The delegate of `MPPFaceDetector` must adopt `MPPFaceDetectorLiveStreamDelegate` protocol.
* The methods in this protocol are optional.
*/
NS_SWIFT_NAME(FaceDetectorLiveStreamDelegate)
@protocol MPPFaceDetectorLiveStreamDelegate <NSObject>
@optional
/**
* This method notifies a delegate that the results of asynchronous face detection of
* an image submitted to the `MPPFaceDetector` is available.
*
* This method is called on a private serial dispatch queue created by the `MPPFaceDetector`
* for performing the asynchronous delegates calls.
*
* @param faceDetector The face detector which performed the face detection.
* This is useful to test equality when there are multiple instances of `MPPFaceDetector`.
* @param result The `MPPFaceDetectorResult` object that contains a list of detections, each
* detection has a bounding box that is expressed in the unrotated input frame of reference
* coordinates system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the
* underlying image data.
* @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
* image was sent to the face detector.
* @param error An optional error parameter populated when there is an error in performing face
* detection on the input live stream image data.
*/
- (void)faceDetector:(MPPFaceDetector *)faceDetector
didFinishDetectionWithResult:(nullable MPPFaceDetectorResult *)result
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(nullable NSError *)error
NS_SWIFT_NAME(faceDetector(_:didFinishDetection:timestampInMilliseconds:error:));
@end
/** Options for setting up a `MPPFaceDetector`. */
NS_SWIFT_NAME(FaceDetectorOptions)
@interface MPPFaceDetectorOptions : MPPTaskOptions <NSCopying>
/**
* Running mode of the face detector task. Defaults to `MPPRunningModeImage`.
* `MPPFaceDetector` can be created with one of the following running modes:
* 1. `MPPRunningModeImage`: The mode for performing face detection on single image inputs.
* 2. `MPPRunningModeVideo`: The mode for performing face detection on the decoded frames of a
* video.
* 3. `MPPRunningModeLiveStream`: The mode for performing face detection on a live stream of
* input data, such as from the camera.
*/
@property(nonatomic) MPPRunningMode runningMode;
/**
* An object that confirms to `MPPFaceDetectorLiveStreamDelegate` protocol. This object must
* implement `faceDetector:didFinishDetectionWithResult:timestampInMilliseconds:error:` to receive
* the results of performing asynchronous face detection on images (i.e, when `runningMode` =
* `MPPRunningModeLiveStream`).
*/
@property(nonatomic, weak, nullable) id<MPPFaceDetectorLiveStreamDelegate>
faceDetectorLiveStreamDelegate;
/**
* The minimum confidence score for the face detection to be considered successful. Defaults to
* 0.5.
*/
@property(nonatomic) float minDetectionConfidence;
/**
* The minimum non-maximum-suppression threshold for face detection to be considered overlapped.
* Defaults to 0.3.
*/
@property(nonatomic) float minSuppressionThreshold;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,38 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/vision/face_detector/sources/MPPFaceDetectorOptions.h"
@implementation MPPFaceDetectorOptions
- (instancetype)init {
self = [super init];
if (self) {
_minDetectionConfidence = 0.5;
_minSuppressionThreshold = 0.3;
}
return self;
}
- (id)copyWithZone:(NSZone *)zone {
MPPFaceDetectorOptions *faceDetectorOptions = [super copyWithZone:zone];
faceDetectorOptions.minDetectionConfidence = self.minDetectionConfidence;
faceDetectorOptions.minSuppressionThreshold = self.minSuppressionThreshold;
faceDetectorOptions.faceDetectorLiveStreamDelegate = self.faceDetectorLiveStreamDelegate;
return faceDetectorOptions;
}
@end

View File

@ -0,0 +1,49 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/components/containers/sources/MPPDetection.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
NS_ASSUME_NONNULL_BEGIN
/** Represents the detection results generated by `MPPFaceDetector`. */
NS_SWIFT_NAME(FaceDetectorResult)
@interface MPPFaceDetectorResult : MPPTaskResult
/**
* The array of `MPPDetection` objects each of which has a bounding box that is expressed in the
* unrotated input frame of reference coordinates system, i.e. in `[0,image_width) x
* [0,image_height)`, which are the dimensions of the underlying image data.
*/
@property(nonatomic, readonly) NSArray<MPPDetection *> *detections;
/**
* Initializes a new `MPPFaceDetectorResult` with the given array of detections and timestamp (in
* milliseconds).
*
* @param detections An array of `MPPDetection` objects each of which has a bounding box that is
* expressed in the unrotated input frame of reference coordinates system, i.e. in `[0,image_width)
* x [0,image_height)`, which are the dimensions of the underlying image data.
* @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
*
* @return An instance of `MPPFaceDetectorResult` initialized with the given array of detections
* and timestamp (in milliseconds).
*/
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,28 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/vision/face_detector/sources/MPPFaceDetectorResult.h"
@implementation MPPFaceDetectorResult
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
if (self) {
_detections = [detections copy];
}
return self;
}
@end

View File

@ -0,0 +1,42 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPFaceDetectorOptionsHelpers",
srcs = ["sources/MPPFaceDetectorOptions+Helpers.mm"],
hdrs = ["sources/MPPFaceDetectorOptions+Helpers.h"],
deps = [
"//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
"//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol",
"//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers",
"//mediapipe/tasks/ios/vision/face_detector:MPPFaceDetectorOptions",
],
)
objc_library(
name = "MPPFaceDetectorResultHelpers",
srcs = ["sources/MPPFaceDetectorResult+Helpers.mm"],
hdrs = ["sources/MPPFaceDetectorResult+Helpers.h"],
deps = [
"//mediapipe/framework:packet",
"//mediapipe/tasks/ios/components/containers/utils:MPPDetectionHelpers",
"//mediapipe/tasks/ios/vision/face_detector:MPPFaceDetectorResult",
],
)

View File

@ -0,0 +1,36 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef __cplusplus
#error "This file requires Objective-C++."
#endif // __cplusplus
#include "mediapipe/framework/calculator_options.pb.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"
#import "mediapipe/tasks/ios/vision/face_detector/sources/MPPFaceDetectorOptions.h"
NS_ASSUME_NONNULL_BEGIN
@interface MPPFaceDetectorOptions (Helpers) <MPPTaskOptionsProtocol>
/**
* Populates the provided `CalculatorOptions` proto container with the current settings.
*
* @param optionsProto The `CalculatorOptions` proto object to copy the settings to.
*/
- (void)copyToProto:(::mediapipe::CalculatorOptions *)optionsProto;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,39 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/vision/face_detector/utils/sources/MPPFaceDetectorOptions+Helpers.h"
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h"
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
using CalculatorOptionsProto = ::mediapipe::CalculatorOptions;
using FaceDetectorGraphOptionsProto =
::mediapipe::tasks::vision::face_detector::proto::FaceDetectorGraphOptions;
@implementation MPPFaceDetectorOptions (Helpers)
- (void)copyToProto:(CalculatorOptionsProto *)optionsProto {
FaceDetectorGraphOptionsProto *graphOptions =
optionsProto->MutableExtension(FaceDetectorGraphOptionsProto::ext);
graphOptions->Clear();
[self.baseOptions copyToProto:graphOptions->mutable_base_options()];
graphOptions->set_min_detection_confidence(self.minDetectionConfidence);
graphOptions->set_min_suppression_threshold(self.minSuppressionThreshold);
}
@end

View File

@ -0,0 +1,39 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef __cplusplus
#error "This file requires Objective-C++."
#endif // __cplusplus
#include "mediapipe/framework/packet.h"
#import "mediapipe/tasks/ios/vision/face_detector/sources/MPPFaceDetectorResult.h"
NS_ASSUME_NONNULL_BEGIN
@interface MPPFaceDetectorResult (Helpers)
/**
* Creates an `MPPFaceDetectorResult` from a MediaPipe packet containing a
* `std::vector<DetectionProto>`.
*
* @param packet a MediaPipe packet wrapping a `std::vector<DetectionProto>`.
*
* @return An `MPPFaceDetectorResult` object that contains a list of detections.
*/
+ (nullable MPPFaceDetectorResult *)faceDetectorResultWithDetectionsPacket:
(const ::mediapipe::Packet &)packet;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,45 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/vision/face_detector/utils/sources/MPPFaceDetectorResult+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.h"
using DetectionProto = ::mediapipe::Detection;
using ::mediapipe::Packet;
static constexpr int kMicrosecondsPerMillisecond = 1000;
@implementation MPPFaceDetectorResult (Helpers)
+ (nullable MPPFaceDetectorResult *)faceDetectorResultWithDetectionsPacket:(const Packet &)packet {
NSMutableArray<MPPDetection *> *detections;
if (packet.ValidateAsType<std::vector<DetectionProto>>().ok()) {
const std::vector<DetectionProto> &detectionProtos = packet.Get<std::vector<DetectionProto>>();
detections = [NSMutableArray arrayWithCapacity:(NSUInteger)detectionProtos.size()];
for (const auto &detectionProto : detectionProtos) {
[detections addObject:[MPPDetection detectionWithProto:detectionProto]];
}
} else {
detections = [NSMutableArray arrayWithCapacity:0];
}
return
[[MPPFaceDetectorResult alloc] initWithDetections:detections
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
kMicrosecondsPerMillisecond)];
}
@end

View File

@ -87,7 +87,7 @@ NS_SWIFT_NAME(GestureRecognizerOptions)
gestureRecognizerLiveStreamDelegate; gestureRecognizerLiveStreamDelegate;
/** Sets the maximum number of hands can be detected by the GestureRecognizer. */ /** Sets the maximum number of hands can be detected by the GestureRecognizer. */
@property(nonatomic) NSInteger numberOfHands NS_SWIFT_NAME(numHands); @property(nonatomic) NSInteger numHands;
/** Sets minimum confidence score for the hand detection to be considered successful */ /** Sets minimum confidence score for the hand detection to be considered successful */
@property(nonatomic) float minHandDetectionConfidence; @property(nonatomic) float minHandDetectionConfidence;

View File

@ -19,7 +19,7 @@
- (instancetype)init { - (instancetype)init {
self = [super init]; self = [super init];
if (self) { if (self) {
_numberOfHands = 1; _numHands = 1;
_minHandDetectionConfidence = 0.5f; _minHandDetectionConfidence = 0.5f;
_minHandPresenceConfidence = 0.5f; _minHandPresenceConfidence = 0.5f;
_minTrackingConfidence = 0.5f; _minTrackingConfidence = 0.5f;
@ -33,7 +33,7 @@
gestureRecognizerOptions.runningMode = self.runningMode; gestureRecognizerOptions.runningMode = self.runningMode;
gestureRecognizerOptions.gestureRecognizerLiveStreamDelegate = gestureRecognizerOptions.gestureRecognizerLiveStreamDelegate =
self.gestureRecognizerLiveStreamDelegate; self.gestureRecognizerLiveStreamDelegate;
gestureRecognizerOptions.numberOfHands = self.numberOfHands; gestureRecognizerOptions.numHands = self.numHands;
gestureRecognizerOptions.minHandDetectionConfidence = self.minHandDetectionConfidence; gestureRecognizerOptions.minHandDetectionConfidence = self.minHandDetectionConfidence;
gestureRecognizerOptions.minHandPresenceConfidence = self.minHandPresenceConfidence; gestureRecognizerOptions.minHandPresenceConfidence = self.minHandPresenceConfidence;
gestureRecognizerOptions.minTrackingConfidence = self.minTrackingConfidence; gestureRecognizerOptions.minTrackingConfidence = self.minTrackingConfidence;

View File

@ -60,7 +60,7 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto
HandDetectorGraphOptionsProto *handDetectorGraphOptionsProto = HandDetectorGraphOptionsProto *handDetectorGraphOptionsProto =
handLandmarkerGraphOptionsProto->mutable_hand_detector_graph_options(); handLandmarkerGraphOptionsProto->mutable_hand_detector_graph_options();
handDetectorGraphOptionsProto->Clear(); handDetectorGraphOptionsProto->Clear();
handDetectorGraphOptionsProto->set_num_hands(self.numberOfHands); handDetectorGraphOptionsProto->set_num_hands(self.numHands);
handDetectorGraphOptionsProto->set_min_detection_confidence(self.minHandDetectionConfidence); handDetectorGraphOptionsProto->set_min_detection_confidence(self.minHandDetectionConfidence);
HandLandmarksDetectorGraphOptionsProto *handLandmarksDetectorGraphOptionsProto = HandLandmarksDetectorGraphOptionsProto *handLandmarksDetectorGraphOptionsProto =

View File

@ -80,7 +80,7 @@ NS_SWIFT_NAME(ObjectDetector)
* Creates a new instance of `MPPObjectDetector` from the given `MPPObjectDetectorOptions`. * Creates a new instance of `MPPObjectDetector` from the given `MPPObjectDetectorOptions`.
* *
* @param options The options of type `MPPObjectDetectorOptions` to use for configuring the * @param options The options of type `MPPObjectDetectorOptions` to use for configuring the
* `MPPImageClassifMPPObjectDetectorier`. * `MPPObjectDetector`.
* @param error An optional error parameter populated when there is an error in initializing the * @param error An optional error parameter populated when there is an error in initializing the
* object detector. * object detector.
* *
@ -96,7 +96,7 @@ NS_SWIFT_NAME(ObjectDetector)
* `MPPImage`. Only use this method when the `MPPObjectDetector` is created with * `MPPImage`. Only use this method when the `MPPObjectDetector` is created with
* `MPPRunningModeImage`. * `MPPRunningModeImage`.
* *
* This method supports classification of RGBA images. If your `MPPImage` has a source type of * This method supports detecting objects in RGBA images. If your `MPPImage` has a source type of
* `MPPImageSourceTypePixelBuffer` or `MPPImageSourceTypeSampleBuffer`, the underlying pixel buffer * `MPPImageSourceTypePixelBuffer` or `MPPImageSourceTypeSampleBuffer`, the underlying pixel buffer
* must have one of the following pixel format types: * must have one of the following pixel format types:
* 1. kCVPixelFormatType_32BGRA * 1. kCVPixelFormatType_32BGRA
@ -123,7 +123,7 @@ NS_SWIFT_NAME(ObjectDetector)
* the provided `MPPImage`. Only use this method when the `MPPObjectDetector` is created with * the provided `MPPImage`. Only use this method when the `MPPObjectDetector` is created with
* `MPPRunningModeVideo`. * `MPPRunningModeVideo`.
* *
* This method supports classification of RGBA images. If your `MPPImage` has a source type of * This method supports detecting objects in of RGBA images. If your `MPPImage` has a source type of
* `MPPImageSourceTypePixelBuffer` or `MPPImageSourceTypeSampleBuffer`, the underlying pixel buffer * `MPPImageSourceTypePixelBuffer` or `MPPImageSourceTypeSampleBuffer`, the underlying pixel buffer
* must have one of the following pixel format types: * must have one of the following pixel format types:
* 1. kCVPixelFormatType_32BGRA * 1. kCVPixelFormatType_32BGRA
@ -161,7 +161,7 @@ NS_SWIFT_NAME(ObjectDetector)
* It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent * It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent
* to the object detector. The input timestamps must be monotonically increasing. * to the object detector. The input timestamps must be monotonically increasing.
* *
* This method supports classification of RGBA images. If your `MPPImage` has a source type of * This method supports detecting objects in RGBA images. If your `MPPImage` has a source type of
* `MPPImageSourceTypePixelBuffer` or `MPPImageSourceTypeSampleBuffer`, the underlying pixel buffer * `MPPImageSourceTypePixelBuffer` or `MPPImageSourceTypeSampleBuffer`, the underlying pixel buffer
* must have one of the following pixel format types: * must have one of the following pixel format types:
* 1. kCVPixelFormatType_32BGRA * 1. kCVPixelFormatType_32BGRA
@ -170,8 +170,8 @@ NS_SWIFT_NAME(ObjectDetector)
* If the input `MPPImage` has a source type of `MPPImageSourceTypeImage` ensure that the color * If the input `MPPImage` has a source type of `MPPImageSourceTypeImage` ensure that the color
* space is RGB with an Alpha channel. * space is RGB with an Alpha channel.
* *
* If this method is used for classifying live camera frames using `AVFoundation`, ensure that you * If this method is used for detecting objects in live camera frames using `AVFoundation`, ensure
* request `AVCaptureVideoDataOutput` to output frames in `kCMPixelFormat_32RGBA` using its * that you request `AVCaptureVideoDataOutput` to output frames in `kCMPixelFormat_32RGBA` using its
* `videoSettings` property. * `videoSettings` property.
* *
* @param image A live stream image data of type `MPPImage` on which object detection is to be * @param image A live stream image data of type `MPPImage` on which object detection is to be

View File

@ -25,8 +25,12 @@ using ::mediapipe::Packet;
+ (nullable MPPObjectDetectorResult *)objectDetectorResultWithDetectionsPacket: + (nullable MPPObjectDetectorResult *)objectDetectorResultWithDetectionsPacket:
(const Packet &)packet { (const Packet &)packet {
NSInteger timestampInMilliseconds = (NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond);
if (!packet.ValidateAsType<std::vector<DetectionProto>>().ok()) { if (!packet.ValidateAsType<std::vector<DetectionProto>>().ok()) {
return nil; return [[MPPObjectDetectorResult alloc] initWithDetections:@[]
timestampInMilliseconds:timestampInMilliseconds];
} }
const std::vector<DetectionProto> &detectionProtos = packet.Get<std::vector<DetectionProto>>(); const std::vector<DetectionProto> &detectionProtos = packet.Get<std::vector<DetectionProto>>();
@ -39,8 +43,7 @@ using ::mediapipe::Packet;
return return
[[MPPObjectDetectorResult alloc] initWithDetections:detections [[MPPObjectDetectorResult alloc] initWithDetections:detections
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() / timestampInMilliseconds:timestampInMilliseconds];
kMicroSecondsPerMilliSecond)];
} }
@end @end

View File

@ -71,6 +71,7 @@ android_library(
srcs = [ srcs = [
"objectdetector/ObjectDetectionResult.java", "objectdetector/ObjectDetectionResult.java",
"objectdetector/ObjectDetector.java", "objectdetector/ObjectDetector.java",
"objectdetector/ObjectDetectorResult.java",
], ],
javacopts = [ javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF", "-Xep:AndroidJdkLibsChecker:OFF",

View File

@ -14,15 +14,16 @@
package com.google.mediapipe.tasks.vision.objectdetector; package com.google.mediapipe.tasks.vision.objectdetector;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.core.TaskResult; import com.google.mediapipe.tasks.core.TaskResult;
import com.google.mediapipe.formats.proto.DetectionProto.Detection; import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
/** Represents the detection results generated by {@link ObjectDetector}. */ /**
@AutoValue * Represents the detection results generated by {@link ObjectDetector}.
*
* @deprecated Use {@link ObjectDetectorResult} instead.
*/
@Deprecated
public abstract class ObjectDetectionResult implements TaskResult { public abstract class ObjectDetectionResult implements TaskResult {
@Override @Override
@ -36,15 +37,10 @@ public abstract class ObjectDetectionResult implements TaskResult {
* *
* @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages. * @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
* @param timestampMs a timestamp for this result. * @param timestampMs a timestamp for this result.
* @deprecated Use {@link ObjectDetectorResult#create} instead.
*/ */
@Deprecated
public static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) { public static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) {
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>(); return ObjectDetectorResult.create(detectionList, timestampMs);
for (Detection detectionProto : detectionList) {
detections.add(
com.google.mediapipe.tasks.components.containers.Detection.createFromProto(
detectionProto));
}
return new AutoValue_ObjectDetectionResult(
timestampMs, Collections.unmodifiableList(detections));
} }
} }

View File

@ -99,11 +99,16 @@ public final class ObjectDetector extends BaseVisionTaskApi {
private static final String TAG = ObjectDetector.class.getSimpleName(); private static final String TAG = ObjectDetector.class.getSimpleName();
private static final String IMAGE_IN_STREAM_NAME = "image_in"; private static final String IMAGE_IN_STREAM_NAME = "image_in";
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
@SuppressWarnings("ConstantCaseForConstants")
private static final List<String> INPUT_STREAMS = private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList( Collections.unmodifiableList(
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
@SuppressWarnings("ConstantCaseForConstants")
private static final List<String> OUTPUT_STREAMS = private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out")); Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out"));
private static final int DETECTIONS_OUT_STREAM_INDEX = 0; private static final int DETECTIONS_OUT_STREAM_INDEX = 0;
private static final int IMAGE_OUT_STREAM_INDEX = 1; private static final int IMAGE_OUT_STREAM_INDEX = 1;
private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.ObjectDetectorGraph"; private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.ObjectDetectorGraph";
@ -166,19 +171,19 @@ public final class ObjectDetector extends BaseVisionTaskApi {
public static ObjectDetector createFromOptions( public static ObjectDetector createFromOptions(
Context context, ObjectDetectorOptions detectorOptions) { Context context, ObjectDetectorOptions detectorOptions) {
// TODO: Consolidate OutputHandler and TaskRunner. // TODO: Consolidate OutputHandler and TaskRunner.
OutputHandler<ObjectDetectionResult, MPImage> handler = new OutputHandler<>(); OutputHandler<ObjectDetectorResult, MPImage> handler = new OutputHandler<>();
handler.setOutputPacketConverter( handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<ObjectDetectionResult, MPImage>() { new OutputHandler.OutputPacketConverter<ObjectDetectorResult, MPImage>() {
@Override @Override
public ObjectDetectionResult convertToTaskResult(List<Packet> packets) { public ObjectDetectorResult convertToTaskResult(List<Packet> packets) {
// If there is no object detected in the image, just returns empty lists. // If there is no object detected in the image, just returns empty lists.
if (packets.get(DETECTIONS_OUT_STREAM_INDEX).isEmpty()) { if (packets.get(DETECTIONS_OUT_STREAM_INDEX).isEmpty()) {
return ObjectDetectionResult.create( return ObjectDetectorResult.create(
new ArrayList<>(), new ArrayList<>(),
BaseVisionTaskApi.generateResultTimestampMs( BaseVisionTaskApi.generateResultTimestampMs(
detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX))); detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX)));
} }
return ObjectDetectionResult.create( return ObjectDetectorResult.create(
PacketGetter.getProtoVector( PacketGetter.getProtoVector(
packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()), packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()),
BaseVisionTaskApi.generateResultTimestampMs( BaseVisionTaskApi.generateResultTimestampMs(
@ -235,7 +240,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @param image a MediaPipe {@link MPImage} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ObjectDetectionResult detect(MPImage image) { public ObjectDetectorResult detect(MPImage image) {
return detect(image, ImageProcessingOptions.builder().build()); return detect(image, ImageProcessingOptions.builder().build());
} }
@ -258,10 +263,9 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* region-of-interest. * region-of-interest.
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ObjectDetectionResult detect( public ObjectDetectorResult detect(MPImage image, ImageProcessingOptions imageProcessingOptions) {
MPImage image, ImageProcessingOptions imageProcessingOptions) {
validateImageProcessingOptions(imageProcessingOptions); validateImageProcessingOptions(imageProcessingOptions);
return (ObjectDetectionResult) processImageData(image, imageProcessingOptions); return (ObjectDetectorResult) processImageData(image, imageProcessingOptions);
} }
/** /**
@ -282,7 +286,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @param timestampMs the input timestamp (in milliseconds). * @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ObjectDetectionResult detectForVideo(MPImage image, long timestampMs) { public ObjectDetectorResult detectForVideo(MPImage image, long timestampMs) {
return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
} }
@ -309,10 +313,10 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* region-of-interest. * region-of-interest.
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ObjectDetectionResult detectForVideo( public ObjectDetectorResult detectForVideo(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
validateImageProcessingOptions(imageProcessingOptions); validateImageProcessingOptions(imageProcessingOptions);
return (ObjectDetectionResult) processVideoData(image, imageProcessingOptions, timestampMs); return (ObjectDetectorResult) processVideoData(image, imageProcessingOptions, timestampMs);
} }
/** /**
@ -435,7 +439,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* object detector is in the live stream mode. * object detector is in the live stream mode.
*/ */
public abstract Builder setResultListener( public abstract Builder setResultListener(
ResultListener<ObjectDetectionResult, MPImage> value); ResultListener<ObjectDetectorResult, MPImage> value);
/** Sets an optional {@link ErrorListener}}. */ /** Sets an optional {@link ErrorListener}}. */
public abstract Builder setErrorListener(ErrorListener value); public abstract Builder setErrorListener(ErrorListener value);
@ -476,11 +480,13 @@ public final class ObjectDetector extends BaseVisionTaskApi {
abstract Optional<Float> scoreThreshold(); abstract Optional<Float> scoreThreshold();
@SuppressWarnings("AutoValueImmutableFields")
abstract List<String> categoryAllowlist(); abstract List<String> categoryAllowlist();
@SuppressWarnings("AutoValueImmutableFields")
abstract List<String> categoryDenylist(); abstract List<String> categoryDenylist();
abstract Optional<ResultListener<ObjectDetectionResult, MPImage>> resultListener(); abstract Optional<ResultListener<ObjectDetectorResult, MPImage>> resultListener();
abstract Optional<ErrorListener> errorListener(); abstract Optional<ErrorListener> errorListener();

View File

@ -0,0 +1,44 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.objectdetector;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** Represents the detection results generated by {@link ObjectDetector}. */
@AutoValue
@SuppressWarnings("deprecation")
public abstract class ObjectDetectorResult extends ObjectDetectionResult {
/**
* Creates an {@link ObjectDetectorResult} instance from a list of {@link Detection} protobuf
* messages.
*
* @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
* @param timestampMs a timestamp for this result.
*/
public static ObjectDetectorResult create(List<Detection> detectionList, long timestampMs) {
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();
for (Detection detectionProto : detectionList) {
detections.add(
com.google.mediapipe.tasks.components.containers.Detection.createFromProto(
detectionProto));
}
return new AutoValue_ObjectDetectorResult(
timestampMs, Collections.unmodifiableList(detections));
}
}

View File

@ -69,7 +69,7 @@ public class ObjectDetectorTest {
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
} }
@ -77,7 +77,7 @@ public class ObjectDetectorTest {
public void detect_successWithNoOptions() throws Exception { public void detect_successWithNoOptions() throws Exception {
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE); ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat. // Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE); assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
} }
@ -91,7 +91,7 @@ public class ObjectDetectorTest {
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// results should have 8 detected objects because maxResults was set to 8. // results should have 8 detected objects because maxResults was set to 8.
assertThat(results.detections()).hasSize(8); assertThat(results.detections()).hasSize(8);
} }
@ -105,7 +105,7 @@ public class ObjectDetectorTest {
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// The score threshold should block all other other objects, except cat. // The score threshold should block all other other objects, except cat.
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
} }
@ -119,7 +119,7 @@ public class ObjectDetectorTest {
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// The score threshold should block objects. // The score threshold should block objects.
assertThat(results.detections()).isEmpty(); assertThat(results.detections()).isEmpty();
} }
@ -133,7 +133,7 @@ public class ObjectDetectorTest {
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Because of the allowlist, results should only contain cat, and there are 6 detected // Because of the allowlist, results should only contain cat, and there are 6 detected
// bounding boxes of cats in CAT_AND_DOG_IMAGE. // bounding boxes of cats in CAT_AND_DOG_IMAGE.
assertThat(results.detections()).hasSize(5); assertThat(results.detections()).hasSize(5);
@ -148,7 +148,7 @@ public class ObjectDetectorTest {
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Because of the denylist, the highest result is not cat anymore. // Because of the denylist, the highest result is not cat anymore.
assertThat(results.detections().get(0).categories().get(0).categoryName()) assertThat(results.detections().get(0).categories().get(0).categoryName())
.isNotEqualTo("cat"); .isNotEqualTo("cat");
@ -160,7 +160,7 @@ public class ObjectDetectorTest {
ObjectDetector.createFromFile( ObjectDetector.createFromFile(
ApplicationProvider.getApplicationContext(), ApplicationProvider.getApplicationContext(),
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE)); TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE));
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat. // Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE); assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
} }
@ -172,7 +172,7 @@ public class ObjectDetectorTest {
ApplicationProvider.getApplicationContext(), ApplicationProvider.getApplicationContext(),
TestUtils.loadToDirectByteBuffer( TestUtils.loadToDirectByteBuffer(
ApplicationProvider.getApplicationContext(), MODEL_FILE)); ApplicationProvider.getApplicationContext(), MODEL_FILE));
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat. // Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE); assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
} }
@ -191,7 +191,7 @@ public class ObjectDetectorTest {
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
} }
@ -256,7 +256,7 @@ public class ObjectDetectorTest {
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageProcessingOptions imageProcessingOptions = ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRotationDegrees(-90).build(); ImageProcessingOptions.builder().setRotationDegrees(-90).build();
ObjectDetectionResult results = ObjectDetectorResult results =
objectDetector.detect( objectDetector.detect(
getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions); getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions);
@ -302,7 +302,7 @@ public class ObjectDetectorTest {
ObjectDetectorOptions.builder() ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(mode) .setRunningMode(mode)
.setResultListener((objectDetectionResult, inputImage) -> {}) .setResultListener((ObjectDetectorResult, inputImage) -> {})
.build()); .build());
assertThat(exception) assertThat(exception)
.hasMessageThat() .hasMessageThat()
@ -381,7 +381,7 @@ public class ObjectDetectorTest {
ObjectDetectorOptions.builder() ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener((objectDetectionResult, inputImage) -> {}) .setResultListener((ObjectDetectorResult, inputImage) -> {})
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
@ -411,7 +411,7 @@ public class ObjectDetectorTest {
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
} }
@ -426,7 +426,7 @@ public class ObjectDetectorTest {
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
ObjectDetectionResult results = ObjectDetectorResult results =
objectDetector.detectForVideo( objectDetector.detectForVideo(
getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ i); getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ i);
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
@ -441,8 +441,8 @@ public class ObjectDetectorTest {
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener( .setResultListener(
(objectDetectionResult, inputImage) -> { (ObjectDetectorResult, inputImage) -> {
assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE); assertContainsOnlyCat(ObjectDetectorResult, CAT_BOUNDING_BOX, CAT_SCORE);
assertImageSizeIsExpected(inputImage); assertImageSizeIsExpected(inputImage);
}) })
.setMaxResults(1) .setMaxResults(1)
@ -468,8 +468,8 @@ public class ObjectDetectorTest {
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener( .setResultListener(
(objectDetectionResult, inputImage) -> { (ObjectDetectorResult, inputImage) -> {
assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE); assertContainsOnlyCat(ObjectDetectorResult, CAT_BOUNDING_BOX, CAT_SCORE);
assertImageSizeIsExpected(inputImage); assertImageSizeIsExpected(inputImage);
}) })
.setMaxResults(1) .setMaxResults(1)
@ -483,6 +483,16 @@ public class ObjectDetectorTest {
} }
} }
@Test
@SuppressWarnings("deprecation")
public void detect_canUseDeprecatedApi() throws Exception {
ObjectDetector objectDetector =
ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
}
private static MPImage getImageFromAsset(String filePath) throws Exception { private static MPImage getImageFromAsset(String filePath) throws Exception {
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
InputStream istr = assetManager.open(filePath); InputStream istr = assetManager.open(filePath);
@ -491,7 +501,7 @@ public class ObjectDetectorTest {
// Checks if results has one and only detection result, which is a cat. // Checks if results has one and only detection result, which is a cat.
private static void assertContainsOnlyCat( private static void assertContainsOnlyCat(
ObjectDetectionResult result, RectF expectedBoundingBox, float expectedScore) { ObjectDetectorResult result, RectF expectedBoundingBox, float expectedScore) {
assertThat(result.detections()).hasSize(1); assertThat(result.detections()).hasSize(1);
Detection catResult = result.detections().get(0); Detection catResult = result.detections().get(0);
assertApproximatelyEqualBoundingBoxes(catResult.boundingBox(), expectedBoundingBox); assertApproximatelyEqualBoundingBoxes(catResult.boundingBox(), expectedBoundingBox);

View File

@ -34,6 +34,8 @@ py_library(
], ],
deps = [ deps = [
":optional_dependencies", ":optional_dependencies",
"//mediapipe/calculators/tensor:inference_calculator_py_pb2",
"//mediapipe/tasks/cc/core/proto:acceleration_py_pb2",
"//mediapipe/tasks/cc/core/proto:base_options_py_pb2", "//mediapipe/tasks/cc/core/proto:base_options_py_pb2",
"//mediapipe/tasks/cc/core/proto:external_file_py_pb2", "//mediapipe/tasks/cc/core/proto:external_file_py_pb2",
], ],

View File

@ -14,13 +14,19 @@
"""Base options for MediaPipe Task APIs.""" """Base options for MediaPipe Task APIs."""
import dataclasses import dataclasses
import enum
import os import os
import platform
from typing import Any, Optional from typing import Any, Optional
from mediapipe.calculators.tensor import inference_calculator_pb2
from mediapipe.tasks.cc.core.proto import acceleration_pb2
from mediapipe.tasks.cc.core.proto import base_options_pb2 from mediapipe.tasks.cc.core.proto import base_options_pb2
from mediapipe.tasks.cc.core.proto import external_file_pb2 from mediapipe.tasks.cc.core.proto import external_file_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_DelegateProto = inference_calculator_pb2.InferenceCalculatorOptions.Delegate
_AccelerationProto = acceleration_pb2.Acceleration
_BaseOptionsProto = base_options_pb2.BaseOptions _BaseOptionsProto = base_options_pb2.BaseOptions
_ExternalFileProto = external_file_pb2.ExternalFile _ExternalFileProto = external_file_pb2.ExternalFile
@ -41,11 +47,17 @@ class BaseOptions:
Attributes: Attributes:
model_asset_path: Path to the model asset file. model_asset_path: Path to the model asset file.
model_asset_buffer: The model asset file contents as bytes. model_asset_buffer: The model asset file contents as bytes.
delegate: Accelaration to use. Supported values are GPU and CPU. GPU support
is currently limited to Ubuntu platforms.
""" """
class Delegate(enum.Enum):
CPU = 0
GPU = 1
model_asset_path: Optional[str] = None model_asset_path: Optional[str] = None
model_asset_buffer: Optional[bytes] = None model_asset_buffer: Optional[bytes] = None
# TODO: Allow Python API to specify acceleration settings. delegate: Optional[Delegate] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _BaseOptionsProto: def to_pb2(self) -> _BaseOptionsProto:
@ -55,17 +67,44 @@ class BaseOptions:
else: else:
full_path = None full_path = None
platform_name = platform.system()
if self.delegate == BaseOptions.Delegate.GPU:
if platform_name == 'Linux':
acceleration_proto = _AccelerationProto(gpu=_DelegateProto.Gpu())
else:
raise NotImplementedError(
'GPU Delegate is not yet supported for ' + platform_name
)
elif self.delegate == BaseOptions.Delegate.CPU:
acceleration_proto = _AccelerationProto(tflite=_DelegateProto.TfLite())
else:
acceleration_proto = None
return _BaseOptionsProto( return _BaseOptionsProto(
model_asset=_ExternalFileProto( model_asset=_ExternalFileProto(
file_name=full_path, file_content=self.model_asset_buffer)) file_name=full_path, file_content=self.model_asset_buffer
),
acceleration=acceleration_proto,
)
@classmethod @classmethod
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _BaseOptionsProto) -> 'BaseOptions': def create_from_pb2(cls, pb2_obj: _BaseOptionsProto) -> 'BaseOptions':
"""Creates a `BaseOptions` object from the given protobuf object.""" """Creates a `BaseOptions` object from the given protobuf object."""
delegate = None
if pb2_obj.acceleration is not None:
delegate = (
BaseOptions.Delegate.GPU
if pb2_obj.acceleration.gpu is not None
else BaseOptions.Delegate.CPU
)
return BaseOptions( return BaseOptions(
model_asset_path=pb2_obj.model_asset.file_name, model_asset_path=pb2_obj.model_asset.file_name,
model_asset_buffer=pb2_obj.model_asset.file_content) model_asset_buffer=pb2_obj.model_asset.file_content,
delegate=delegate,
)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object. """Checks if this object is equal to the given object.

View File

@ -59,7 +59,11 @@ const DEFAULT_SCORE_THRESHOLD = 0.5;
* This API expects a pre-trained face landmarker model asset bundle. * This API expects a pre-trained face landmarker model asset bundle.
*/ */
export class FaceLandmarker extends VisionTaskRunner { export class FaceLandmarker extends VisionTaskRunner {
private result: FaceLandmarkerResult = {faceLandmarks: []}; private result: FaceLandmarkerResult = {
faceLandmarks: [],
faceBlendshapes: [],
facialTransformationMatrixes: []
};
private outputFaceBlendshapes = false; private outputFaceBlendshapes = false;
private outputFacialTransformationMatrixes = false; private outputFacialTransformationMatrixes = false;
@ -256,13 +260,11 @@ export class FaceLandmarker extends VisionTaskRunner {
} }
private resetResults(): void { private resetResults(): void {
this.result = {faceLandmarks: []}; this.result = {
if (this.outputFaceBlendshapes) { faceLandmarks: [],
this.result.faceBlendshapes = []; faceBlendshapes: [],
} facialTransformationMatrixes: []
if (this.outputFacialTransformationMatrixes) { };
this.result.facialTransformationMatrixes = [];
}
} }
/** Sets the default values for the graph. */ /** Sets the default values for the graph. */
@ -286,7 +288,7 @@ export class FaceLandmarker extends VisionTaskRunner {
/** Adds new blendshapes from the given proto. */ /** Adds new blendshapes from the given proto. */
private addBlenshape(data: Uint8Array[]): void { private addBlenshape(data: Uint8Array[]): void {
if (!this.result.faceBlendshapes) { if (!this.outputFaceBlendshapes) {
return; return;
} }
@ -300,7 +302,7 @@ export class FaceLandmarker extends VisionTaskRunner {
/** Adds new transformation matrixes from the given proto. */ /** Adds new transformation matrixes from the given proto. */
private addFacialTransformationMatrixes(data: Uint8Array[]): void { private addFacialTransformationMatrixes(data: Uint8Array[]): void {
if (!this.result.facialTransformationMatrixes) { if (!this.outputFacialTransformationMatrixes) {
return; return;
} }

View File

@ -29,8 +29,8 @@ export declare interface FaceLandmarkerResult {
faceLandmarks: NormalizedLandmark[][]; faceLandmarks: NormalizedLandmark[][];
/** Optional face blendshapes results. */ /** Optional face blendshapes results. */
faceBlendshapes?: Classifications[]; faceBlendshapes: Classifications[];
/** Optional facial transformation matrix. */ /** Optional facial transformation matrix. */
facialTransformationMatrixes?: Matrix[]; facialTransformationMatrixes: Matrix[];
} }

View File

@ -30,6 +30,7 @@ from setuptools.command import build_py
from setuptools.command import install from setuptools.command import install
__version__ = 'dev' __version__ = 'dev'
MP_DISABLE_GPU = os.environ.get('MEDIAPIPE_DISABLE_GPU') != '0'
IS_WINDOWS = (platform.system() == 'Windows') IS_WINDOWS = (platform.system() == 'Windows')
IS_MAC = (platform.system() == 'Darwin') IS_MAC = (platform.system() == 'Darwin')
MP_ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) MP_ROOT_PATH = os.path.dirname(os.path.abspath(__file__))
@ -279,10 +280,16 @@ class BuildModules(build_ext.build_ext):
'build', 'build',
'--compilation_mode=opt', '--compilation_mode=opt',
'--copt=-DNDEBUG', '--copt=-DNDEBUG',
'--define=MEDIAPIPE_DISABLE_GPU=1',
'--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable), '--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable),
binary_graph_target, binary_graph_target,
] ]
if MP_DISABLE_GPU:
bazel_command.append('--define=MEDIAPIPE_DISABLE_GPU=1')
else:
bazel_command.append('--copt=-DMESA_EGL_NO_X11_HEADERS')
bazel_command.append('--copt=-DEGL_NO_X11')
if not self.link_opencv and not IS_WINDOWS: if not self.link_opencv and not IS_WINDOWS:
bazel_command.append('--define=OPENCV=source') bazel_command.append('--define=OPENCV=source')
if subprocess.call(bazel_command) != 0: if subprocess.call(bazel_command) != 0:
@ -300,14 +307,21 @@ class GenerateMetadataSchema(build_ext.build_ext):
'object_detector_metadata_schema_py', 'object_detector_metadata_schema_py',
'schema_py', 'schema_py',
]: ]:
bazel_command = [ bazel_command = [
'bazel', 'bazel',
'build', 'build',
'--compilation_mode=opt', '--compilation_mode=opt',
'--define=MEDIAPIPE_DISABLE_GPU=1',
'--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable), '--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable),
'//mediapipe/tasks/metadata:' + target, '//mediapipe/tasks/metadata:' + target,
] ]
if MP_DISABLE_GPU:
bazel_command.append('--define=MEDIAPIPE_DISABLE_GPU=1')
else:
bazel_command.append('--copt=-DMESA_EGL_NO_X11_HEADERS')
bazel_command.append('--copt=-DEGL_NO_X11')
if subprocess.call(bazel_command) != 0: if subprocess.call(bazel_command) != 0:
sys.exit(-1) sys.exit(-1)
_copy_to_build_lib_dir( _copy_to_build_lib_dir(
@ -393,7 +407,8 @@ class BuildExtension(build_ext.build_ext):
'build', 'build',
'--compilation_mode=opt', '--compilation_mode=opt',
'--copt=-DNDEBUG', '--copt=-DNDEBUG',
'--define=MEDIAPIPE_DISABLE_GPU=1', '--copt=-DMESA_EGL_NO_X11_HEADERS',
'--copt=-DEGL_NO_X11',
'--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable), '--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable),
str(ext.bazel_target + '.so'), str(ext.bazel_target + '.so'),
] ]