Updated MPPImageClassifier to use delegates instead of completion blocks for callback.

This commit is contained in:
Prianka Liz Kariat 2023-05-04 16:43:18 +05:30
parent 1323a5271c
commit ab4b07646c
5 changed files with 205 additions and 58 deletions

View File

@ -27,6 +27,8 @@ static NSDictionary *const kMultiObjectsRotatedImage =
@{@"name" : @"multi_objects_rotated", @"type" : @"jpg"}; @{@"name" : @"multi_objects_rotated", @"type" : @"jpg"};
static const int kMobileNetCategoriesCount = 1001; static const int kMobileNetCategoriesCount = 1001;
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
static NSString *const kLiveStreamTestsDictImageClassifierKey = @"image_classifier";
static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
#define AssertEqualErrors(error, expectedError) \ #define AssertEqualErrors(error, expectedError) \
XCTAssertNotNil(error); \ XCTAssertNotNil(error); \
@ -54,11 +56,14 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
XCTAssertEqual(imageClassifierResult.classificationResult.classifications.count, 1); \ XCTAssertEqual(imageClassifierResult.classificationResult.classifications.count, 1); \
XCTAssertEqual(imageClassifierResult.classificationResult.classifications[0].headIndex, 0); XCTAssertEqual(imageClassifierResult.classificationResult.classifications[0].headIndex, 0);
@interface MPPImageClassifierTests : XCTestCase @interface MPPImageClassifierTests : XCTestCase <MPPImageClassifierLiveStreamDelegate> {
NSDictionary *liveStreamSucceedsTestDict;
NSDictionary *outOfOrderTimestampTestDict;
}
@end @end
@implementation MPPImageClassifierTests @implementation MPPImageClassifierTests
#pragma mark Results #pragma mark Results
+ (NSArray<MPPCategory *> *)expectedResultCategoriesForClassifyBurgerImageWithFloatModel { + (NSArray<MPPCategory *> *)expectedResultCategoriesForClassifyBurgerImageWithFloatModel {
@ -436,42 +441,42 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
#pragma mark Running Mode Tests #pragma mark Running Mode Tests
- (void)testCreateImageClassifierFailsWithResultListenerInNonLiveStreamMode { - (void)testCreateImageClassifierFailsWithDelegateInNonLiveStreamMode {
MPPRunningMode runningModesToTest[] = {MPPRunningModeImage, MPPRunningModeVideo}; MPPRunningMode runningModesToTest[] = {MPPRunningModeImage, MPPRunningModeVideo};
for (int i = 0; i < sizeof(runningModesToTest) / sizeof(runningModesToTest[0]); i++) { for (int i = 0; i < sizeof(runningModesToTest) / sizeof(runningModesToTest[0]); i++) {
MPPImageClassifierOptions *options = [self imageClassifierOptionsWithModelName:kFloatModelName]; MPPImageClassifierOptions *options = [self imageClassifierOptionsWithModelName:kFloatModelName];
options.runningMode = runningModesToTest[i]; options.runningMode = runningModesToTest[i];
options.completion = ^(MPPImageClassifierResult *result, NSError *error) { options.imageClassifierLiveStreamDelegate = self;
};
[self [self
assertCreateImageClassifierWithOptions:options assertCreateImageClassifierWithOptions:options
failsWithExpectedError: failsWithExpectedError:
[NSError [NSError errorWithDomain:kExpectedErrorDomain
errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{ userInfo:@{
NSLocalizedDescriptionKey : NSLocalizedDescriptionKey :
@"The vision task is in image or video mode, a " @"The vision task is in image or video mode. The "
@"user-defined result callback should not be provided." @"delegate must not be set in the task's options."
}]]; }]];
} }
} }
- (void)testCreateImageClassifierFailsWithMissingResultListenerInLiveStreamMode { - (void)testCreateImageClassifierFailsWithMissingDelegateInLiveStreamMode {
MPPImageClassifierOptions *options = [self imageClassifierOptionsWithModelName:kFloatModelName]; MPPImageClassifierOptions *options = [self imageClassifierOptionsWithModelName:kFloatModelName];
options.runningMode = MPPRunningModeLiveStream; options.runningMode = MPPRunningModeLiveStream;
[self assertCreateImageClassifierWithOptions:options [self assertCreateImageClassifierWithOptions:options
failsWithExpectedError: failsWithExpectedError:
[NSError errorWithDomain:kExpectedErrorDomain [NSError
errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{ userInfo:@{
NSLocalizedDescriptionKey : NSLocalizedDescriptionKey :
@"The vision task is in live stream mode, a " @"The vision task is in live stream mode. An object "
@"user-defined result callback must be provided." @"must be set as the delegate of the task in its "
@"options to ensure asynchronous delivery of results."
}]]; }]];
} }
@ -553,9 +558,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
MPPImageClassifierOptions *options = [self imageClassifierOptionsWithModelName:kFloatModelName]; MPPImageClassifierOptions *options = [self imageClassifierOptionsWithModelName:kFloatModelName];
options.runningMode = MPPRunningModeLiveStream; options.runningMode = MPPRunningModeLiveStream;
options.completion = ^(MPPImageClassifierResult *result, NSError *error) { options.imageClassifierLiveStreamDelegate = self;
};
MPPImageClassifier *imageClassifier = [self imageClassifierWithOptionsSucceeds:options]; MPPImageClassifier *imageClassifier = [self imageClassifierWithOptionsSucceeds:options];
@ -619,16 +622,20 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
options.maxResults = maxResults; options.maxResults = maxResults;
options.runningMode = MPPRunningModeLiveStream; options.runningMode = MPPRunningModeLiveStream;
options.completion = ^(MPPImageClassifierResult *result, NSError *error) { options.imageClassifierLiveStreamDelegate = self;
[self assertImageClassifierResult:result
hasExpectedCategoriesCount:maxResults XCTestExpectation *expectation = [[XCTestExpectation alloc]
expectedCategories: initWithDescription:@"classifyWithOutOfOrderTimestampsAndLiveStream"];
[MPPImageClassifierTests
expectedResultCategoriesForClassifyBurgerImageWithFloatModel]]; expectation.expectedFulfillmentCount = 1;
};
MPPImageClassifier *imageClassifier = [self imageClassifierWithOptionsSucceeds:options]; MPPImageClassifier *imageClassifier = [self imageClassifierWithOptionsSucceeds:options];
outOfOrderTimestampTestDict = @{
kLiveStreamTestsDictImageClassifierKey : imageClassifier,
kLiveStreamTestsDictExpectationKey : expectation
};
MPPImage *image = [self imageWithFileInfo:kBurgerImage]; MPPImage *image = [self imageWithFileInfo:kBurgerImage];
XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:1 error:nil]); XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:1 error:nil]);
@ -644,6 +651,8 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
@"INVALID_ARGUMENT: Input timestamp must be monotonically increasing." @"INVALID_ARGUMENT: Input timestamp must be monotonically increasing."
}]; }];
AssertEqualErrors(error, expectedError); AssertEqualErrors(error, expectedError);
[self waitForExpectations:@[ expectation ] timeout:1e-2f];
} }
- (void)testClassifyWithLiveStreamModeSucceeds { - (void)testClassifyWithLiveStreamModeSucceeds {
@ -653,24 +662,63 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
options.maxResults = maxResults; options.maxResults = maxResults;
options.runningMode = MPPRunningModeLiveStream; options.runningMode = MPPRunningModeLiveStream;
options.completion = ^(MPPImageClassifierResult *result, NSError *error) { options.imageClassifierLiveStreamDelegate = self;
[self assertImageClassifierResult:result
hasExpectedCategoriesCount:maxResults NSInteger iterationCount = 100;
expectedCategories:
[MPPImageClassifierTests // Because of flow limiting, we cannot ensure that the callback will be
expectedResultCategoriesForClassifyBurgerImageWithFloatModel]]; // invoked `iterationCount` times.
}; // An normal expectation will fail if expectation.fullfill() is not called
// `expectation.expectedFulfillmentCount` times.
// If `expectation.isInverted = true`, the test will only succeed if
// expectation is not fullfilled for the specified `expectedFulfillmentCount`.
// Since in our case we cannot predict how many times the expectation is
// supposed to be fullfilled setting,
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
// `expectation.isInverted = true` ensures that test succeeds if
// expectation is fullfilled <= `iterationCount` times.
XCTestExpectation *expectation =
[[XCTestExpectation alloc] initWithDescription:@"classifyWithLiveStream"];
expectation.expectedFulfillmentCount = iterationCount + 1;
expectation.inverted = YES;
MPPImageClassifier *imageClassifier = [self imageClassifierWithOptionsSucceeds:options]; MPPImageClassifier *imageClassifier = [self imageClassifierWithOptionsSucceeds:options];
liveStreamSucceedsTestDict = @{
kLiveStreamTestsDictImageClassifierKey : imageClassifier,
kLiveStreamTestsDictExpectationKey : expectation
};
// TODO: Mimic initialization from CMSampleBuffer as live stream mode is most likely to be used // TODO: Mimic initialization from CMSampleBuffer as live stream mode is most likely to be used
// with the iOS camera. AVCaptureVideoDataOutput sample buffer delegates provide frames of type // with the iOS camera. AVCaptureVideoDataOutput sample buffer delegates provide frames of type
// `CMSampleBuffer`. // `CMSampleBuffer`.
MPPImage *image = [self imageWithFileInfo:kBurgerImage]; MPPImage *image = [self imageWithFileInfo:kBurgerImage];
for (int i = 0; i < 3; i++) { for (int i = 0; i < iterationCount; i++) {
XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:i error:nil]); XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:i error:nil]);
} }
[self waitForExpectations:@[ expectation ] timeout:1e-2f];
}
- (void)imageClassifier:(MPPImageClassifier *)imageClassifier
didFinishClassificationWithResult:(MPPImageClassifierResult *)imageClassifierResult
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError *)error {
NSInteger maxResults = 3;
[self assertImageClassifierResult:imageClassifierResult
hasExpectedCategoriesCount:maxResults
expectedCategories:
[MPPImageClassifierTests
expectedResultCategoriesForClassifyBurgerImageWithFloatModel]];
if (imageClassifier == outOfOrderTimestampTestDict[kLiveStreamTestsDictImageClassifierKey]) {
[outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
} else if (imageClassifier ==
liveStreamSucceedsTestDict[kLiveStreamTestsDictImageClassifierKey]) {
[liveStreamSucceedsTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
}
} }
@end @end

View File

@ -164,8 +164,11 @@ NS_SWIFT_NAME(ImageClassifier)
* Sends live stream image data of type `MPPImage` to perform image classification using the whole * Sends live stream image data of type `MPPImage` to perform image classification using the whole
* image as region of interest. Rotation will be applied according to the `orientation` property of * image as region of interest. Rotation will be applied according to the `orientation` property of
* the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with * the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with
* `MPPRunningModeLiveStream`. Results are provided asynchronously via the `completion` callback * `MPPRunningModeLiveStream`.
* provided in the `MPPImageClassifierOptions`. * The object which needs to be continuously notified of the available results of image
* classification must confirm to `MPPImageClassifierLiveStreamDelegate` protocol and implement the
* `imageClassifier:didFinishClassificationWithResult:timestampInMilliseconds:error:`
* delegate method.
* *
* It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent * It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent
* to the image classifier. The input timestamps must be monotonically increasing. * to the image classifier. The input timestamps must be monotonically increasing.
@ -185,11 +188,14 @@ NS_SWIFT_NAME(ImageClassifier)
NS_SWIFT_NAME(classifyAsync(image:timestampInMilliseconds:)); NS_SWIFT_NAME(classifyAsync(image:timestampInMilliseconds:));
/** /**
* Sends live stream image data of type `MPPImage` to perform image classification, cropped to the * Sends live stream image data of type ``MPPImage`` to perform image classification, cropped to the
* specified region of interest.. Rotation will be applied according to the `orientation` property * specified region of interest.. Rotation will be applied according to the `orientation` property
* of the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with * of the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with
* `MPPRunningModeLiveStream`. Results are provided asynchronously via the `completion` callback * `MPPRunningModeLiveStream`.
* provided in the `MPPImageClassifierOptions`. * The object which needs to be continuously notified of the available results of image
* classification must confirm to `MPPImageClassifierLiveStreamDelegate` protocol and implement the
* `imageClassifier:didFinishClassificationWithResult:timestampInMilliseconds:error:` delegate
* method.
* *
* It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent * It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent
* to the image classifier. The input timestamps must be monotonically increasing. * to the image classifier. The input timestamps must be monotonically increasing.

View File

@ -27,6 +27,7 @@
namespace { namespace {
using ::mediapipe::NormalizedRect; using ::mediapipe::NormalizedRect;
using ::mediapipe::Packet; using ::mediapipe::Packet;
using ::mediapipe::Timestamp;
using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::core::PacketMap;
using ::mediapipe::tasks::core::PacketsCallback; using ::mediapipe::tasks::core::PacketsCallback;
} // namespace } // namespace
@ -38,9 +39,9 @@ static NSString *const kImageOutStreamName = @"image_out";
static NSString *const kImageTag = @"IMAGE"; static NSString *const kImageTag = @"IMAGE";
static NSString *const kNormRectStreamName = @"norm_rect_in"; static NSString *const kNormRectStreamName = @"norm_rect_in";
static NSString *const kNormRectTag = @"NORM_RECT"; static NSString *const kNormRectTag = @"NORM_RECT";
static NSString *const kTaskGraphName = static NSString *const kTaskGraphName =
@"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; @"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
static NSString *const kTaskName = @"imageClassifier";
#define InputPacketMap(imagePacket, normalizedRectPacket) \ #define InputPacketMap(imagePacket, normalizedRectPacket) \
{ \ { \
@ -53,6 +54,8 @@ static NSString *const kTaskGraphName =
/** iOS Vision Task Runner */ /** iOS Vision Task Runner */
MPPVisionTaskRunner *_visionTaskRunner; MPPVisionTaskRunner *_visionTaskRunner;
} }
@property(nonatomic, weak) id<MPPImageClassifierLiveStreamDelegate>
imageClassifierLiveStreamDelegate;
@end @end
@implementation MPPImageClassifier @implementation MPPImageClassifier
@ -81,16 +84,58 @@ static NSString *const kTaskGraphName =
PacketsCallback packetsCallback = nullptr; PacketsCallback packetsCallback = nullptr;
if (options.completion) { if (options.imageClassifierLiveStreamDelegate) {
_imageClassifierLiveStreamDelegate = options.imageClassifierLiveStreamDelegate;
// Capturing `self` as weak in order to avoid `self` being kept in memory
// and cause a retain cycle, after self is set to `nil`.
MPPImageClassifier *__weak weakSelf = self;
// Create a private serial dispatch queue in which the deleagte method will be called
// asynchronously. This is to ensure that if the client performs a long running operation in
// the delegate method, the queue on which the C++ callbacks is invoked is not blocked and is
// freed up to continue with its operations.
const char *queueName = [MPPVisionTaskRunner uniqueDispatchQueueNameWithSuffix:kTaskName];
dispatch_queue_t callbackQueue = dispatch_queue_create(queueName, NULL);
packetsCallback = [=](absl::StatusOr<PacketMap> status_or_packets) { packetsCallback = [=](absl::StatusOr<PacketMap> status_or_packets) {
NSError *callbackError = nil; if (!weakSelf) {
MPPImageClassifierResult *result; return;
if ([MPPCommonUtils checkCppError:status_or_packets.status() toError:&callbackError]) {
result = [MPPImageClassifierResult
imageClassifierResultWithClassificationsPacket:
status_or_packets.value()[kClassificationsStreamName.cppString]];
} }
options.completion(result, callbackError); if (![weakSelf.imageClassifierLiveStreamDelegate
respondsToSelector:@selector
(imageClassifier:
didFinishClassificationWithResult:timestampInMilliseconds:error:)]) {
return;
}
NSError *callbackError = nil;
if (![MPPCommonUtils checkCppError:status_or_packets.status() toError:&callbackError]) {
dispatch_async(callbackQueue, ^{
[weakSelf.imageClassifierLiveStreamDelegate imageClassifier:weakSelf
didFinishClassificationWithResult:nil
timestampInMilliseconds:Timestamp::Unset().Value()
error:callbackError];
});
return;
}
PacketMap &outputPacketMap = status_or_packets.value();
if (outputPacketMap[kImageOutStreamName.cppString].IsEmpty()) {
return;
}
MPPImageClassifierResult *result =
[MPPImageClassifierResult imageClassifierResultWithClassificationsPacket:
outputPacketMap[kClassificationsStreamName.cppString]];
NSInteger timeStampInMilliseconds =
outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() /
kMicroSecondsPerMilliSecond;
dispatch_async(callbackQueue, ^{
[weakSelf.imageClassifierLiveStreamDelegate imageClassifier:weakSelf
didFinishClassificationWithResult:result
timestampInMilliseconds:timeStampInMilliseconds
error:callbackError];
});
}; };
} }

View File

@ -20,20 +20,68 @@
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
@class MPPImageClassifier;
/**
* This protocol defines an interface for the delegates of `MPPImageClassifier` object to receive
* results of asynchronous classification of images
* (i.e, when `runningMode = MPPRunningModeLiveStream`).
*
* The delegate of `MPPImageClassifier` must adopt `MPPImageClassifierLiveStreamDelegate` protocol.
* The methods in this protocol are optional.
*/
NS_SWIFT_NAME(ImageClassifierLiveStreamDelegate)
@protocol MPPImageClassifierLiveStreamDelegate <NSObject>
@optional
/**
* This method notifies a delegate that the results of asynchronous classification of
* an image submitted to the `MPPImageClassifier` is available.
*
* This method is called on a private serial queue created by the `MPPImageClassifier`
* for performing the asynchronous delegates calls.
*
* @param imageClassifier The image classifier which performed the classification.
* This is useful to test equality when there are multiple instances of `MPPImageClassifier`.
* @param result An `MPPImageClassifierResult` object that contains a list of image classifications.
* @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
* image was sent to the image classifier.
* @param error An optional error parameter populated when there is an error in performing image
* classification on the input live stream image data.
*
*/
- (void)imageClassifier:(MPPImageClassifier *)imageClassifier
didFinishClassificationWithResult:(nullable MPPImageClassifierResult *)result
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(nullable NSError *)error
NS_SWIFT_NAME(imageClassifier(_:didFinishClassification:timestampInMilliseconds:error:));
@end
/** /**
* Options for setting up a `MPPImageClassifier`. * Options for setting up a `MPPImageClassifier`.
*/ */
NS_SWIFT_NAME(ImageClassifierOptions) NS_SWIFT_NAME(ImageClassifierOptions)
@interface MPPImageClassifierOptions : MPPTaskOptions <NSCopying> @interface MPPImageClassifierOptions : MPPTaskOptions <NSCopying>
/**
* Running mode of the image classifier task. Defaults to `MPPRunningModeImage`.
* `MPPImageClassifier` can be created with one of the following running modes:
* 1. `MPPRunningModeImage`: The mode for performing classification on single image inputs.
* 2. `MPPRunningModeVideo`: The mode for performing classification on the decoded frames of a
* video.
* 3. `MPPRunningModeLiveStream`: The mode for performing classification on a live stream of input
* data, such as from the camera.
*/
@property(nonatomic) MPPRunningMode runningMode; @property(nonatomic) MPPRunningMode runningMode;
/** /**
* The user-defined result callback for processing live stream data. The result callback should only * An object that confirms to `MPPImageClassifierLiveStreamDelegate` protocol. This object must
* be specified when the running mode is set to the live stream mode. * implement `objectDetector:didFinishDetectionWithResult:timestampInMilliseconds:error:` to receive
* TODO: Add parameter `MPPImage` in the callback. * the results of asynchronous classification on images (i.e, when `runningMode =
* MPPRunningModeLiveStream`).
*/ */
@property(nonatomic, copy) void (^completion)(MPPImageClassifierResult *result, NSError *error); @property(nonatomic, weak, nullable) id<MPPImageClassifierLiveStreamDelegate>
imageClassifierLiveStreamDelegate;
/** /**
* The locale to use for display names specified through the TFLite Model Metadata, if any. Defaults * The locale to use for display names specified through the TFLite Model Metadata, if any. Defaults

View File

@ -33,7 +33,7 @@
imageClassifierOptions.categoryDenylist = self.categoryDenylist; imageClassifierOptions.categoryDenylist = self.categoryDenylist;
imageClassifierOptions.categoryAllowlist = self.categoryAllowlist; imageClassifierOptions.categoryAllowlist = self.categoryAllowlist;
imageClassifierOptions.displayNamesLocale = self.displayNamesLocale; imageClassifierOptions.displayNamesLocale = self.displayNamesLocale;
imageClassifierOptions.completion = self.completion; imageClassifierOptions.imageClassifierLiveStreamDelegate = self.imageClassifierLiveStreamDelegate;
return imageClassifierOptions; return imageClassifierOptions;
} }