diff --git a/mediapipe/tasks/ios/test/vision/image_classifier/ImageClassifierTests.swift b/mediapipe/tasks/ios/test/vision/image_classifier/ImageClassifierTests.swift index 3de431078..62bf0d487 100644 --- a/mediapipe/tasks/ios/test/vision/image_classifier/ImageClassifierTests.swift +++ b/mediapipe/tasks/ios/test/vision/image_classifier/ImageClassifierTests.swift @@ -499,6 +499,7 @@ class ImageClassifierTests: XCTestCase { imageClassifierOptions.runningMode = runningMode imageClassifierOptions.completion = {( result: ImageClassifierResult?, + timestampMs: Int, error: Error?) -> () in } @@ -620,6 +621,7 @@ class ImageClassifierTests: XCTestCase { imageClassifierOptions.runningMode = .liveStream imageClassifierOptions.completion = {( result: ImageClassifierResult?, + timestampMs: Int, error: Error?) -> () in } @@ -687,4 +689,125 @@ class ImageClassifierTests: XCTestCase { ) } } + + func testClassifyWithOutOfOrderTimestampsAndLiveStreamModeSucceeds() throws { + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath( + ImageClassifierTests.floatModelPath)) + + imageClassifierOptions.runningMode = .liveStream + + let maxResults = 3 + imageClassifierOptions.maxResults = maxResults + + let expectation = expectation( + description: "classifyWithOutOfOrderTimestampsAndLiveStream") + expectation.expectedFulfillmentCount = 1; + + imageClassifierOptions.completion = {( + result: ImageClassifierResult?, + timestampMs: Int, + error: Error?) -> () in + do { + try self.assertImageClassifierResult( + try XCTUnwrap(result), + hasCategoryCount: maxResults, + andCategories: + ImageClassifierTests + .expectedResultsClassifyBurgerImageWithFloatModel) + } + catch { + // Any errors will be thrown by the wait() method of the expectation. + } + expectation.fulfill() + } + + let imageClassifier = try XCTUnwrap(ImageClassifier(options: + imageClassifierOptions)) + + let mpImage = try XCTUnwrap( + imageWithFileInfo(ImageClassifierTests.burgerImage)) + + XCTAssertNoThrow( + try imageClassifier.classifyAsync( + image: mpImage, + timestampMs: 100)) + + XCTAssertThrowsError( + try imageClassifier.classifyAsync( + image: mpImage, + timestampMs: 0)) {(error) in + assertEqualErrorDescriptions( + error, + expectedLocalizedDescription: """ + INVALID_ARGUMENT: Input timestamp must be monotonically \ + increasing. + """) + } + + wait(for:[expectation], timeout: 0.1) + } + + func testClassifyWithLiveStreamModeSucceeds() throws { + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath( + ImageClassifierTests.floatModelPath)) + + imageClassifierOptions.runningMode = .liveStream + + let maxResults = 3 + imageClassifierOptions.maxResults = maxResults + + let iterationCount = 100; + + // Because of flow limiting, we cannot ensure that the callback will be + // invoked `iterationCount` times. + // An normal expectation will fail if expectation.fullfill() is not called + // `expectation.expectedFulfillmentCount` times. + // If `expectation.isInverted = true`, the test will only succeed if + // expectation is not fullfilled for the specified `expectedFulfillmentCount`. + // Since in our case we cannot predict how many times the expectation is + // supposed to be fullfilled setting, + // `expectation.expectedFulfillmentCount` = `iterationCount` and + // `expectation.isInverted = true` ensures that test succeeds if + // expectation is not fullfilled `iterationCount` times. + let expectation = expectation(description: "liveStreamClassify") + expectation.expectedFulfillmentCount = iterationCount; + expectation.isInverted = true; + + imageClassifierOptions.completion = {( + result: ImageClassifierResult?, + timestampMs: Int, + error: Error?) -> () in + do { + try self.assertImageClassifierResult( + try XCTUnwrap(result), + hasCategoryCount: maxResults, + andCategories: + ImageClassifierTests + .expectedResultsClassifyBurgerImageWithFloatModel) + } + catch { + // Any errors will be thrown by the wait() method of the expectation. + } + expectation.fulfill() + } + + let imageClassifier = try XCTUnwrap(ImageClassifier(options: + imageClassifierOptions)) + + let mpImage = try XCTUnwrap( + imageWithFileInfo(ImageClassifierTests.burgerImage)) + + for i in 0.. status_or_packets) { NSError *callbackError = nil; - MPPImageClassifierResult *result; - if ([MPPCommonUtils checkCppError:status_or_packets.status() toError:&callbackError]) { - result = [MPPImageClassifierResult - imageClassifierResultWithClassificationsPacket: - status_or_packets.value()[kClassificationsStreamName.cppString]]; + if (![MPPCommonUtils checkCppError:status_or_packets.status() toError:&callbackError]) { + options.completion(nil, Timestamp::Unset().Value(), callbackError); + return; } - options.completion(result, callbackError); + + PacketMap &outputPacketMap = status_or_packets.value(); + if (outputPacketMap[kImageOutStreamName.cppString].IsEmpty()) { + return; + } + + MPPImageClassifierResult *result = [MPPImageClassifierResult + imageClassifierResultWithClassificationsPacket: + outputPacketMap[kClassificationsStreamName.cppString]]; + + options.completion(result, outputPacketMap[kImageOutStreamName.cppString].Timestamp().Value() / + kMicroSecondsPerMilliSecond, callbackError); }; } diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h index 2e6022041..a79241711 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h @@ -33,7 +33,7 @@ NS_SWIFT_NAME(ImageClassifierOptions) * be specified when the running mode is set to the live stream mode. * TODO: Add parameter `MPPImage` in the callback. */ -@property(nonatomic, copy) void (^completion)(MPPImageClassifierResult *result, NSError *error); +@property(nonatomic, copy) void (^completion)(MPPImageClassifierResult *result, NSInteger timestmapMs, NSError *error); /** * The locale to use for display names specified through the TFLite Model Metadata, if any. Defaults diff --git a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h index 0375ac2a5..68d939f45 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h +++ b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h @@ -18,6 +18,8 @@ NS_ASSUME_NONNULL_BEGIN +static const int kMicroSecondsPerMilliSecond = 1000; + @interface MPPImageClassifierResult (Helpers) /** @@ -28,7 +30,7 @@ NS_ASSUME_NONNULL_BEGIN * * @return An `MPPImageClassifierResult` object that contains a list of image classifications. */ -+ (MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket: ++ (nullable MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket: (const mediapipe::Packet &)packet; @end diff --git a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm index 09e21b278..03fbdd793 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm +++ b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm @@ -17,8 +17,6 @@ #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" -static const int kMicroSecondsPerMilliSecond = 1000; - namespace { using ClassificationResultProto = ::mediapipe::tasks::components::containers::proto::ClassificationResult; @@ -27,10 +25,18 @@ using ::mediapipe::Packet; @implementation MPPImageClassifierResult (Helpers) -+ (MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket: ++ (nullable MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket: (const Packet &)packet { - MPPClassificationResult *classificationResult = [MPPClassificationResult - classificationResultWithProto:packet.Get()]; + + MPPClassificationResult *classificationResult; + MPPImageClassifierResult *imageClassifierResult; + + if (!packet.ValidateAsType().ok()) { + return nil; + } + + classificationResult = [MPPClassificationResult + classificationResultWithProto:packet.Get()]; return [[MPPImageClassifierResult alloc] initWithClassificationResult:classificationResult