Added swift and objective tests for iOS text classifier

This commit is contained in:
Prianka Liz Kariat 2023-01-13 21:05:44 +05:30
parent 9e0b85c9b5
commit 2a53d78ae4
3 changed files with 600 additions and 0 deletions

View File

@ -0,0 +1,82 @@
load(
"@build_bazel_rules_apple//apple:ios.bzl",
"ios_unit_test",
)
load(
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
"tflite_ios_lab_runner"
)
load(
"@build_bazel_rules_swift//swift:swift.bzl",
"swift_library"
)
load(
"//mediapipe/tasks:ios/ios.bzl",
"MPP_TASK_MINIMUM_OS_VERSION"
)
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
TFL_DEFAULT_TAGS = [
"apple",
]
# Following sanitizer tests are not supported by iOS test targets.
TFL_DISABLED_SANITIZER_TAGS = [
"noasan",
"nomsan",
"notsan",
]
objc_library(
name = "MPPTextClassifierObjcTestLibrary",
testonly = 1,
srcs = ["MPPTextClassifierTests.m"],
data = [
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
"//mediapipe/tasks/testdata/text:text_classifier_models",
],
tags = [],
deps = [
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier",
],
)
ios_unit_test(
name = "MPPTextClassifierObjcTest",
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags =[],
deps = [
":MPPTextClassifierObjcTestLibrary",
],
)
swift_library(
name = "MPPTextClassifierSwiftTestLibrary",
testonly = 1,
srcs = ["TextClassifierTests.swift"],
data = [
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
"//mediapipe/tasks/testdata/text:text_classifier_models",
],
tags = TFL_DEFAULT_TAGS,
deps = [
"//mediapipe/tasks/ios/common:MPPCommon",
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier",
],
)
ios_unit_test(
name = "MPPTextClassifierSwiftTest",
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
deps = [
":MPPTextClassifierSwiftTestLibrary",
],
)

View File

@ -0,0 +1,281 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <XCTest/XCTest.h>
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
static NSString *const kBertTextClassifierModelName = @"bert_text_classifier";
static NSString *const kRegexTextClassifierModelName =
@"test_model_text_classifier_with_regex_tokenizer";
static NSString *const kNegativeText = @"unflinchingly bleak and desperate";
static NSString *const kPositiveText = @"it's a charming and often affecting journey";
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
#define AssertEqualErrors(error, expectedError) \
XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \
XCTAssertNotEqual( \
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
NSNotFound)
#define AssertEqualCategoryArrays(categories, expectedCategories) \
XCTAssertEqual(categories.count, expectedCategories.count); \
for (int i = 0; i < categories.count; i++) { \
XCTAssertEqual(categories[i].index, expectedCategories[i].index); \
XCTAssertEqualWithAccuracy(categories[i].score, expectedCategories[i].score, 1e-6); \
XCTAssertEqualObjects(categories[i].categoryName, expectedCategories[i].categoryName); \
XCTAssertEqualObjects(categories[i].displayName, expectedCategories[i].displayName); \
}
#define AssertTextClassifierResultHasOneHead(textClassifierResult) \
XCTAssertNotNil(textClassifierResult); \
\
XCTAssertNotNil(textClassifierResult.classificationResult); \
XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); \
XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0);
@interface MPPTextClassifierTests : XCTestCase
@end
@implementation MPPTextClassifierTests
- (void)setUp {
}
- (void)tearDown {
// Put teardown code here. This method is called after the invocation of each test method in the
// class.
}
+ (NSArray<MPPCategory *> *)expectedBertResultCategoriesForNegativeText {
return @[
[[MPPCategory alloc] initWithIndex:0 score:0.956187f categoryName:@"negative" displayName:nil],
[[MPPCategory alloc] initWithIndex:1 score:0.043812f categoryName:@"positive" displayName:nil]
];
}
+ (NSArray<MPPCategory *> *)expectedBertResultCategoriesForPositiveText {
return @[
[[MPPCategory alloc] initWithIndex:1 score:0.999945f categoryName:@"positive" displayName:nil],
[[MPPCategory alloc] initWithIndex:0 score:0.000055f categoryName:@"negative" displayName:nil]
];
}
+ (NSArray<MPPCategory *> *)expectedRegexResultCategoriesForNegativeText {
return @[
[[MPPCategory alloc] initWithIndex:0 score:0.6647746f categoryName:@"Negative" displayName:nil],
[[MPPCategory alloc] initWithIndex:1 score:0.33522537 categoryName:@"Positive" displayName:nil]
];
}
+ (NSArray<MPPCategory *> *)expectedRegexResultCategoriesForPositiveText {
return @[
[[MPPCategory alloc] initWithIndex:0 score:0.5120041f categoryName:@"Negative" displayName:nil],
[[MPPCategory alloc] initWithIndex:1 score:0.48799595 categoryName:@"Positive" displayName:nil]
];
}
+ (NSArray<MPPCategory *> *)expectedBertResultCategoriesForEdgeCaseTests {
return @[ [[MPPCategory alloc] initWithIndex:0
score:0.956187f
categoryName:@"negative"
displayName:nil] ];
}
- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName
ofType:extension];
return filePath;
}
- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName {
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
MPPTextClassifierOptions *textClassifierOptions = [[MPPTextClassifierOptions alloc] init];
textClassifierOptions.baseOptions.modelAssetPath = modelPath;
return textClassifierOptions;
}
- (MPPTextClassifier *)textClassifierFromModelFileWithName:(NSString *)modelName {
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath
error:nil];
XCTAssertNotNil(textClassifier);
return textClassifier;
}
- (void)assertCreateTextClassifierWithOptions:(MPPTextClassifierOptions *)textClassifierOptions
failsWithExpectedError:(NSError *)expectedError {
NSError *error = nil;
MPPTextClassifier *textClassifier =
[[MPPTextClassifier alloc] initWithOptions:textClassifierOptions error:&error];
XCTAssertNil(textClassifier);
AssertEqualErrors(error, expectedError);
}
- (void)assertResultsOfClassifyText:(NSString *)text
usingTextClassifier:(MPPTextClassifier *)textClassifier
equalsCategories:(NSArray<MPPCategory *> *)expectedCategories {
MPPTextClassifierResult *negativeResult = [textClassifier classifyText:text error:nil];
AssertTextClassifierResultHasOneHead(negativeResult);
AssertEqualCategoryArrays(negativeResult.classificationResult.classifications[0].categories,
expectedCategories);
}
- (void)testCreateTextClassifierFailsWithMissingModelPath {
NSString *modelPath = [self filePathWithName:@"" extension:@""];
NSError *error = nil;
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath
error:&error];
XCTAssertNil(textClassifier);
NSError *expectedError = [NSError
errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', "
@"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."
}];
AssertEqualErrors(error, expectedError);
}
- (void)testCreateTextClassifierFailsWithBothAllowListAndDenyList {
MPPTextClassifierOptions *options =
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
options.categoryAllowlist = @[ @"positive" ];
options.categoryDenylist = @[ @"negative" ];
[self assertCreateTextClassifierWithOptions:options
failsWithExpectedError:
[NSError
errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"INVALID_ARGUMENT: `category_allowlist` and "
@"`category_denylist` are mutually exclusive options."
}]];
}
- (void)testCreateTextClassifierFailsWithInvalidMaxResults {
MPPTextClassifierOptions *options =
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
options.maxResults = 0;
[self assertCreateTextClassifierWithOptions:options
failsWithExpectedError:
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"INVALID_ARGUMENT: Invalid `max_results` option: "
@"value must be != 0."
}]];
}
- (void)testClassifyWithBertSucceeds {
MPPTextClassifier *textClassifier =
[self textClassifierFromModelFileWithName:kBertTextClassifierModelName];
[self assertResultsOfClassifyText:kNegativeText
usingTextClassifier:textClassifier
equalsCategories:[MPPTextClassifierTests
expectedBertResultCategoriesForNegativeText]];
[self assertResultsOfClassifyText:kPositiveText
usingTextClassifier:textClassifier
equalsCategories:[MPPTextClassifierTests
expectedBertResultCategoriesForPositiveText]];
}
- (void)testClassifyWithRegexSucceeds {
MPPTextClassifier *textClassifier =
[self textClassifierFromModelFileWithName:kRegexTextClassifierModelName];
[self assertResultsOfClassifyText:kNegativeText
usingTextClassifier:textClassifier
equalsCategories:[MPPTextClassifierTests
expectedRegexResultCategoriesForNegativeText]];
[self assertResultsOfClassifyText:kPositiveText
usingTextClassifier:textClassifier
equalsCategories:[MPPTextClassifierTests
expectedRegexResultCategoriesForPositiveText]];
}
- (void)testClassifyWithMaxResultsSucceeds {
MPPTextClassifierOptions *options =
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
options.maxResults = 1;
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
XCTAssertNotNil(textClassifier);
[self assertResultsOfClassifyText:kNegativeText
usingTextClassifier:textClassifier
equalsCategories:[MPPTextClassifierTests
expectedBertResultCategoriesForEdgeCaseTests]];
}
- (void)testClassifyWithCategoryAllowListSucceeds {
MPPTextClassifierOptions *options =
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
options.categoryAllowlist = @[ @"negative" ];
NSError *error = nil;
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options
error:&error];
XCTAssertNotNil(textClassifier);
XCTAssertNil(error);
[self assertResultsOfClassifyText:kNegativeText
usingTextClassifier:textClassifier
equalsCategories:[MPPTextClassifierTests
expectedBertResultCategoriesForEdgeCaseTests]];
}
- (void)testClassifyWithCategoryDenyListSucceeds {
MPPTextClassifierOptions *options =
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
options.categoryDenylist = @[ @"positive" ];
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
XCTAssertNotNil(textClassifier);
[self assertResultsOfClassifyText:kNegativeText
usingTextClassifier:textClassifier
equalsCategories:[MPPTextClassifierTests
expectedBertResultCategoriesForEdgeCaseTests]];
}
- (void)testClassifyWithScoreThresholdSucceeds {
MPPTextClassifierOptions *options =
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
options.scoreThreshold = 0.5f;
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
XCTAssertNotNil(textClassifier);
[self assertResultsOfClassifyText:kNegativeText
usingTextClassifier:textClassifier
equalsCategories:[MPPTextClassifierTests
expectedBertResultCategoriesForEdgeCaseTests]];
}
@end

View File

@ -0,0 +1,237 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import XCTest
import MPPCommon
@testable import MPPTextClassifier
class TextClassifierTests: XCTestCase {
static let bundle = Bundle(for: TextClassifierTests.self)
static let kBertModelPath = bundle.path(
forResource: "bert_text_classifier",
ofType: "tflite")
static let kPositiveText = "it's a charming and often affecting journey"
static let kNegativeText = "unflinchingly bleak and desperate"
static let kBertNegativeTextResults = [
ResultCategory(
index: 0,
score: 0.956187,
categoryName: "negative",
displayName: nil),
ResultCategory(
index: 1,
score: 0.043812,
categoryName: "positive",
displayName: nil)
]
static let kBertNegativeTextResultsForEdgeTestCases = [
ResultCategory(
index: 0,
score: 0.956187,
categoryName: "negative",
displayName: nil),
]
func assertEqualErrorDescriptions(
_ error: Error, expectedLocalizedDescription:String) {
XCTAssertEqual(
error.localizedDescription,
expectedLocalizedDescription)
}
func assertCategoriesAreEqual(
category: ResultCategory,
expectedCategory: ResultCategory) {
XCTAssertEqual(
category.index,
expectedCategory.index)
XCTAssertEqual(
category.score,
expectedCategory.score,
accuracy:1e-6)
XCTAssertEqual(
category.categoryName,
expectedCategory.categoryName)
XCTAssertEqual(
category.displayName,
expectedCategory.displayName)
}
func assertEqualCategoryArrays(
categoryArray: [ResultCategory],
expectedCategoryArray:[ResultCategory]) {
XCTAssertEqual(categoryArray.count, expectedCategoryArray.count)
for (category, expectedCategory) in
zip(categoryArray, expectedCategoryArray) {
assertCategoriesAreEqual(
category:category,
expectedCategory:expectedCategory)
}
}
func assertTextClassifierResultHasOneHead(
_ textClassifierResult: TextClassifierResult) {
XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1);
XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0);
}
func textClassifierOptionsWithModelPath(
_ modelPath: String?) throws -> TextClassifierOptions {
let modelPath = try XCTUnwrap(modelPath)
let textClassifierOptions = TextClassifierOptions();
textClassifierOptions.baseOptions.modelAssetPath = modelPath;
return textClassifierOptions
}
func assertCreateTextClassifierThrowsError(
textClassifierOptions: TextClassifierOptions,
expectedErrorDescription: String) {
do {
let textClassifier = try TextClassifier(options:textClassifierOptions)
XCTAssertNil(textClassifier)
}
catch {
assertEqualErrorDescriptions(
error,
expectedLocalizedDescription: expectedErrorDescription)
}
}
func assertResultsForClassify(
text: String,
using textClassifier: TextClassifier,
equals expectedCategories: [ResultCategory]) throws {
let textClassifierResult =
try XCTUnwrap(
textClassifier.classify(text: text));
assertTextClassifierResultHasOneHead(textClassifierResult);
assertEqualCategoryArrays(
categoryArray:
textClassifierResult.classificationResult.classifications[0].categories,
expectedCategoryArray: expectedCategories);
}
func testCreateTextClassifierWithInvalidMaxResultsFails() throws {
let textClassifierOptions =
try XCTUnwrap(
textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
textClassifierOptions.maxResults = 0
assertCreateTextClassifierThrowsError(
textClassifierOptions: textClassifierOptions,
expectedErrorDescription: """
INVALID_ARGUMENT: Invalid `max_results` option: value must be != 0.
""")
}
func testCreateTextClassifierWithCategoryAllowlistandDenylistFails() throws {
let textClassifierOptions =
try XCTUnwrap(
textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
textClassifierOptions.categoryAllowlist = ["positive"]
textClassifierOptions.categoryDenylist = ["positive"]
assertCreateTextClassifierThrowsError(
textClassifierOptions: textClassifierOptions,
expectedErrorDescription: """
INVALID_ARGUMENT: `category_allowlist` and `category_denylist` are \
mutually exclusive options.
""")
}
func testClassifyWithBertSucceeds() throws {
let modelPath = try XCTUnwrap(TextClassifierTests.kBertModelPath)
let textClassifier = try XCTUnwrap(TextClassifier(modelPath: modelPath))
try assertResultsForClassify(
text: TextClassifierTests.kNegativeText,
using: textClassifier,
equals: TextClassifierTests.kBertNegativeTextResults)
}
func testClassifyWithMaxResultsSucceeds() throws {
let textClassifierOptions =
try XCTUnwrap(
textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
textClassifierOptions.maxResults = 1
let textClassifier =
try XCTUnwrap(TextClassifier(options: textClassifierOptions))
try assertResultsForClassify(
text: TextClassifierTests.kNegativeText,
using: textClassifier,
equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases)
}
func testClassifyWithCategoryAllowlistSucceeds() throws {
let textClassifierOptions =
try XCTUnwrap(
textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
textClassifierOptions.categoryAllowlist = ["negative"];
let textClassifier =
try XCTUnwrap(TextClassifier(options: textClassifierOptions))
try assertResultsForClassify(
text: TextClassifierTests.kNegativeText,
using: textClassifier,
equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases)
}
func testClassifyWithCategoryDenylistSucceeds() throws {
let textClassifierOptions =
try XCTUnwrap(
textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
textClassifierOptions.categoryDenylist = ["positive"];
let textClassifier =
try XCTUnwrap(TextClassifier(options: textClassifierOptions))
try assertResultsForClassify(
text: TextClassifierTests.kNegativeText,
using: textClassifier,
equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases)
}
func testClassifyWithScoreThresholdSucceeds() throws {
let textClassifierOptions =
try XCTUnwrap(
textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
textClassifierOptions.scoreThreshold = 0.5;
let textClassifier =
try XCTUnwrap(TextClassifier(options: textClassifierOptions))
try assertResultsForClassify(
text: TextClassifierTests.kNegativeText,
using: textClassifier,
equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases)
}
}