diff --git a/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m b/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m index 58abb5c70..a2fd68482 100644 --- a/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m +++ b/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m @@ -27,6 +27,8 @@ static NSDictionary *const kMultiObjectsRotatedImage = @{@"name" : @"multi_objects_rotated", @"type" : @"jpg"}; static const int kMobileNetCategoriesCount = 1001; static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; +static NSString *const kLiveStreamTestsDictImageClassifierKey = @"image_classifier"; +static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; #define AssertEqualErrors(error, expectedError) \ XCTAssertNotNil(error); \ @@ -54,11 +56,14 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; XCTAssertEqual(imageClassifierResult.classificationResult.classifications.count, 1); \ XCTAssertEqual(imageClassifierResult.classificationResult.classifications[0].headIndex, 0); -@interface MPPImageClassifierTests : XCTestCase +@interface MPPImageClassifierTests : XCTestCase { + NSDictionary *liveStreamSucceedsTestDict; + NSDictionary *outOfOrderTimestampTestDict; +} + @end @implementation MPPImageClassifierTests - #pragma mark Results + (NSArray *)expectedResultCategoriesForClassifyBurgerImageWithFloatModel { @@ -436,43 +441,43 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; #pragma mark Running Mode Tests -- (void)testCreateImageClassifierFailsWithResultListenerInNonLiveStreamMode { +- (void)testCreateImageClassifierFailsWithDelegateInNonLiveStreamMode { MPPRunningMode runningModesToTest[] = {MPPRunningModeImage, MPPRunningModeVideo}; for (int i = 0; i < sizeof(runningModesToTest) / sizeof(runningModesToTest[0]); i++) { MPPImageClassifierOptions *options = [self imageClassifierOptionsWithModelName:kFloatModelName]; options.runningMode = runningModesToTest[i]; - options.completion = ^(MPPImageClassifierResult *result, NSError *error) { - }; + options.imageClassifierLiveStreamDelegate = self; [self assertCreateImageClassifierWithOptions:options failsWithExpectedError: - [NSError - errorWithDomain:kExpectedErrorDomain - code:MPPTasksErrorCodeInvalidArgumentError - userInfo:@{ - NSLocalizedDescriptionKey : - @"The vision task is in image or video mode, a " - @"user-defined result callback should not be provided." - }]]; + [NSError errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"The vision task is in image or video mode. The " + @"delegate must not be set in the task's options." + }]]; } } -- (void)testCreateImageClassifierFailsWithMissingResultListenerInLiveStreamMode { +- (void)testCreateImageClassifierFailsWithMissingDelegateInLiveStreamMode { MPPImageClassifierOptions *options = [self imageClassifierOptionsWithModelName:kFloatModelName]; options.runningMode = MPPRunningModeLiveStream; [self assertCreateImageClassifierWithOptions:options failsWithExpectedError: - [NSError errorWithDomain:kExpectedErrorDomain - code:MPPTasksErrorCodeInvalidArgumentError - userInfo:@{ - NSLocalizedDescriptionKey : - @"The vision task is in live stream mode, a " - @"user-defined result callback must be provided." - }]]; + [NSError + errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"The vision task is in live stream mode. An object " + @"must be set as the delegate of the task in its " + @"options to ensure asynchronous delivery of results." + }]]; } - (void)testClassifyFailsWithCallingWrongApiInImageMode { @@ -553,9 +558,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; MPPImageClassifierOptions *options = [self imageClassifierOptionsWithModelName:kFloatModelName]; options.runningMode = MPPRunningModeLiveStream; - options.completion = ^(MPPImageClassifierResult *result, NSError *error) { - - }; + options.imageClassifierLiveStreamDelegate = self; MPPImageClassifier *imageClassifier = [self imageClassifierWithOptionsSucceeds:options]; @@ -619,16 +622,20 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; options.maxResults = maxResults; options.runningMode = MPPRunningModeLiveStream; - options.completion = ^(MPPImageClassifierResult *result, NSError *error) { - [self assertImageClassifierResult:result - hasExpectedCategoriesCount:maxResults - expectedCategories: - [MPPImageClassifierTests - expectedResultCategoriesForClassifyBurgerImageWithFloatModel]]; - }; + options.imageClassifierLiveStreamDelegate = self; + + XCTestExpectation *expectation = [[XCTestExpectation alloc] + initWithDescription:@"classifyWithOutOfOrderTimestampsAndLiveStream"]; + + expectation.expectedFulfillmentCount = 1; MPPImageClassifier *imageClassifier = [self imageClassifierWithOptionsSucceeds:options]; + outOfOrderTimestampTestDict = @{ + kLiveStreamTestsDictImageClassifierKey : imageClassifier, + kLiveStreamTestsDictExpectationKey : expectation + }; + MPPImage *image = [self imageWithFileInfo:kBurgerImage]; XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:1 error:nil]); @@ -644,6 +651,8 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; @"INVALID_ARGUMENT: Input timestamp must be monotonically increasing." }]; AssertEqualErrors(error, expectedError); + + [self waitForExpectations:@[ expectation ] timeout:1e-2f]; } - (void)testClassifyWithLiveStreamModeSucceeds { @@ -653,24 +662,63 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; options.maxResults = maxResults; options.runningMode = MPPRunningModeLiveStream; - options.completion = ^(MPPImageClassifierResult *result, NSError *error) { - [self assertImageClassifierResult:result - hasExpectedCategoriesCount:maxResults - expectedCategories: - [MPPImageClassifierTests - expectedResultCategoriesForClassifyBurgerImageWithFloatModel]]; - }; + options.imageClassifierLiveStreamDelegate = self; + + NSInteger iterationCount = 100; + + // Because of flow limiting, we cannot ensure that the callback will be + // invoked `iterationCount` times. + // An normal expectation will fail if expectation.fullfill() is not called + // `expectation.expectedFulfillmentCount` times. + // If `expectation.isInverted = true`, the test will only succeed if + // expectation is not fullfilled for the specified `expectedFulfillmentCount`. + // Since in our case we cannot predict how many times the expectation is + // supposed to be fullfilled setting, + // `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and + // `expectation.isInverted = true` ensures that test succeeds if + // expectation is fullfilled <= `iterationCount` times. + XCTestExpectation *expectation = + [[XCTestExpectation alloc] initWithDescription:@"classifyWithLiveStream"]; + + expectation.expectedFulfillmentCount = iterationCount + 1; + expectation.inverted = YES; MPPImageClassifier *imageClassifier = [self imageClassifierWithOptionsSucceeds:options]; + liveStreamSucceedsTestDict = @{ + kLiveStreamTestsDictImageClassifierKey : imageClassifier, + kLiveStreamTestsDictExpectationKey : expectation + }; + // TODO: Mimic initialization from CMSampleBuffer as live stream mode is most likely to be used // with the iOS camera. AVCaptureVideoDataOutput sample buffer delegates provide frames of type // `CMSampleBuffer`. MPPImage *image = [self imageWithFileInfo:kBurgerImage]; - for (int i = 0; i < 3; i++) { + for (int i = 0; i < iterationCount; i++) { XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:i error:nil]); } + + [self waitForExpectations:@[ expectation ] timeout:1e-2f]; +} + +- (void)imageClassifier:(MPPImageClassifier *)imageClassifier + didFinishClassificationWithResult:(MPPImageClassifierResult *)imageClassifierResult + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(NSError *)error { + NSInteger maxResults = 3; + [self assertImageClassifierResult:imageClassifierResult + hasExpectedCategoriesCount:maxResults + expectedCategories: + [MPPImageClassifierTests + expectedResultCategoriesForClassifyBurgerImageWithFloatModel]]; + + if (imageClassifier == outOfOrderTimestampTestDict[kLiveStreamTestsDictImageClassifierKey]) { + [outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill]; + } else if (imageClassifier == + liveStreamSucceedsTestDict[kLiveStreamTestsDictImageClassifierKey]) { + [liveStreamSucceedsTestDict[kLiveStreamTestsDictExpectationKey] fulfill]; + } } @end diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h index 345687877..024eee0aa 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h @@ -164,8 +164,11 @@ NS_SWIFT_NAME(ImageClassifier) * Sends live stream image data of type `MPPImage` to perform image classification using the whole * image as region of interest. Rotation will be applied according to the `orientation` property of * the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with - * `MPPRunningModeLiveStream`. Results are provided asynchronously via the `completion` callback - * provided in the `MPPImageClassifierOptions`. + * `MPPRunningModeLiveStream`. + * The object which needs to be continuously notified of the available results of image + * classification must confirm to `MPPImageClassifierLiveStreamDelegate` protocol and implement the + * `imageClassifier:didFinishClassificationWithResult:timestampInMilliseconds:error:` + * delegate method. * * It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent * to the image classifier. The input timestamps must be monotonically increasing. @@ -185,11 +188,14 @@ NS_SWIFT_NAME(ImageClassifier) NS_SWIFT_NAME(classifyAsync(image:timestampInMilliseconds:)); /** - * Sends live stream image data of type `MPPImage` to perform image classification, cropped to the + * Sends live stream image data of type ``MPPImage`` to perform image classification, cropped to the * specified region of interest.. Rotation will be applied according to the `orientation` property * of the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with - * `MPPRunningModeLiveStream`. Results are provided asynchronously via the `completion` callback - * provided in the `MPPImageClassifierOptions`. + * `MPPRunningModeLiveStream`. + * The object which needs to be continuously notified of the available results of image + * classification must confirm to `MPPImageClassifierLiveStreamDelegate` protocol and implement the + * `imageClassifier:didFinishClassificationWithResult:timestampInMilliseconds:error:` delegate + * method. * * It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent * to the image classifier. The input timestamps must be monotonically increasing. diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm index 3e345a5d0..408153c01 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm @@ -27,6 +27,7 @@ namespace { using ::mediapipe::NormalizedRect; using ::mediapipe::Packet; +using ::mediapipe::Timestamp; using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::core::PacketsCallback; } // namespace @@ -38,9 +39,9 @@ static NSString *const kImageOutStreamName = @"image_out"; static NSString *const kImageTag = @"IMAGE"; static NSString *const kNormRectStreamName = @"norm_rect_in"; static NSString *const kNormRectTag = @"NORM_RECT"; - static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; +static NSString *const kTaskName = @"imageClassifier"; #define InputPacketMap(imagePacket, normalizedRectPacket) \ { \ @@ -53,6 +54,8 @@ static NSString *const kTaskGraphName = /** iOS Vision Task Runner */ MPPVisionTaskRunner *_visionTaskRunner; } +@property(nonatomic, weak) id + imageClassifierLiveStreamDelegate; @end @implementation MPPImageClassifier @@ -81,16 +84,58 @@ static NSString *const kTaskGraphName = PacketsCallback packetsCallback = nullptr; - if (options.completion) { + if (options.imageClassifierLiveStreamDelegate) { + _imageClassifierLiveStreamDelegate = options.imageClassifierLiveStreamDelegate; + // Capturing `self` as weak in order to avoid `self` being kept in memory + // and cause a retain cycle, after self is set to `nil`. + MPPImageClassifier *__weak weakSelf = self; + + // Create a private serial dispatch queue in which the deleagte method will be called + // asynchronously. This is to ensure that if the client performs a long running operation in + // the delegate method, the queue on which the C++ callbacks is invoked is not blocked and is + // freed up to continue with its operations. + const char *queueName = [MPPVisionTaskRunner uniqueDispatchQueueNameWithSuffix:kTaskName]; + dispatch_queue_t callbackQueue = dispatch_queue_create(queueName, NULL); packetsCallback = [=](absl::StatusOr status_or_packets) { - NSError *callbackError = nil; - MPPImageClassifierResult *result; - if ([MPPCommonUtils checkCppError:status_or_packets.status() toError:&callbackError]) { - result = [MPPImageClassifierResult - imageClassifierResultWithClassificationsPacket: - status_or_packets.value()[kClassificationsStreamName.cppString]]; + if (!weakSelf) { + return; } - options.completion(result, callbackError); + if (![weakSelf.imageClassifierLiveStreamDelegate + respondsToSelector:@selector + (imageClassifier: + didFinishClassificationWithResult:timestampInMilliseconds:error:)]) { + return; + } + + NSError *callbackError = nil; + if (![MPPCommonUtils checkCppError:status_or_packets.status() toError:&callbackError]) { + dispatch_async(callbackQueue, ^{ + [weakSelf.imageClassifierLiveStreamDelegate imageClassifier:weakSelf + didFinishClassificationWithResult:nil + timestampInMilliseconds:Timestamp::Unset().Value() + error:callbackError]; + }); + return; + } + + PacketMap &outputPacketMap = status_or_packets.value(); + if (outputPacketMap[kImageOutStreamName.cppString].IsEmpty()) { + return; + } + + MPPImageClassifierResult *result = + [MPPImageClassifierResult imageClassifierResultWithClassificationsPacket: + outputPacketMap[kClassificationsStreamName.cppString]]; + + NSInteger timeStampInMilliseconds = + outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() / + kMicroSecondsPerMilliSecond; + dispatch_async(callbackQueue, ^{ + [weakSelf.imageClassifierLiveStreamDelegate imageClassifier:weakSelf + didFinishClassificationWithResult:result + timestampInMilliseconds:timeStampInMilliseconds + error:callbackError]; + }); }; } diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h index 2e6022041..fc76560c2 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h @@ -20,20 +20,68 @@ NS_ASSUME_NONNULL_BEGIN +@class MPPImageClassifier; + +/** + * This protocol defines an interface for the delegates of `MPPImageClassifier` object to receive + * results of asynchronous classification of images + * (i.e, when `runningMode = MPPRunningModeLiveStream`). + * + * The delegate of `MPPImageClassifier` must adopt `MPPImageClassifierLiveStreamDelegate` protocol. + * The methods in this protocol are optional. + */ +NS_SWIFT_NAME(ImageClassifierLiveStreamDelegate) +@protocol MPPImageClassifierLiveStreamDelegate + +@optional +/** + * This method notifies a delegate that the results of asynchronous classification of + * an image submitted to the `MPPImageClassifier` is available. + * + * This method is called on a private serial queue created by the `MPPImageClassifier` + * for performing the asynchronous delegates calls. + * + * @param imageClassifier The image classifier which performed the classification. + * This is useful to test equality when there are multiple instances of `MPPImageClassifier`. + * @param result An `MPPImageClassifierResult` object that contains a list of image classifications. + * @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input + * image was sent to the image classifier. + * @param error An optional error parameter populated when there is an error in performing image + * classification on the input live stream image data. + * + */ +- (void)imageClassifier:(MPPImageClassifier *)imageClassifier + didFinishClassificationWithResult:(nullable MPPImageClassifierResult *)result + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(nullable NSError *)error + NS_SWIFT_NAME(imageClassifier(_:didFinishClassification:timestampInMilliseconds:error:)); +@end + /** * Options for setting up a `MPPImageClassifier`. */ NS_SWIFT_NAME(ImageClassifierOptions) @interface MPPImageClassifierOptions : MPPTaskOptions +/** + * Running mode of the image classifier task. Defaults to `MPPRunningModeImage`. + * `MPPImageClassifier` can be created with one of the following running modes: + * 1. `MPPRunningModeImage`: The mode for performing classification on single image inputs. + * 2. `MPPRunningModeVideo`: The mode for performing classification on the decoded frames of a + * video. + * 3. `MPPRunningModeLiveStream`: The mode for performing classification on a live stream of input + * data, such as from the camera. + */ @property(nonatomic) MPPRunningMode runningMode; /** - * The user-defined result callback for processing live stream data. The result callback should only - * be specified when the running mode is set to the live stream mode. - * TODO: Add parameter `MPPImage` in the callback. + * An object that confirms to `MPPImageClassifierLiveStreamDelegate` protocol. This object must + * implement `objectDetector:didFinishDetectionWithResult:timestampInMilliseconds:error:` to receive + * the results of asynchronous classification on images (i.e, when `runningMode = + * MPPRunningModeLiveStream`). */ -@property(nonatomic, copy) void (^completion)(MPPImageClassifierResult *result, NSError *error); +@property(nonatomic, weak, nullable) id + imageClassifierLiveStreamDelegate; /** * The locale to use for display names specified through the TFLite Model Metadata, if any. Defaults diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.m b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.m index e109dcc3b..8d3815ff3 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.m +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.m @@ -33,7 +33,7 @@ imageClassifierOptions.categoryDenylist = self.categoryDenylist; imageClassifierOptions.categoryAllowlist = self.categoryAllowlist; imageClassifierOptions.displayNamesLocale = self.displayNamesLocale; - imageClassifierOptions.completion = self.completion; + imageClassifierOptions.imageClassifierLiveStreamDelegate = self.imageClassifierLiveStreamDelegate; return imageClassifierOptions; }