From 0d4c365b78640c50322853912369cd57e8169ea2 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Tue, 28 Mar 2023 22:55:06 +0530 Subject: [PATCH] Added iOS Image Classifier Swift Tests --- .../ios/test/vision/image_classifier/BUILD | 31 + .../ImageClassifierTests.swift | 779 ++++++++++++++++++ .../vision/utils/sources/MPPImage+TestUtils.h | 4 +- 3 files changed, 812 insertions(+), 2 deletions(-) create mode 100644 mediapipe/tasks/ios/test/vision/image_classifier/ImageClassifierTests.swift diff --git a/mediapipe/tasks/ios/test/vision/image_classifier/BUILD b/mediapipe/tasks/ios/test/vision/image_classifier/BUILD index c274e6e2e..d4d3b37b9 100644 --- a/mediapipe/tasks/ios/test/vision/image_classifier/BUILD +++ b/mediapipe/tasks/ios/test/vision/image_classifier/BUILD @@ -1,4 +1,8 @@ load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") +load( + "@build_bazel_rules_swift//swift:swift.bzl", + "swift_library", +) load( "//mediapipe/tasks:ios/ios.bzl", "MPP_TASK_MINIMUM_OS_VERSION", @@ -53,3 +57,30 @@ ios_unit_test( ":MPPImageClassifierObjcTestLibrary", ], ) + +swift_library( + name = "MPPImageClassifierSwiftTestLibrary", + testonly = 1, + srcs = ["ImageClassifierTests.swift"], + data = [ + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + ], + tags = TFL_DEFAULT_TAGS, + deps = [ + "//mediapipe/tasks/ios/common:MPPCommon", + "//mediapipe/tasks/ios/test/vision/utils:MPPImageTestUtils", + "//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifier", + ], +) + +ios_unit_test( + name = "MPPImageClassifierSwiftTest", + minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":MPPImageClassifierSwiftTestLibrary", + ], +) + diff --git a/mediapipe/tasks/ios/test/vision/image_classifier/ImageClassifierTests.swift b/mediapipe/tasks/ios/test/vision/image_classifier/ImageClassifierTests.swift new file mode 100644 index 000000000..8c34a1ab2 --- /dev/null +++ b/mediapipe/tasks/ios/test/vision/image_classifier/ImageClassifierTests.swift @@ -0,0 +1,779 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import MPPCommon +import MPPImageTestUtils +import XCTest + +@testable import MPPImageClassifier + +typealias FileInfo = (name: String, type: String) + +class ImageClassifierTests: XCTestCase { + + static let bundle = Bundle(for: ImageClassifierTests.self) + + static let floatModelPath = bundle.path( + forResource: "mobilenet_v2_1.0_224", + ofType: "tflite") + + static let burgerImage = FileInfo(name: "burger", type: "jpg") + static let burgerRotatedImage = FileInfo(name: "burger_rotated", type: "jpg") + static let multiObjectsImage = FileInfo(name: "multi_objects", type: "jpg") + static let multiObjectsRotatedImage = FileInfo(name: "multi_objects_rotated", type: "jpg") + + static let mobileNetCategoriesCount: Int = 1001; + + static let expectedResultsClassifyBurgerImageWithFloatModel = [ + ResultCategory( + index: 934, + score: 0.786005, + categoryName: "cheeseburger", + displayName: nil), + ResultCategory( + index: 932, + score: 0.023508, + categoryName: "bagel", + displayName: nil), + ResultCategory( + index: 925, + score: 0.021172, + categoryName: "guacamole", + displayName: nil), + ] + + func assertEqualErrorDescriptions( + _ error: Error, expectedLocalizedDescription: String + ) { + XCTAssertEqual( + error.localizedDescription, + expectedLocalizedDescription) + } + + func assertCategoriesAreEqual( + category: ResultCategory, + expectedCategory: ResultCategory, + indexInCategoryList: Int + ) { + XCTAssertEqual( + category.index, + expectedCategory.index, + String( + format: """ + category[%d].index and expectedCategory[%d].index are not equal. + """, indexInCategoryList)) + XCTAssertEqual( + category.score, + expectedCategory.score, + accuracy: 1e-3, + String( + format: """ + category[%d].score and expectedCategory[%d].score are not equal. + """, indexInCategoryList)) + XCTAssertEqual( + category.categoryName, + expectedCategory.categoryName, + String( + format: """ + category[%d].categoryName and expectedCategory[%d].categoryName are \ + not equal. + """, indexInCategoryList)) + XCTAssertEqual( + category.displayName, + expectedCategory.displayName, + String( + format: """ + category[%d].displayName and expectedCategory[%d].displayName are \ + not equal. + """, indexInCategoryList)) + } + + func assertEqualCategoryArrays( + categoryArray: [ResultCategory], + expectedCategoryArray: [ResultCategory] + ) { + XCTAssertEqual( + categoryArray.count, + expectedCategoryArray.count) + + for (index, (category, expectedCategory)) in zip(categoryArray, expectedCategoryArray) + .enumerated() + { + assertCategoriesAreEqual( + category: category, + expectedCategory: expectedCategory, + indexInCategoryList: index) + } + } + + func assertImageClassifierResultHasOneHead( + _ imageClassifierResult: ImageClassifierResult + ) { + XCTAssertEqual(imageClassifierResult.classificationResult.classifications.count, 1) + XCTAssertEqual(imageClassifierResult.classificationResult.classifications[0].headIndex, 0) + } + + func imageClassifierOptionsWithModelPath( + _ modelPath: String? + ) throws -> ImageClassifierOptions { + let modelPath = try XCTUnwrap(modelPath) + + let imageClassifierOptions = ImageClassifierOptions() + imageClassifierOptions.baseOptions.modelAssetPath = modelPath + + return imageClassifierOptions + } + + func assertCreateImageClassifierThrowsError( + imageClassifierOptions: ImageClassifierOptions, + expectedErrorDescription: String + ) { + do { + let imageClassifier = try ImageClassifier(options: imageClassifierOptions) + XCTAssertNil(imageClassifier) + } catch { + assertEqualErrorDescriptions( + error, + expectedLocalizedDescription: expectedErrorDescription) + } + } + + func assertImageClassifierResult( + _ imageClassifierResult: ImageClassifierResult, + hasCategoryCount expectedCategoryCount: Int, + andCategories expectedCategories: [ResultCategory] + ) throws { + assertImageClassifierResultHasOneHead(imageClassifierResult) + let categories = imageClassifierResult.classificationResult.classifications[0].categories + + XCTAssertEqual(categories.count, expectedCategoryCount) + assertEqualCategoryArrays( + categoryArray: + Array(categories.prefix(expectedCategories.count)), + expectedCategoryArray: expectedCategories) + } + + func assertResultsForClassifyImage( + _ image: MPImage, + usingImageClassifier imageClassifier: ImageClassifier, + hasCategoryCount expectedCategoryCount: Int, + andCategories expectedCategories: [ResultCategory] + ) throws { + let imageClassifierResult = + try XCTUnwrap( + imageClassifier.classify(image: image)) + + try assertImageClassifierResult( + imageClassifierResult, + hasCategoryCount: expectedCategoryCount, + andCategories: expectedCategories + ) + } + + func assertResultsForClassifyImageWithFileInfo( + _ fileInfo: FileInfo, + usingImageClassifier imageClassifier: ImageClassifier, + hasCategoryCount expectedCategoryCount: Int, + andCategories expectedCategories: [ResultCategory] + ) throws { + let mpImage = try XCTUnwrap( + MPImage.imageFromBundle( + withClass: type(of: self), + filename: fileInfo.name, + type: fileInfo.type)) + + try assertResultsForClassifyImage( + mpImage, + usingImageClassifier: imageClassifier, + hasCategoryCount: expectedCategoryCount, + andCategories: expectedCategories + ) + } + + // func testCreateImageClassifierWithInvalidMaxResultsFails() throws { + // let textClassifierOptions = + // try XCTUnwrap( + // textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath)) + // textClassifierOptions.maxResults = 0 + + // assertCreateTextClassifierThrowsError( + // textClassifierOptions: textClassifierOptions, + // expectedErrorDescription: """ + // INVALID_ARGUMENT: Invalid `max_results` option: value must be != 0. + // """) + // } + + func testCreateImageClassifierWithCategoryAllowlistAndDenylistFails() throws { + + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath(ImageClassifierTests.floatModelPath)) + imageClassifierOptions.categoryAllowlist = ["bagel"] + imageClassifierOptions.categoryDenylist = ["guacamole"] + + assertCreateImageClassifierThrowsError( + imageClassifierOptions: imageClassifierOptions, + expectedErrorDescription: """ + INVALID_ARGUMENT: `category_allowlist` and `category_denylist` are \ + mutually exclusive options. + """) + } + + func testClassifyWithModelPathAndFloatModelSucceeds() throws { + + let modelPath = try XCTUnwrap(ImageClassifierTests.floatModelPath) + let imageClassifier = try XCTUnwrap(ImageClassifier(modelPath: modelPath)) + + try assertResultsForClassifyImageWithFileInfo( + ImageClassifierTests.burgerImage, + usingImageClassifier: imageClassifier, + hasCategoryCount: ImageClassifierTests.mobileNetCategoriesCount, + andCategories: ImageClassifierTests.expectedResultsClassifyBurgerImageWithFloatModel) + } + + func testClassifyWithOptionsAndFloatModelSucceeds() throws { + + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath(ImageClassifierTests.floatModelPath)) + + let imageClassifier = try XCTUnwrap(ImageClassifier(options: imageClassifierOptions)) + + try assertResultsForClassifyImageWithFileInfo( + ImageClassifierTests.burgerImage, + usingImageClassifier: imageClassifier, + hasCategoryCount: ImageClassifierTests.mobileNetCategoriesCount, + andCategories: ImageClassifierTests.expectedResultsClassifyBurgerImageWithFloatModel) + } + + func testClassifyWithScoreThresholdSucceeds() throws { + + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath(ImageClassifierTests.floatModelPath)) + imageClassifierOptions.scoreThreshold = 0.25 + + let imageClassifier = try XCTUnwrap(ImageClassifier(options: imageClassifierOptions)) + + let expectedCategories = [ + ResultCategory( + index: 934, + score: 0.786005, + categoryName: "cheeseburger", + displayName: nil), + ] + + try assertResultsForClassifyImageWithFileInfo( + ImageClassifierTests.burgerImage, + usingImageClassifier: imageClassifier, + hasCategoryCount: expectedCategories.count, + andCategories: expectedCategories) + } + + func testClassifyWithAllowlistSucceeds() throws { + + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath(ImageClassifierTests.floatModelPath)) + imageClassifierOptions.categoryAllowlist = ["cheeseburger", "guacamole", "meat loaf"] + + let imageClassifier = try XCTUnwrap(ImageClassifier(options: imageClassifierOptions)) + + let expectedCategories = [ + ResultCategory( + index: 934, + score: 0.786005, + categoryName: "cheeseburger", + displayName: nil), + ResultCategory( + index: 925, + score: 0.021172, + categoryName: "guacamole", + displayName: nil), + ResultCategory( + index: 963, + score: 0.006279315, + categoryName: "meat loaf", + displayName: nil), + ] + + try assertResultsForClassifyImageWithFileInfo( + ImageClassifierTests.burgerImage, + usingImageClassifier: imageClassifier, + hasCategoryCount: expectedCategories.count, + andCategories: expectedCategories) + } + + func testClassifyWithDenylistSucceeds() throws { + + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath(ImageClassifierTests.floatModelPath)) + imageClassifierOptions.categoryDenylist = ["bagel"] + + let maxResults = 3; + imageClassifierOptions.maxResults = maxResults; + + let imageClassifier = try XCTUnwrap(ImageClassifier(options: imageClassifierOptions)) + + let expectedCategories = [ + ResultCategory( + index: 934, + score: 0.786005, + categoryName: "cheeseburger", + displayName: nil), + ResultCategory( + index: 925, + score: 0.021172, + categoryName: "guacamole", + displayName: nil), + ResultCategory( + index: 963, + score: 0.006279315, + categoryName: "meat loaf", + displayName: nil), + ] + + try assertResultsForClassifyImageWithFileInfo( + ImageClassifierTests.burgerImage, + usingImageClassifier: imageClassifier, + hasCategoryCount: maxResults, + andCategories: expectedCategories) + } + + func testClassifyWithRegionOfInterestSucceeds() throws { + + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath(ImageClassifierTests.floatModelPath)) + + let maxResults = 1; + imageClassifierOptions.maxResults = maxResults; + + let imageClassifier = try XCTUnwrap(ImageClassifier(options: imageClassifierOptions)) + + let mpImage = try XCTUnwrap( + MPImage.imageFromBundle( + withClass: type(of: self), + filename: ImageClassifierTests.multiObjectsImage.name, + type: ImageClassifierTests.multiObjectsImage.type)) + + let imageClassifierResult = try XCTUnwrap( + imageClassifier.classify( + image: mpImage, + regionOfInterest: CGRect( + x: 0.450, + y: 0.308, + width: 0.164, + height: 0.426))) + + let expectedCategories = [ + ResultCategory( + index: 806, + score: 0.997122, + categoryName: "soccer ball", + displayName: nil), + ] + + + try assertImageClassifierResult( + imageClassifierResult, + hasCategoryCount: maxResults, + andCategories: expectedCategories) + } + + func testClassifyWithOrientationSucceeds() throws { + + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath(ImageClassifierTests.floatModelPath)) + + let maxResults = 3; + imageClassifierOptions.maxResults = maxResults; + + let imageClassifier = try XCTUnwrap(ImageClassifier(options: imageClassifierOptions)) + + let expectedCategories = [ + ResultCategory( + index: 934, + score: 0.622074, + categoryName: "cheeseburger", + displayName: nil), + ResultCategory( + index: 963, + score: 0.051214, + categoryName: "meat loaf", + displayName: nil), + ResultCategory( + index: 925, + score: 0.048719, + categoryName: "guacamole", + displayName: nil), + ] + + let mpImage = try XCTUnwrap( + MPImage.imageFromBundle( + withClass: type(of: self), + filename: ImageClassifierTests.burgerRotatedImage.name, + type: ImageClassifierTests.burgerRotatedImage.type, + orientation: .right)) + + try assertResultsForClassifyImage( + mpImage, + usingImageClassifier: imageClassifier, + hasCategoryCount: expectedCategories.count, + andCategories: expectedCategories) + } + + func testClassifyWithOrientationAndRegionOfInterestSucceeds() throws { + + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath(ImageClassifierTests.floatModelPath)) + + let maxResults = 3; + imageClassifierOptions.maxResults = maxResults; + + let imageClassifier = try XCTUnwrap(ImageClassifier(options: imageClassifierOptions)) + + let expectedCategories = [ + ResultCategory( + index: 560, + score: 0.682305, + categoryName: "folding chair", + displayName: nil), + ] + + let mpImage = try XCTUnwrap( + MPImage.imageFromBundle( + withClass: type(of: self), + filename: ImageClassifierTests.multiObjectsRotatedImage.name, + type: ImageClassifierTests.multiObjectsRotatedImage.type, + orientation: .right)) + + let imageClassifierResult = try XCTUnwrap( + imageClassifier.classify( + image: mpImage, + regionOfInterest: CGRect( + x: 0.0, + y: 0.1763, + width: 0.5642, + height: 0.1286))) + + + try assertImageClassifierResult( + imageClassifierResult, + hasCategoryCount: maxResults, + andCategories: expectedCategories) + } + + func testImageClassifierFailsWithResultListenerInNonLiveStreamMode() throws { + + let runningModesToTest = [RunningMode.image, RunningMode.video]; + + for runningMode in runningModesToTest { + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath( + ImageClassifierTests.floatModelPath)) + imageClassifierOptions.runningMode = runningMode + imageClassifierOptions.completion = {(result: ImageClassifierResult?, error: Error?) -> () in + } + + assertCreateImageClassifierThrowsError( + imageClassifierOptions: imageClassifierOptions, + expectedErrorDescription: """ + The vision task is in image or video mode, a user-defined result \ + callback should not be provided. + """) + } + } + + func testImageClassifierFailsWithMissingResultListenerInLiveStreamMode() throws { + + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath( + ImageClassifierTests.floatModelPath)) + imageClassifierOptions.runningMode = .liveStream + + assertCreateImageClassifierThrowsError( + imageClassifierOptions: imageClassifierOptions, + expectedErrorDescription: """ + The vision task is in live stream mode, a user-defined result callback \ + must be provided. + """) + } + + func testClassifyFailsWithCallingWrongApiInImageMode() throws { + + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath( + ImageClassifierTests.floatModelPath)) + + let imageClassifier = try XCTUnwrap(ImageClassifier(options: + imageClassifierOptions)) + + let mpImage = try XCTUnwrap( + MPImage.imageFromBundle( + withClass: type(of: self), + filename: ImageClassifierTests.multiObjectsRotatedImage.name, + type: ImageClassifierTests.multiObjectsRotatedImage.type)) + + do { + try imageClassifier.classifyAsync( + image: mpImage, + timestampMs:0) + } catch { + assertEqualErrorDescriptions( + error, + expectedLocalizedDescription: """ + The vision task is not initialized with live stream mode. Current \ + Running Mode: Image + """) + } + + do { + let imagClassifierResult = try imageClassifier.classify( + videoFrame: mpImage, + timestampMs: 0) + XCTAssertNil(imagClassifierResult) + } catch { + assertEqualErrorDescriptions( + error, + expectedLocalizedDescription: """ + The vision task is not initialized with video mode. Current Running Mode: Image + """) + } + } + + func testClassifyFailsWithCallingWrongApiInVideoMode() throws { + + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath( + ImageClassifierTests.floatModelPath)) + + imageClassifierOptions.runningMode = .video + + let imageClassifier = try XCTUnwrap(ImageClassifier(options: + imageClassifierOptions)) + + let mpImage = try XCTUnwrap( + MPImage.imageFromBundle( + withClass: type(of: self), + filename: ImageClassifierTests.multiObjectsRotatedImage.name, + type: ImageClassifierTests.multiObjectsRotatedImage.type)) + + do { + try imageClassifier.classifyAsync( + image: mpImage, + timestampMs:0) + } catch { + assertEqualErrorDescriptions( + error, + expectedLocalizedDescription: """ + The vision task is not initialized with live stream mode. Current \ + Running Mode: Video + """) + } + + do { + let imagClassifierResult = try imageClassifier.classify( + image: mpImage) + XCTAssertNil(imagClassifierResult) + } catch { + assertEqualErrorDescriptions( + error, + expectedLocalizedDescription: """ + The vision task is not initialized with image mode. Current Running \ + Mode: Video + """) + } + } + + func testClassifyFailsWithCallingWrongApiLiveStreamInMode() throws { + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath( + ImageClassifierTests.floatModelPath)) + + imageClassifierOptions.runningMode = .liveStream + imageClassifierOptions.completion = {( + result: ImageClassifierResult?, + error: Error?) -> () in + } + + let imageClassifier = try XCTUnwrap(ImageClassifier(options: + imageClassifierOptions)) + + let mpImage = try XCTUnwrap( + MPImage.imageFromBundle( + withClass: type(of: self), + filename: ImageClassifierTests.multiObjectsRotatedImage.name, + type: ImageClassifierTests.multiObjectsRotatedImage.type)) + + do { + let imagClassifierResult = try imageClassifier.classify( + image: mpImage) + XCTAssertNil(imagClassifierResult) + } catch { + assertEqualErrorDescriptions( + error, + expectedLocalizedDescription: """ + The vision task is not initialized with image mode. Current Running \ + Mode: Live Stream + """) + } + + do { + let imagClassifierResult = try imageClassifier.classify( + videoFrame: mpImage, + timestampMs: 0) + XCTAssertNil(imagClassifierResult) + } catch { + assertEqualErrorDescriptions( + error, + expectedLocalizedDescription: """ + The vision task is not initialized with video mode. Current Running \ + Mode: Live Stream + """) + } + } + + func testClassifyWithVideoModeSucceeds() throws { + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath( + ImageClassifierTests.floatModelPath)) + + imageClassifierOptions.runningMode = .video + + let maxResults = 3; + imageClassifierOptions.maxResults = maxResults + + let imageClassifier = try XCTUnwrap(ImageClassifier(options: + imageClassifierOptions)) + + let mpImage = try XCTUnwrap( + MPImage.imageFromBundle( + withClass: type(of: self), + filename: ImageClassifierTests.burgerImage.name, + type: ImageClassifierTests.burgerImage.type)) + + for i in 0..<3 { + let imageClassifierResult = try XCTUnwrap( + imageClassifier.classify( + videoFrame: mpImage, + timestampMs: i)) + try assertImageClassifierResult( + imageClassifierResult, + hasCategoryCount: maxResults, + andCategories: ImageClassifierTests.expectedResultsClassifyBurgerImageWithFloatModel + ) + } + } + + func testClassifyWithLiveStreamModeSucceeds() throws { + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath( + ImageClassifierTests.floatModelPath)) + + imageClassifierOptions.runningMode = .liveStream + + let maxResults = 3 + imageClassifierOptions.maxResults = maxResults + + let expectation = expectation(description: "liveStreamClassify") + + imageClassifierOptions.completion = {(result: ImageClassifierResult?, error: Error?) -> () in + do { + try self.assertImageClassifierResult( + try XCTUnwrap(result), + hasCategoryCount: maxResults, + andCategories: ImageClassifierTests.expectedResultsClassifyBurgerImageWithFloatModel) + } + catch { + + } + expectation.fulfill() + } + + let imageClassifier = try XCTUnwrap(ImageClassifier(options: + imageClassifierOptions)) + + let mpImage = try XCTUnwrap( + MPImage.imageFromBundle( + withClass: type(of: self), + filename: ImageClassifierTests.burgerImage.name, + type: ImageClassifierTests.burgerImage.type)) + + for i in 0..<3 { + XCTAssertNoThrow( + try imageClassifier.classifyAsync( + image: mpImage, + timestampMs: i)) + } + + wait(for: [expectation], timeout: 10) + + } + + // func testClassifyWithMaxResultsSucceeds() throws { + // let textClassifierOptions = + // try XCTUnwrap( + // textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath)) + // textClassifierOptions.maxResults = 1 + + // let textClassifier = + // try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + // try assertResultsForClassify( + // text: TextClassifierTests.negativeText, + // using: textClassifier, + // equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases) + // } + + // func testClassifyWithCategoryAllowlistSucceeds() throws { + // let textClassifierOptions = + // try XCTUnwrap( + // textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath)) + // textClassifierOptions.categoryAllowlist = ["negative"] + + // let textClassifier = + // try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + // try assertResultsForClassify( + // text: TextClassifierTests.negativeText, + // using: textClassifier, + // equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases) + // } + + // func testClassifyWithCategoryDenylistSucceeds() throws { + // let textClassifierOptions = + // try XCTUnwrap( + // textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath)) + // textClassifierOptions.categoryDenylist = ["positive"] + + // let textClassifier = + // try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + // try assertResultsForClassify( + // text: TextClassifierTests.negativeText, + // using: textClassifier, + // equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases) + // } +} diff --git a/mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h b/mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h index 9dfe29fd3..bf225cd16 100644 --- a/mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h +++ b/mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h @@ -37,7 +37,7 @@ NS_ASSUME_NONNULL_BEGIN + (nullable MPPImage *)imageFromBundleWithClass:(Class)classObject fileName:(NSString *)name ofType:(NSString *)type - NS_SWIFT_NAME(imageFromBundle(class:filename:type:)); + NS_SWIFT_NAME(imageFromBundle(withClass:filename:type:)); /** * Loads an image from a file in an app bundle into a `MPPImage` object with the specified @@ -56,7 +56,7 @@ NS_ASSUME_NONNULL_BEGIN fileName:(NSString *)name ofType:(NSString *)type orientation:(UIImageOrientation)imageOrientation - NS_SWIFT_NAME(imageFromBundle(class:filename:type:orientation:)); + NS_SWIFT_NAME(imageFromBundle(withClass:filename:type:orientation:)); @end