diff --git a/mediapipe/tasks/ios/vision/face_detector/BUILD b/mediapipe/tasks/ios/vision/face_detector/BUILD index e4fc15616..eb34da1b6 100644 --- a/mediapipe/tasks/ios/vision/face_detector/BUILD +++ b/mediapipe/tasks/ios/vision/face_detector/BUILD @@ -55,7 +55,7 @@ objc_library( "//mediapipe/tasks/ios/core:MPPTaskInfo", "//mediapipe/tasks/ios/vision/core:MPPImage", "//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator", - "//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner", + "//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunnerRefactored", "//mediapipe/tasks/ios/vision/face_detector/utils:MPPFaceDetectorOptionsHelpers", "//mediapipe/tasks/ios/vision/face_detector/utils:MPPFaceDetectorResultHelpers", ], diff --git a/mediapipe/tasks/ios/vision/face_detector/sources/MPPFaceDetector.mm b/mediapipe/tasks/ios/vision/face_detector/sources/MPPFaceDetector.mm index 7cb525fb0..6da599fd7 100644 --- a/mediapipe/tasks/ios/vision/face_detector/sources/MPPFaceDetector.mm +++ b/mediapipe/tasks/ios/vision/face_detector/sources/MPPFaceDetector.mm @@ -18,12 +18,10 @@ #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" #import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h" -#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h" +#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunnerRefactored.h" #import "mediapipe/tasks/ios/vision/face_detector/utils/sources/MPPFaceDetectorOptions+Helpers.h" #import "mediapipe/tasks/ios/vision/face_detector/utils/sources/MPPFaceDetectorResult+Helpers.h" -using ::mediapipe::NormalizedRect; -using ::mediapipe::Packet; using ::mediapipe::Timestamp; using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::core::PacketsCallback; @@ -49,6 +47,12 @@ static NSString *const kTaskName = @"faceDetector"; } \ } +#define FaceDetectorResultWithOutputPacketMap(outputPacketMap) \ + ( \ + [MPPFaceDetectorResult \ + faceDetectorResultWithDetectionsPacket:outputPacketMap[kDetectionsStreamName.cppString]] \ + ) + @interface MPPFaceDetector () { /** iOS Vision Task Runner */ MPPVisionTaskRunner *_visionTaskRunner; @@ -102,11 +106,13 @@ static NSString *const kTaskName = @"faceDetector"; }; } - _visionTaskRunner = - [[MPPVisionTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] - runningMode:options.runningMode - packetsCallback:std::move(packetsCallback) - error:error]; + _visionTaskRunner = [[MPPVisionTaskRunner alloc] initWithTaskInfo:taskInfo + runningMode:options.runningMode + roiAllowed:NO + packetsCallback:std::move(packetsCallback) + imageInputStreamName:kImageInStreamName + normRectInputStreamName:kNormRectStreamName + error:error]; if (!_visionTaskRunner) { return nil; @@ -124,95 +130,29 @@ static NSString *const kTaskName = @"faceDetector"; return [self initWithOptions:options error:error]; } -- (std::optional)inputPacketMapWithMPPImage:(MPPImage *)image - timestampInMilliseconds:(NSInteger)timestampInMilliseconds - error:(NSError **)error { - std::optional rect = - [_visionTaskRunner normalizedRectWithImageOrientation:image.orientation - imageSize:CGSizeMake(image.width, image.height) - error:error]; - if (!rect.has_value()) { - return std::nullopt; - } - - Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image - timestampInMilliseconds:timestampInMilliseconds - error:error]; - if (imagePacket.IsEmpty()) { - return std::nullopt; - } - - Packet normalizedRectPacket = - [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() - timestampInMilliseconds:timestampInMilliseconds]; - - PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket); - return inputPacketMap; -} - - (nullable MPPFaceDetectorResult *)detectInImage:(MPPImage *)image error:(NSError **)error { - std::optional rect = - [_visionTaskRunner normalizedRectWithImageOrientation:image.orientation - imageSize:CGSizeMake(image.width, image.height) - error:error]; - if (!rect.has_value()) { - return nil; - } + std::optional outputPacketMap = [_visionTaskRunner processImage:image error:error]; - Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image error:error]; - if (imagePacket.IsEmpty()) { - return nil; - } - - Packet normalizedRectPacket = - [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()]; - - PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket); - - std::optional outputPacketMap = [_visionTaskRunner processImagePacketMap:inputPacketMap - error:error]; - if (!outputPacketMap.has_value()) { - return nil; - } - - return [MPPFaceDetectorResult - faceDetectorResultWithDetectionsPacket:outputPacketMap - .value()[kDetectionsStreamName.cppString]]; + return [MPPFaceDetector faceDetectorResultWithOptionalOutputPacketMap:outputPacketMap]; } - (nullable MPPFaceDetectorResult *)detectInVideoFrame:(MPPImage *)image timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(NSError **)error { - std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image - timestampInMilliseconds:timestampInMilliseconds - error:error]; - if (!inputPacketMap.has_value()) { - return nil; - } - std::optional outputPacketMap = - [_visionTaskRunner processVideoFramePacketMap:inputPacketMap.value() error:error]; + [_visionTaskRunner processVideoFrame:image + timestampInMilliseconds:timestampInMilliseconds + error:error]; - if (!outputPacketMap.has_value()) { - return nil; - } - - return [MPPFaceDetectorResult - faceDetectorResultWithDetectionsPacket:outputPacketMap - .value()[kDetectionsStreamName.cppString]]; + return [MPPFaceDetector faceDetectorResultWithOptionalOutputPacketMap:outputPacketMap]; } - (BOOL)detectAsyncInImage:(MPPImage *)image timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(NSError **)error { - std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image - timestampInMilliseconds:timestampInMilliseconds - error:error]; - if (!inputPacketMap.has_value()) { - return NO; - } - - return [_visionTaskRunner processLiveStreamPacketMap:inputPacketMap.value() error:error]; + return [_visionTaskRunner processLiveStreamImage:image + timestampInMilliseconds:timestampInMilliseconds + error:error]; } - (void)processLiveStreamResult:(absl::StatusOr)liveStreamResult { @@ -237,9 +177,7 @@ static NSString *const kTaskName = @"faceDetector"; return; } - MPPFaceDetectorResult *result = [MPPFaceDetectorResult - faceDetectorResultWithDetectionsPacket:liveStreamResult - .value()[kDetectionsStreamName.cppString]]; + MPPFaceDetectorResult *result = FaceDetectorResultWithOutputPacketMap(liveStreamResult.value()); NSInteger timeStampInMilliseconds = outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() / @@ -252,4 +190,13 @@ static NSString *const kTaskName = @"faceDetector"; }); } ++ (nullable MPPFaceDetectorResult *)faceDetectorResultWithOptionalOutputPacketMap: + (std::optional)outputPacketMap { + if (!outputPacketMap.has_value()) { + return nil; + } + + return FaceDetectorResultWithOutputPacketMap(outputPacketMap.value()); +} + @end diff --git a/mediapipe/tasks/ios/vision/image_classifier/BUILD b/mediapipe/tasks/ios/vision/image_classifier/BUILD index cf89249c4..daff017dc 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/BUILD +++ b/mediapipe/tasks/ios/vision/image_classifier/BUILD @@ -57,7 +57,7 @@ objc_library( "//mediapipe/tasks/ios/core:MPPTaskInfo", "//mediapipe/tasks/ios/vision/core:MPPImage", "//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator", - "//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner", + "//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunnerRefactored", "//mediapipe/tasks/ios/vision/image_classifier/utils:MPPImageClassifierOptionsHelpers", "//mediapipe/tasks/ios/vision/image_classifier/utils:MPPImageClassifierResultHelpers", ], diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm index 5d2595cd1..3e1592e11 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm @@ -18,7 +18,7 @@ #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" #import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h" -#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h" +#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunnerRefactored.h" #import "mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.h" #import "mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h" @@ -52,6 +52,13 @@ static const int kMicroSecondsPerMilliSecond = 1000; } \ } +#define ImageClassifierResultWithOutputPacketMap(outputPacketMap) \ + ( \ + [MPPImageClassifierResult \ + imageClassifierResultWithClassificationsPacket:outputPacketMap[kClassificationsStreamName \ + .cppString]] \ + ) + @interface MPPImageClassifier () { /** iOS Vision Task Runner */ MPPVisionTaskRunner *_visionTaskRunner; @@ -63,43 +70,7 @@ static const int kMicroSecondsPerMilliSecond = 1000; @implementation MPPImageClassifier -- (void)processLiveStreamResult:(absl::StatusOr)liveStreamResult { - if (![self.imageClassifierLiveStreamDelegate - respondsToSelector:@selector - (imageClassifier:didFinishClassificationWithResult:timestampInMilliseconds:error:)]) { - return; - } - - NSError *callbackError = nil; - if (![MPPCommonUtils checkCppError:liveStreamResult.status() toError:&callbackError]) { - dispatch_async(_callbackQueue, ^{ - [self.imageClassifierLiveStreamDelegate imageClassifier:self - didFinishClassificationWithResult:nil - timestampInMilliseconds:Timestamp::Unset().Value() - error:callbackError]; - }); - return; - } - - PacketMap &outputPacketMap = liveStreamResult.value(); - if (outputPacketMap[kImageOutStreamName.cppString].IsEmpty()) { - return; - } - - MPPImageClassifierResult *result = [MPPImageClassifierResult - imageClassifierResultWithClassificationsPacket:outputPacketMap[kClassificationsStreamName - .cppString]]; - - NSInteger timeStampInMilliseconds = - outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() / - kMicroSecondsPerMilliSecond; - dispatch_async(_callbackQueue, ^{ - [self.imageClassifierLiveStreamDelegate imageClassifier:self - didFinishClassificationWithResult:result - timestampInMilliseconds:timeStampInMilliseconds - error:callbackError]; - }); -} +#pragma mark - Public - (instancetype)initWithOptions:(MPPImageClassifierOptions *)options error:(NSError **)error { self = [super init]; @@ -143,11 +114,13 @@ static const int kMicroSecondsPerMilliSecond = 1000; }; } - _visionTaskRunner = - [[MPPVisionTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] - runningMode:options.runningMode - packetsCallback:std::move(packetsCallback) - error:error]; + _visionTaskRunner = [[MPPVisionTaskRunner alloc] initWithTaskInfo:taskInfo + runningMode:options.runningMode + roiAllowed:YES + packetsCallback:std::move(packetsCallback) + imageInputStreamName:kImageInStreamName + normRectInputStreamName:kNormRectStreamName + error:error]; if (!_visionTaskRunner) { return nil; @@ -167,90 +140,28 @@ static const int kMicroSecondsPerMilliSecond = 1000; - (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image regionOfInterest:(CGRect)roi error:(NSError **)error { - std::optional rect = - [_visionTaskRunner normalizedRectWithRegionOfInterest:roi - imageOrientation:image.orientation - imageSize:CGSizeMake(image.width, image.height) - error:error]; - if (!rect.has_value()) { - return nil; - } + std::optional outputPacketMap = [_visionTaskRunner processImage:image + regionOfInterest:roi + error:error]; - Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image error:error]; - if (imagePacket.IsEmpty()) { - return nil; - } - - Packet normalizedRectPacket = - [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()]; - - PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket); - - std::optional outputPacketMap = [_visionTaskRunner processImagePacketMap:inputPacketMap - error:error]; - if (!outputPacketMap.has_value()) { - return nil; - } - - return - [MPPImageClassifierResult imageClassifierResultWithClassificationsPacket: - outputPacketMap.value()[kClassificationsStreamName.cppString]]; + return [MPPImageClassifier imageClassifierResultWithOptionalOutputPacketMap:outputPacketMap]; } - (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image error:(NSError **)error { return [self classifyImage:image regionOfInterest:CGRectZero error:error]; } -- (std::optional)inputPacketMapWithMPPImage:(MPPImage *)image - timestampInMilliseconds:(NSInteger)timestampInMilliseconds - regionOfInterest:(CGRect)roi - error:(NSError **)error { - std::optional rect = - [_visionTaskRunner normalizedRectWithRegionOfInterest:roi - imageOrientation:image.orientation - imageSize:CGSizeMake(image.width, image.height) - error:error]; - if (!rect.has_value()) { - return std::nullopt; - } - - Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image - timestampInMilliseconds:timestampInMilliseconds - error:error]; - if (imagePacket.IsEmpty()) { - return std::nullopt; - } - - Packet normalizedRectPacket = - [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() - timestampInMilliseconds:timestampInMilliseconds]; - - PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket); - return inputPacketMap; -} - - (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image timestampInMilliseconds:(NSInteger)timestampInMilliseconds regionOfInterest:(CGRect)roi error:(NSError **)error { - std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image - timestampInMilliseconds:timestampInMilliseconds - regionOfInterest:roi - error:error]; - if (!inputPacketMap.has_value()) { - return nil; - } - std::optional outputPacketMap = - [_visionTaskRunner processVideoFramePacketMap:inputPacketMap.value() error:error]; + [_visionTaskRunner processVideoFrame:image + regionOfInterest:roi + timestampInMilliseconds:timestampInMilliseconds + error:error]; - if (!outputPacketMap.has_value()) { - return nil; - } - - return - [MPPImageClassifierResult imageClassifierResultWithClassificationsPacket: - outputPacketMap.value()[kClassificationsStreamName.cppString]]; + return [MPPImageClassifier imageClassifierResultWithOptionalOutputPacketMap:outputPacketMap]; } - (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image @@ -266,15 +177,10 @@ static const int kMicroSecondsPerMilliSecond = 1000; timestampInMilliseconds:(NSInteger)timestampInMilliseconds regionOfInterest:(CGRect)roi error:(NSError **)error { - std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image - timestampInMilliseconds:timestampInMilliseconds - regionOfInterest:roi - error:error]; - if (!inputPacketMap.has_value()) { - return NO; - } - - return [_visionTaskRunner processLiveStreamPacketMap:inputPacketMap.value() error:error]; + return [_visionTaskRunner processLiveStreamImage:image + regionOfInterest:roi + timestampInMilliseconds:timestampInMilliseconds + error:error]; } - (BOOL)classifyAsyncImage:(MPPImage *)image @@ -286,4 +192,51 @@ static const int kMicroSecondsPerMilliSecond = 1000; error:error]; } +#pragma mark - Private + +- (void)processLiveStreamResult:(absl::StatusOr)liveStreamResult { + if (![self.imageClassifierLiveStreamDelegate + respondsToSelector:@selector + (imageClassifier:didFinishClassificationWithResult:timestampInMilliseconds:error:)]) { + return; + } + + NSError *callbackError = nil; + if (![MPPCommonUtils checkCppError:liveStreamResult.status() toError:&callbackError]) { + dispatch_async(_callbackQueue, ^{ + [self.imageClassifierLiveStreamDelegate imageClassifier:self + didFinishClassificationWithResult:nil + timestampInMilliseconds:Timestamp::Unset().Value() + error:callbackError]; + }); + return; + } + + PacketMap &outputPacketMap = liveStreamResult.value(); + if (outputPacketMap[kImageOutStreamName.cppString].IsEmpty()) { + return; + } + + MPPImageClassifierResult *result = ImageClassifierResultWithOutputPacketMap(outputPacketMap); + + NSInteger timeStampInMilliseconds = + outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() / + kMicroSecondsPerMilliSecond; + dispatch_async(_callbackQueue, ^{ + [self.imageClassifierLiveStreamDelegate imageClassifier:self + didFinishClassificationWithResult:result + timestampInMilliseconds:timeStampInMilliseconds + error:callbackError]; + }); +} + ++ (nullable MPPImageClassifierResult *)imageClassifierResultWithOptionalOutputPacketMap: + (std::optional)outputPacketMap { + if (!outputPacketMap.has_value()) { + return nil; + } + + return ImageClassifierResultWithOutputPacketMap(outputPacketMap.value()); +} + @end