Added flow limiter capability to callback in iOS Image Classifier

This commit is contained in:
Prianka Liz Kariat 2023-03-29 20:52:56 +05:30
parent 3bd8b75bc5
commit b354795d00
6 changed files with 191 additions and 19 deletions

View File

@ -499,6 +499,7 @@ class ImageClassifierTests: XCTestCase {
imageClassifierOptions.runningMode = runningMode
imageClassifierOptions.completion = {(
result: ImageClassifierResult?,
timestampMs: Int,
error: Error?) -> () in
}
@ -620,6 +621,7 @@ class ImageClassifierTests: XCTestCase {
imageClassifierOptions.runningMode = .liveStream
imageClassifierOptions.completion = {(
result: ImageClassifierResult?,
timestampMs: Int,
error: Error?) -> () in
}
@ -687,4 +689,125 @@ class ImageClassifierTests: XCTestCase {
)
}
}
func testClassifyWithOutOfOrderTimestampsAndLiveStreamModeSucceeds() throws {
let imageClassifierOptions =
try XCTUnwrap(
imageClassifierOptionsWithModelPath(
ImageClassifierTests.floatModelPath))
imageClassifierOptions.runningMode = .liveStream
let maxResults = 3
imageClassifierOptions.maxResults = maxResults
let expectation = expectation(
description: "classifyWithOutOfOrderTimestampsAndLiveStream")
expectation.expectedFulfillmentCount = 1;
imageClassifierOptions.completion = {(
result: ImageClassifierResult?,
timestampMs: Int,
error: Error?) -> () in
do {
try self.assertImageClassifierResult(
try XCTUnwrap(result),
hasCategoryCount: maxResults,
andCategories:
ImageClassifierTests
.expectedResultsClassifyBurgerImageWithFloatModel)
}
catch {
// Any errors will be thrown by the wait() method of the expectation.
}
expectation.fulfill()
}
let imageClassifier = try XCTUnwrap(ImageClassifier(options:
imageClassifierOptions))
let mpImage = try XCTUnwrap(
imageWithFileInfo(ImageClassifierTests.burgerImage))
XCTAssertNoThrow(
try imageClassifier.classifyAsync(
image: mpImage,
timestampMs: 100))
XCTAssertThrowsError(
try imageClassifier.classifyAsync(
image: mpImage,
timestampMs: 0)) {(error) in
assertEqualErrorDescriptions(
error,
expectedLocalizedDescription: """
INVALID_ARGUMENT: Input timestamp must be monotonically \
increasing.
""")
}
wait(for:[expectation], timeout: 0.1)
}
func testClassifyWithLiveStreamModeSucceeds() throws {
let imageClassifierOptions =
try XCTUnwrap(
imageClassifierOptionsWithModelPath(
ImageClassifierTests.floatModelPath))
imageClassifierOptions.runningMode = .liveStream
let maxResults = 3
imageClassifierOptions.maxResults = maxResults
let iterationCount = 100;
// Because of flow limiting, we cannot ensure that the callback will be
// invoked `iterationCount` times.
// An normal expectation will fail if expectation.fullfill() is not called
// `expectation.expectedFulfillmentCount` times.
// If `expectation.isInverted = true`, the test will only succeed if
// expectation is not fullfilled for the specified `expectedFulfillmentCount`.
// Since in our case we cannot predict how many times the expectation is
// supposed to be fullfilled setting,
// `expectation.expectedFulfillmentCount` = `iterationCount` and
// `expectation.isInverted = true` ensures that test succeeds if
// expectation is not fullfilled `iterationCount` times.
let expectation = expectation(description: "liveStreamClassify")
expectation.expectedFulfillmentCount = iterationCount;
expectation.isInverted = true;
imageClassifierOptions.completion = {(
result: ImageClassifierResult?,
timestampMs: Int,
error: Error?) -> () in
do {
try self.assertImageClassifierResult(
try XCTUnwrap(result),
hasCategoryCount: maxResults,
andCategories:
ImageClassifierTests
.expectedResultsClassifyBurgerImageWithFloatModel)
}
catch {
// Any errors will be thrown by the wait() method of the expectation.
}
expectation.fulfill()
}
let imageClassifier = try XCTUnwrap(ImageClassifier(options:
imageClassifierOptions))
let mpImage = try XCTUnwrap(
imageWithFileInfo(ImageClassifierTests.burgerImage))
for i in 0..<iterationCount {
XCTAssertNoThrow(
try imageClassifier.classifyAsync(
image: mpImage,
timestampMs: i))
}
wait(for:[expectation], timeout: 0.5)
}
}

View File

@ -445,7 +445,8 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
MPPImageClassifierOptions *options = [self imageClassifierOptionsWithModelName:kFloatModelName];
options.runningMode = runningModesToTest[i];
options.completion = ^(MPPImageClassifierResult *result, NSError *error) {
options.completion =
^(MPPImageClassifierResult *result, NSInteger timestampMs, NSError *error) {
};
[self
@ -554,7 +555,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
MPPImageClassifierOptions *options = [self imageClassifierOptionsWithModelName:kFloatModelName];
options.runningMode = MPPRunningModeLiveStream;
options.completion = ^(MPPImageClassifierResult *result, NSError *error) {
options.completion = ^(MPPImageClassifierResult *result, NSInteger timestampMs, NSError *error) {
};
@ -617,13 +618,19 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
NSInteger maxResults = 3;
options.maxResults = maxResults;
XCTestExpectation *expectation = [[XCTestExpectation alloc]
initWithDescription:@"classifyWithOutOfOrderTimestampsAndLiveStream"];
expectation.expectedFulfillmentCount = 1;
options.runningMode = MPPRunningModeLiveStream;
options.completion = ^(MPPImageClassifierResult *result, NSError *error) {
options.completion = ^(MPPImageClassifierResult *result, NSInteger timestampMs, NSError *error) {
[self assertImageClassifierResult:result
hasExpectedCategoriesCount:maxResults
expectedCategories:
[MPPImageClassifierTests
expectedResultCategoriesForClassifyBurgerImageWithFloatModel]];
[expectation fulfill];
};
MPPImageClassifier *imageClassifier = [self imageClassifierWithOptionsSucceeds:options];
@ -643,6 +650,8 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
@"INVALID_ARGUMENT: Input timestamp must be monotonically increasing."
}];
AssertEqualErrors(error, expectedError);
[self waitForExpectations:@[ expectation ] timeout:0.1];
}
- (void)testClassifyWithLiveStreamModeSucceeds {
@ -651,13 +660,33 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
NSInteger maxResults = 3;
options.maxResults = maxResults;
NSInteger iterationCount = 100;
// Because of flow limiting, we cannot ensure that the callback will be
// invoked `iterationCount` times.
// An normal expectation will fail if expectation.fullfill() is not called
// `expectation.expectedFulfillmentCount` times.
// If `expectation.isInverted = true`, the test will only succeed if
// expectation is not fullfilled for the specified `expectedFulfillmentCount`.
// Since in our case we cannot predict how many times the expectation is
// supposed to be fullfilled setting,
// `expectation.expectedFulfillmentCount` = `iterationCount` and
// `expectation.isInverted = true` ensures that test succeeds if
// expectation is not fullfilled `iterationCount` times.
XCTestExpectation *expectation =
[[XCTestExpectation alloc] initWithDescription:@"classifyWithLiveStream"];
expectation.expectedFulfillmentCount = iterationCount;
expectation.inverted = YES;
options.runningMode = MPPRunningModeLiveStream;
options.completion = ^(MPPImageClassifierResult *result, NSError *error) {
options.completion = ^(MPPImageClassifierResult *result, NSInteger timestampMs, NSError *error) {
[self assertImageClassifierResult:result
hasExpectedCategoriesCount:maxResults
expectedCategories:
[MPPImageClassifierTests
expectedResultCategoriesForClassifyBurgerImageWithFloatModel]];
[expectation fulfill];
};
MPPImageClassifier *imageClassifier = [self imageClassifierWithOptionsSucceeds:options];
@ -667,9 +696,11 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
// `CMSampleBuffer`.
MPPImage *image = [self imageWithFileInfo:kBurgerImage];
for (int i = 0; i < 3; i++) {
for (int i = 0; i < iterationCount; i++) {
XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampMs:i error:nil]);
}
[self waitForExpectations:@[ expectation ] timeout:0.5];
}
@end

View File

@ -27,6 +27,7 @@
namespace {
using ::mediapipe::NormalizedRect;
using ::mediapipe::Packet;
using ::mediapipe::Timestamp;
using ::mediapipe::tasks::core::PacketMap;
using ::mediapipe::tasks::core::PacketsCallback;
} // namespace
@ -84,13 +85,22 @@ static NSString *const kTaskGraphName =
if (options.completion) {
packetsCallback = [=](absl::StatusOr<PacketMap> status_or_packets) {
NSError *callbackError = nil;
MPPImageClassifierResult *result;
if ([MPPCommonUtils checkCppError:status_or_packets.status() toError:&callbackError]) {
result = [MPPImageClassifierResult
imageClassifierResultWithClassificationsPacket:
status_or_packets.value()[kClassificationsStreamName.cppString]];
if (![MPPCommonUtils checkCppError:status_or_packets.status() toError:&callbackError]) {
options.completion(nil, Timestamp::Unset().Value(), callbackError);
return;
}
options.completion(result, callbackError);
PacketMap &outputPacketMap = status_or_packets.value();
if (outputPacketMap[kImageOutStreamName.cppString].IsEmpty()) {
return;
}
MPPImageClassifierResult *result = [MPPImageClassifierResult
imageClassifierResultWithClassificationsPacket:
outputPacketMap[kClassificationsStreamName.cppString]];
options.completion(result, outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() /
kMicroSecondsPerMilliSecond, callbackError);
};
}

View File

@ -33,7 +33,7 @@ NS_SWIFT_NAME(ImageClassifierOptions)
* be specified when the running mode is set to the live stream mode.
* TODO: Add parameter `MPPImage` in the callback.
*/
@property(nonatomic, copy) void (^completion)(MPPImageClassifierResult *result, NSError *error);
@property(nonatomic, copy) void (^completion)(MPPImageClassifierResult *result, NSInteger timestmapMs, NSError *error);
/**
* The locale to use for display names specified through the TFLite Model Metadata, if any. Defaults

View File

@ -18,6 +18,8 @@
NS_ASSUME_NONNULL_BEGIN
static const int kMicroSecondsPerMilliSecond = 1000;
@interface MPPImageClassifierResult (Helpers)
/**
@ -28,7 +30,7 @@ NS_ASSUME_NONNULL_BEGIN
*
* @return An `MPPImageClassifierResult` object that contains a list of image classifications.
*/
+ (MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket:
+ (nullable MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket:
(const mediapipe::Packet &)packet;
@end

View File

@ -17,8 +17,6 @@
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
static const int kMicroSecondsPerMilliSecond = 1000;
namespace {
using ClassificationResultProto =
::mediapipe::tasks::components::containers::proto::ClassificationResult;
@ -27,9 +25,17 @@ using ::mediapipe::Packet;
@implementation MPPImageClassifierResult (Helpers)
+ (MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket:
+ (nullable MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket:
(const Packet &)packet {
MPPClassificationResult *classificationResult = [MPPClassificationResult
MPPClassificationResult *classificationResult;
MPPImageClassifierResult *imageClassifierResult;
if (!packet.ValidateAsType<ClassificationResultProto>().ok()) {
return nil;
}
classificationResult = [MPPClassificationResult
classificationResultWithProto:packet.Get<ClassificationResultProto>()];
return [[MPPImageClassifierResult alloc]