From ab135190e586b0928ff8fa9d2ca101da20e2cdb1 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 4 May 2023 17:03:40 +0530 Subject: [PATCH] Updated iOS object detector to use delegates instead of callbacks for async calls --- .../object_detector/MPPObjectDetectorTests.m | 94 +++++++++++-------- .../sources/MPPObjectDetector.h | 3 + .../sources/MPPObjectDetector.mm | 43 +++++++-- .../sources/MPPObjectDetectorOptions.h | 61 +++++++++++- .../sources/MPPObjectDetectorOptions.m | 2 +- 5 files changed, 148 insertions(+), 55 deletions(-) diff --git a/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m b/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m index fd9466b7d..9ccfece91 100644 --- a/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m +++ b/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m @@ -25,6 +25,8 @@ static NSDictionary *const kCatsAndDogsRotatedImage = static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; static const float pixelDifferenceTolerance = 10.0f; static const float scoreDifferenceTolerance = 0.02f; +static NSString *const kLiveStreamTestsDictObjectDetectorKey = @"object_detector"; +static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; #define AssertEqualErrors(error, expectedError) \ XCTAssertNotNil(error); \ @@ -58,7 +60,10 @@ static const float scoreDifferenceTolerance = 0.02f; XCTAssertEqualWithAccuracy(boundingBox.size.height, expectedBoundingBox.size.height, \ pixelDifferenceTolerance, @"index i = %d", idx); -@interface MPPObjectDetectorTests : XCTestCase +@interface MPPObjectDetectorTests : XCTestCase { + NSDictionary *liveStreamSucceedsTestDict; + NSDictionary *outOfOrderTimestampTestDict; +} @end @implementation MPPObjectDetectorTests @@ -446,31 +451,28 @@ static const float scoreDifferenceTolerance = 0.02f; #pragma mark Running Mode Tests -- (void)testCreateObjectDetectorFailsWithResultListenerInNonLiveStreamMode { +- (void)testCreateObjectDetectorFailsWithDelegateInNonLiveStreamMode { MPPRunningMode runningModesToTest[] = {MPPRunningModeImage, MPPRunningModeVideo}; for (int i = 0; i < sizeof(runningModesToTest) / sizeof(runningModesToTest[0]); i++) { MPPObjectDetectorOptions *options = [self objectDetectorOptionsWithModelName:kModelName]; options.runningMode = runningModesToTest[i]; - options.completion = - ^(MPPObjectDetectionResult *result, NSInteger timestampInMilliseconds, NSError *error) { - }; + options.objectDetectorLiveStreamDelegate = self; [self assertCreateObjectDetectorWithOptions:options failsWithExpectedError: - [NSError - errorWithDomain:kExpectedErrorDomain - code:MPPTasksErrorCodeInvalidArgumentError - userInfo:@{ - NSLocalizedDescriptionKey : - @"The vision task is in image or video mode, a " - @"user-defined result callback should not be provided." - }]]; + [NSError errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"The vision task is in image or video mode. The " + @"delegate must not be set in the task's options." + }]]; } } -- (void)testCreateObjectDetectorFailsWithMissingResultListenerInLiveStreamMode { +- (void)testCreateObjectDetectorFailsWithMissingDelegateInLiveStreamMode { MPPObjectDetectorOptions *options = [self objectDetectorOptionsWithModelName:kModelName]; options.runningMode = MPPRunningModeLiveStream; @@ -481,8 +483,11 @@ static const float scoreDifferenceTolerance = 0.02f; code:MPPTasksErrorCodeInvalidArgumentError userInfo:@{ NSLocalizedDescriptionKey : - @"The vision task is in live stream mode, a " - @"user-defined result callback must be provided." + @"The vision task is in live stream mode. An " + @"object must be set as the " + @"delegate of the task in its options to ensure " + @"asynchronous delivery of " + @"results." }]]; } @@ -563,10 +568,7 @@ static const float scoreDifferenceTolerance = 0.02f; MPPObjectDetectorOptions *options = [self objectDetectorOptionsWithModelName:kModelName]; options.runningMode = MPPRunningModeLiveStream; - options.completion = - ^(MPPObjectDetectionResult *result, NSInteger timestampInMilliseconds, NSError *error) { - - }; + options.objectDetectorLiveStreamDelegate = self; MPPObjectDetector *objectDetector = [self objectDetectorWithOptionsSucceeds:options]; @@ -631,23 +633,17 @@ static const float scoreDifferenceTolerance = 0.02f; options.maxResults = maxResults; options.runningMode = MPPRunningModeLiveStream; + options.objectDetectorLiveStreamDelegate = self; XCTestExpectation *expectation = [[XCTestExpectation alloc] initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"]; expectation.expectedFulfillmentCount = 1; - options.completion = - ^(MPPObjectDetectionResult *result, NSInteger timestampInMilliseconds, NSError *error) { - [self assertObjectDetectionResult:result - isEqualToExpectedResult: - [MPPObjectDetectorTests - expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: - timestampInMilliseconds] - expectedDetectionsCount:maxResults]; - [expectation fulfill]; - }; - MPPObjectDetector *objectDetector = [self objectDetectorWithOptionsSucceeds:options]; + liveStreamSucceedsTestDict = @{ + kLiveStreamTestsDictObjectDetectorKey : objectDetector, + kLiveStreamTestsDictExpectationKey : expectation + }; MPPImage *image = [self imageWithFileInfo:kCatsAndDogsImage]; @@ -693,19 +689,15 @@ static const float scoreDifferenceTolerance = 0.02f; expectation.expectedFulfillmentCount = iterationCount + 1; expectation.inverted = YES; - options.completion = - ^(MPPObjectDetectionResult *result, NSInteger timestampInMilliseconds, NSError *error) { - [self assertObjectDetectionResult:result - isEqualToExpectedResult: - [MPPObjectDetectorTests - expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: - timestampInMilliseconds] - expectedDetectionsCount:maxResults]; - [expectation fulfill]; - }; + options.objectDetectorLiveStreamDelegate = self; MPPObjectDetector *objectDetector = [self objectDetectorWithOptionsSucceeds:options]; + liveStreamSucceedsTestDict = @{ + kLiveStreamTestsDictObjectDetectorKey : objectDetector, + kLiveStreamTestsDictExpectationKey : expectation + }; + // TODO: Mimic initialization from CMSampleBuffer as live stream mode is most likely to be used // with the iOS camera. AVCaptureVideoDataOutput sample buffer delegates provide frames of type // `CMSampleBuffer`. @@ -718,4 +710,24 @@ static const float scoreDifferenceTolerance = 0.02f; [self waitForExpectations:@[ expectation ] timeout:0.5]; } +#pragma mark MPPObjectDetectorLiveStreamDelegate Methods +- (void)objectDetector:(MPPObjectDetector *)objectDetector + didFinishDetectionWithResult:(MPPObjectDetectionResult *)objectDetectionResult + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(NSError *)error { + NSInteger maxResults = 4; + [self assertObjectDetectionResult:objectDetectionResult + isEqualToExpectedResult: + [MPPObjectDetectorTests + expectedDetectionResultForCatsAndDogsImageWithTimestampInMilliseconds: + timestampInMilliseconds] + expectedDetectionsCount:maxResults]; + + if (objectDetector == outOfOrderTimestampTestDict[kLiveStreamTestsDictObjectDetectorKey]) { + [outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill]; + } else if (objectDetector == liveStreamSucceedsTestDict[kLiveStreamTestsDictObjectDetectorKey]) { + [liveStreamSucceedsTestDict[kLiveStreamTestsDictExpectationKey] fulfill]; + } +} + @end diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.h b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.h index 4443f56d1..249ee0434 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.h +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.h @@ -137,6 +137,9 @@ NS_SWIFT_NAME(ObjectDetector) * the provided `MPPImage`. Only use this method when the `MPPObjectDetector` is created with * `MPPRunningModeLiveStream`. Results are provided asynchronously via the `completion` callback * provided in the `MPPObjectDetectorOptions`. + * The object which needs to be continuously notified of the available results of object + * detection must confirm to `MPPObjectDetectorLiveStreamDelegate` protocol and implement the + * `objectDetector:didFinishDetectionWithResult:timestampInMilliseconds:error:` delegate method. * * It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent * to the object detector. The input timestamps must be monotonically increasing. diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.mm b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.mm index f0914cdb1..5dfbfdab8 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.mm +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.mm @@ -37,8 +37,8 @@ static NSString *const kImageOutStreamName = @"image_out"; static NSString *const kImageTag = @"IMAGE"; static NSString *const kNormRectStreamName = @"norm_rect_in"; static NSString *const kNormRectTag = @"NORM_RECT"; - static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorGraph"; +static NSString *const kTaskName = @"objectDetector"; #define InputPacketMap(imagePacket, normalizedRectPacket) \ { \ @@ -51,6 +51,7 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG /** iOS Vision Task Runner */ MPPVisionTaskRunner *_visionTaskRunner; } +@property(nonatomic, weak) id objectDetectorLiveStreamDelegate; @end @implementation MPPObjectDetector @@ -78,11 +79,32 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG PacketsCallback packetsCallback = nullptr; - if (options.completion) { + if (options.objectDetectorLiveStreamDelegate) { + _objectDetectorLiveStreamDelegate = options.objectDetectorLiveStreamDelegate; + + // Capturing `self` as weak in order to avoid `self` being kept in memory + // and cause a retain cycle, after self is set to `nil`. + MPPObjectDetector *__weak weakSelf = self; + dispatch_queue_t callbackQueue = + dispatch_queue_create([MPPVisionTaskRunner uniqueDispatchQueueNameWithSuffix:kTaskName], NULL); packetsCallback = [=](absl::StatusOr statusOrPackets) { + if (!weakSelf) { + return; + } + if (![weakSelf.objectDetectorLiveStreamDelegate + respondsToSelector:@selector + (objectDetector:didFinishDetectionWithResult:timestampInMilliseconds:error:)]) { + return; + } + NSError *callbackError = nil; if (![MPPCommonUtils checkCppError:statusOrPackets.status() toError:&callbackError]) { - options.completion(nil, Timestamp::Unset().Value(), callbackError); + dispatch_async(callbackQueue, ^{ + [weakSelf.objectDetectorLiveStreamDelegate objectDetector:weakSelf + didFinishDetectionWithResult:nil + timestampInMilliseconds:Timestamp::Unset().Value() + error:callbackError]; + }); return; } @@ -95,10 +117,15 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG objectDetectionResultWithDetectionsPacket:statusOrPackets.value()[kDetectionsStreamName .cppString]]; - options.completion(result, - outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() / - kMicroSecondsPerMilliSecond, - callbackError); + NSInteger timeStampInMilliseconds = + outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() / + kMicroSecondsPerMilliSecond; + dispatch_async(callbackQueue, ^{ + [weakSelf.objectDetectorLiveStreamDelegate objectDetector:weakSelf + didFinishDetectionWithResult:result + timestampInMilliseconds:timeStampInMilliseconds + error:callbackError]; + }); }; } @@ -112,6 +139,7 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG return nil; } } + return self; } @@ -224,5 +252,4 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG return [_visionTaskRunner processLiveStreamPacketMap:inputPacketMap.value() error:error]; } - @end diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h index 79bc9baa6..c60c8acac 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.h @@ -20,19 +20,70 @@ NS_ASSUME_NONNULL_BEGIN +@class MPPObjectDetector; + +/** + * This protocol defines an interface for the delegates of `MPPObjectDetector` object to receive + * results of performing asynchronous object detection on images + * (i.e, when `runningMode` = `MPPRunningModeLiveStream`). + * + * The delegate of `MPPObjectDetector` must adopt `MPPObjectDetectorLiveStreamDelegate` protocol. + * The methods in this protocol are optional. + */ +NS_SWIFT_NAME(ObjectDetectorLiveStreamDelegate) +@protocol MPPObjectDetectorLiveStreamDelegate + +@optional + +/** + * This method notifies a delegate that the results of asynchronous object detection of + * an image submitted to the `MPPObjectDetector` is available. + * + * This method is called on a private serial dispatch queue created by the `MPPObjectDetector` + * for performing the asynchronous delegates calls. + * + * @param objectDetector The object detector which performed the object detection. + * This is useful to test equality when there are multiple instances of `MPPObjectDetector`. + * @param result The `MPPObjectDetectionResult` object that contains a list of detections, each + * detection has a bounding box that is expressed in the unrotated input frame of reference + * coordinates system, i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions of the + * underlying image data. + * @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input + * image was sent to the object detector. + * @param error An optional error parameter populated when there is an error in performing object + * detection on the input live stream image data. + * + */ +- (void)objectDetector:(MPPObjectDetector *)objectDetector + didFinishDetectionWithResult:(nullable MPPObjectDetectionResult *)result + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(nullable NSError *)error + NS_SWIFT_NAME(objectDetector(_:didFinishDetection:timestampInMilliseconds:error:)); +@end + /** Options for setting up a `MPPObjectDetector`. */ NS_SWIFT_NAME(ObjectDetectorOptions) @interface MPPObjectDetectorOptions : MPPTaskOptions +/** + * Running mode of the object detector task. Defaults to `MPPRunningModeImage`. + * `MPPImageClassifier` can be created with one of the following running modes: + * 1. `MPPRunningModeImage`: The mode for performing object detection on single image inputs. + * 2. `MPPRunningModeVideo`: The mode for performing object detection on the decoded frames of a + * video. + * 3. `MPPRunningModeLiveStream`: The mode for performing object detection on a live stream of + * input data, such as from the camera. + */ @property(nonatomic) MPPRunningMode runningMode; /** - * The user-defined result callback for processing live stream data. The result callback should only - * be specified when the running mode is set to the live stream mode. - * TODO: Add parameter `MPPImage` in the callback. + * An object that confirms to `MPPObjectDetectorLiveStreamDelegate` protocol. This object must + * implement `objectDetector:didFinishDetectionWithResult:timestampInMilliseconds:error:` to receive + * the results of performing asynchronous object detection on images (i.e, when `runningMode` = + * `MPPRunningModeLiveStream`). */ -@property(nonatomic, copy) void (^completion) - (MPPObjectDetectionResult *__nullable result, NSInteger timestampMs, NSError *error); +@property(nonatomic, weak, nullable) id + objectDetectorLiveStreamDelegate; /** * The locale to use for display names specified through the TFLite Model Metadata, if any. Defaults diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.m b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.m index 73f8ce5b5..b93a6b30b 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.m +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectorOptions.m @@ -33,7 +33,7 @@ objectDetectorOptions.categoryDenylist = self.categoryDenylist; objectDetectorOptions.categoryAllowlist = self.categoryAllowlist; objectDetectorOptions.displayNamesLocale = self.displayNamesLocale; - objectDetectorOptions.completion = self.completion; + objectDetectorOptions.objectDetectorLiveStreamDelegate = self.objectDetectorLiveStreamDelegate; return objectDetectorOptions; }