diff --git a/mediapipe/tasks/ios/test/text/text_classifier/BUILD b/mediapipe/tasks/ios/test/text/text_classifier/BUILD new file mode 100644 index 000000000..b69202b64 --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_classifier/BUILD @@ -0,0 +1,82 @@ +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_unit_test", +) +load( + "@org_tensorflow//tensorflow/lite:special_rules.bzl", + "tflite_ios_lab_runner" +) +load( + "@build_bazel_rules_swift//swift:swift.bzl", + "swift_library" +) +load( + "//mediapipe/tasks:ios/ios.bzl", + "MPP_TASK_MINIMUM_OS_VERSION" +) + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# Default tags for filtering iOS targets. Targets are restricted to Apple platforms. +TFL_DEFAULT_TAGS = [ + "apple", +] + +# Following sanitizer tests are not supported by iOS test targets. +TFL_DISABLED_SANITIZER_TAGS = [ + "noasan", + "nomsan", + "notsan", +] + +objc_library( + name = "MPPTextClassifierObjcTestLibrary", + testonly = 1, + srcs = ["MPPTextClassifierTests.m"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + tags = [], + deps = [ + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier", + ], + +) + +ios_unit_test( + name = "MPPTextClassifierObjcTest", + minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags =[], + deps = [ + ":MPPTextClassifierObjcTestLibrary", + ], +) + +swift_library( + name = "MPPTextClassifierSwiftTestLibrary", + testonly = 1, + srcs = ["TextClassifierTests.swift"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + tags = TFL_DEFAULT_TAGS, + deps = [ + "//mediapipe/tasks/ios/common:MPPCommon", + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier", + ], +) + +ios_unit_test( + name = "MPPTextClassifierSwiftTest", + minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":MPPTextClassifierSwiftTestLibrary", + ], +) diff --git a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m new file mode 100644 index 000000000..3e2fe4bef --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m @@ -0,0 +1,281 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/common/sources/MPPCommon.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" + +static NSString *const kBertTextClassifierModelName = @"bert_text_classifier"; +static NSString *const kRegexTextClassifierModelName = + @"test_model_text_classifier_with_regex_tokenizer"; +static NSString *const kNegativeText = @"unflinchingly bleak and desperate"; +static NSString *const kPositiveText = @"it's a charming and often affecting journey"; +static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; + +#define AssertEqualErrors(error, expectedError) \ + XCTAssertNotNil(error); \ + XCTAssertEqualObjects(error.domain, expectedError.domain); \ + XCTAssertEqual(error.code, expectedError.code); \ + XCTAssertNotEqual( \ + [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \ + NSNotFound) + +#define AssertEqualCategoryArrays(categories, expectedCategories) \ + XCTAssertEqual(categories.count, expectedCategories.count); \ + for (int i = 0; i < categories.count; i++) { \ + XCTAssertEqual(categories[i].index, expectedCategories[i].index); \ + XCTAssertEqualWithAccuracy(categories[i].score, expectedCategories[i].score, 1e-6); \ + XCTAssertEqualObjects(categories[i].categoryName, expectedCategories[i].categoryName); \ + XCTAssertEqualObjects(categories[i].displayName, expectedCategories[i].displayName); \ + } + +#define AssertTextClassifierResultHasOneHead(textClassifierResult) \ + XCTAssertNotNil(textClassifierResult); \ + \ + XCTAssertNotNil(textClassifierResult.classificationResult); \ + XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); \ + XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0); + +@interface MPPTextClassifierTests : XCTestCase +@end + +@implementation MPPTextClassifierTests + +- (void)setUp { +} + +- (void)tearDown { + // Put teardown code here. This method is called after the invocation of each test method in the + // class. +} + ++ (NSArray *)expectedBertResultCategoriesForNegativeText { + return @[ + [[MPPCategory alloc] initWithIndex:0 score:0.956187f categoryName:@"negative" displayName:nil], + [[MPPCategory alloc] initWithIndex:1 score:0.043812f categoryName:@"positive" displayName:nil] + ]; +} + ++ (NSArray *)expectedBertResultCategoriesForPositiveText { + return @[ + [[MPPCategory alloc] initWithIndex:1 score:0.999945f categoryName:@"positive" displayName:nil], + [[MPPCategory alloc] initWithIndex:0 score:0.000055f categoryName:@"negative" displayName:nil] + ]; +} + ++ (NSArray *)expectedRegexResultCategoriesForNegativeText { + return @[ + [[MPPCategory alloc] initWithIndex:0 score:0.6647746f categoryName:@"Negative" displayName:nil], + [[MPPCategory alloc] initWithIndex:1 score:0.33522537 categoryName:@"Positive" displayName:nil] + ]; +} + ++ (NSArray *)expectedRegexResultCategoriesForPositiveText { + return @[ + [[MPPCategory alloc] initWithIndex:0 score:0.5120041f categoryName:@"Negative" displayName:nil], + [[MPPCategory alloc] initWithIndex:1 score:0.48799595 categoryName:@"Positive" displayName:nil] + ]; +} + ++ (NSArray *)expectedBertResultCategoriesForEdgeCaseTests { + return @[ [[MPPCategory alloc] initWithIndex:0 + score:0.956187f + categoryName:@"negative" + displayName:nil] ]; +} + +- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension { + NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName + ofType:extension]; + return filePath; +} + +- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName { + NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; + MPPTextClassifierOptions *textClassifierOptions = [[MPPTextClassifierOptions alloc] init]; + textClassifierOptions.baseOptions.modelAssetPath = modelPath; + + return textClassifierOptions; +} + +- (MPPTextClassifier *)textClassifierFromModelFileWithName:(NSString *)modelName { + NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath + error:nil]; + XCTAssertNotNil(textClassifier); + + return textClassifier; +} + +- (void)assertCreateTextClassifierWithOptions:(MPPTextClassifierOptions *)textClassifierOptions + failsWithExpectedError:(NSError *)expectedError { + NSError *error = nil; + MPPTextClassifier *textClassifier = + [[MPPTextClassifier alloc] initWithOptions:textClassifierOptions error:&error]; + XCTAssertNil(textClassifier); + AssertEqualErrors(error, expectedError); +} + +- (void)assertResultsOfClassifyText:(NSString *)text + usingTextClassifier:(MPPTextClassifier *)textClassifier + equalsCategories:(NSArray *)expectedCategories { + MPPTextClassifierResult *negativeResult = [textClassifier classifyText:text error:nil]; + AssertTextClassifierResultHasOneHead(negativeResult); + AssertEqualCategoryArrays(negativeResult.classificationResult.classifications[0].categories, + expectedCategories); +} + +- (void)testCreateTextClassifierFailsWithMissingModelPath { + NSString *modelPath = [self filePathWithName:@"" extension:@""]; + + NSError *error = nil; + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath + error:&error]; + XCTAssertNil(textClassifier); + + NSError *expectedError = [NSError + errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', " + @"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'." + }]; + AssertEqualErrors(error, expectedError); +} + +- (void)testCreateTextClassifierFailsWithBothAllowListAndDenyList { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.categoryAllowlist = @[ @"positive" ]; + options.categoryDenylist = @[ @"negative" ]; + + [self assertCreateTextClassifierWithOptions:options + failsWithExpectedError: + [NSError + errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"INVALID_ARGUMENT: `category_allowlist` and " + @"`category_denylist` are mutually exclusive options." + }]]; +} + +- (void)testCreateTextClassifierFailsWithInvalidMaxResults { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.maxResults = 0; + + [self assertCreateTextClassifierWithOptions:options + failsWithExpectedError: + [NSError errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"INVALID_ARGUMENT: Invalid `max_results` option: " + @"value must be != 0." + }]]; +} + +- (void)testClassifyWithBertSucceeds { + MPPTextClassifier *textClassifier = + [self textClassifierFromModelFileWithName:kBertTextClassifierModelName]; + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForNegativeText]]; + + [self assertResultsOfClassifyText:kPositiveText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForPositiveText]]; +} + +- (void)testClassifyWithRegexSucceeds { + MPPTextClassifier *textClassifier = + [self textClassifierFromModelFileWithName:kRegexTextClassifierModelName]; + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedRegexResultCategoriesForNegativeText]]; + [self assertResultsOfClassifyText:kPositiveText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedRegexResultCategoriesForPositiveText]]; +} + +- (void)testClassifyWithMaxResultsSucceeds { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.maxResults = 1; + + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; + XCTAssertNotNil(textClassifier); + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForEdgeCaseTests]]; +} + +- (void)testClassifyWithCategoryAllowListSucceeds { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.categoryAllowlist = @[ @"negative" ]; + + NSError *error = nil; + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options + error:&error]; + XCTAssertNotNil(textClassifier); + XCTAssertNil(error); + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForEdgeCaseTests]]; +} + +- (void)testClassifyWithCategoryDenyListSucceeds { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.categoryDenylist = @[ @"positive" ]; + + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; + XCTAssertNotNil(textClassifier); + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForEdgeCaseTests]]; +} + +- (void)testClassifyWithScoreThresholdSucceeds { + MPPTextClassifierOptions *options = + [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + options.scoreThreshold = 0.5f; + + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; + XCTAssertNotNil(textClassifier); + + [self assertResultsOfClassifyText:kNegativeText + usingTextClassifier:textClassifier + equalsCategories:[MPPTextClassifierTests + expectedBertResultCategoriesForEdgeCaseTests]]; +} + +@end diff --git a/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift b/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift new file mode 100644 index 000000000..d2d433c22 --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_classifier/TextClassifierTests.swift @@ -0,0 +1,237 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest + +import MPPCommon + +@testable import MPPTextClassifier + +class TextClassifierTests: XCTestCase { + + static let bundle = Bundle(for: TextClassifierTests.self) + + static let kBertModelPath = bundle.path( + forResource: "bert_text_classifier", + ofType: "tflite") + + static let kPositiveText = "it's a charming and often affecting journey" + + static let kNegativeText = "unflinchingly bleak and desperate" + + static let kBertNegativeTextResults = [ + ResultCategory( + index: 0, + score: 0.956187, + categoryName: "negative", + displayName: nil), + ResultCategory( + index: 1, + score: 0.043812, + categoryName: "positive", + displayName: nil) + ] + + static let kBertNegativeTextResultsForEdgeTestCases = [ + ResultCategory( + index: 0, + score: 0.956187, + categoryName: "negative", + displayName: nil), + ] + + func assertEqualErrorDescriptions( + _ error: Error, expectedLocalizedDescription:String) { + XCTAssertEqual( + error.localizedDescription, + expectedLocalizedDescription) + } + + func assertCategoriesAreEqual( + category: ResultCategory, + expectedCategory: ResultCategory) { + XCTAssertEqual( + category.index, + expectedCategory.index) + XCTAssertEqual( + category.score, + expectedCategory.score, + accuracy:1e-6) + XCTAssertEqual( + category.categoryName, + expectedCategory.categoryName) + XCTAssertEqual( + category.displayName, + expectedCategory.displayName) + } + + func assertEqualCategoryArrays( + categoryArray: [ResultCategory], + expectedCategoryArray:[ResultCategory]) { + + XCTAssertEqual(categoryArray.count, expectedCategoryArray.count) + + for (category, expectedCategory) in + zip(categoryArray, expectedCategoryArray) { + assertCategoriesAreEqual( + category:category, + expectedCategory:expectedCategory) + } + } + + func assertTextClassifierResultHasOneHead( + _ textClassifierResult: TextClassifierResult) { + XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); + XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0); + } + + func textClassifierOptionsWithModelPath( + _ modelPath: String?) throws -> TextClassifierOptions { + let modelPath = try XCTUnwrap(modelPath) + + let textClassifierOptions = TextClassifierOptions(); + textClassifierOptions.baseOptions.modelAssetPath = modelPath; + + return textClassifierOptions + } + + func assertCreateTextClassifierThrowsError( + textClassifierOptions: TextClassifierOptions, + expectedErrorDescription: String) { + do { + let textClassifier = try TextClassifier(options:textClassifierOptions) + XCTAssertNil(textClassifier) + } + catch { + assertEqualErrorDescriptions( + error, + expectedLocalizedDescription: expectedErrorDescription) + } + } + + func assertResultsForClassify( + text: String, + using textClassifier: TextClassifier, + equals expectedCategories: [ResultCategory]) throws { + let textClassifierResult = + try XCTUnwrap( + textClassifier.classify(text: text)); + assertTextClassifierResultHasOneHead(textClassifierResult); + assertEqualCategoryArrays( + categoryArray: + textClassifierResult.classificationResult.classifications[0].categories, + expectedCategoryArray: expectedCategories); + } + + func testCreateTextClassifierWithInvalidMaxResultsFails() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.maxResults = 0 + + assertCreateTextClassifierThrowsError( + textClassifierOptions: textClassifierOptions, + expectedErrorDescription: """ + INVALID_ARGUMENT: Invalid `max_results` option: value must be != 0. + """) + } + + func testCreateTextClassifierWithCategoryAllowlistandDenylistFails() throws { + + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.categoryAllowlist = ["positive"] + textClassifierOptions.categoryDenylist = ["positive"] + + assertCreateTextClassifierThrowsError( + textClassifierOptions: textClassifierOptions, + expectedErrorDescription: """ + INVALID_ARGUMENT: `category_allowlist` and `category_denylist` are \ + mutually exclusive options. + """) + } + + func testClassifyWithBertSucceeds() throws { + + let modelPath = try XCTUnwrap(TextClassifierTests.kBertModelPath) + let textClassifier = try XCTUnwrap(TextClassifier(modelPath: modelPath)) + + try assertResultsForClassify( + text: TextClassifierTests.kNegativeText, + using: textClassifier, + equals: TextClassifierTests.kBertNegativeTextResults) + } + + func testClassifyWithMaxResultsSucceeds() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.maxResults = 1 + + let textClassifier = + try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + try assertResultsForClassify( + text: TextClassifierTests.kNegativeText, + using: textClassifier, + equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases) + } + + func testClassifyWithCategoryAllowlistSucceeds() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.categoryAllowlist = ["negative"]; + + let textClassifier = + try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + try assertResultsForClassify( + text: TextClassifierTests.kNegativeText, + using: textClassifier, + equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases) + } + + func testClassifyWithCategoryDenylistSucceeds() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.categoryDenylist = ["positive"]; + + let textClassifier = + try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + try assertResultsForClassify( + text: TextClassifierTests.kNegativeText, + using: textClassifier, + equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases) + } + + func testClassifyWithScoreThresholdSucceeds() throws { + let textClassifierOptions = + try XCTUnwrap( + textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath)) + textClassifierOptions.scoreThreshold = 0.5; + + let textClassifier = + try XCTUnwrap(TextClassifier(options: textClassifierOptions)) + + try assertResultsForClassify( + text: TextClassifierTests.kNegativeText, + using: textClassifier, + equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases) + } + +}