Added swift and objective tests for iOS text classifier
This commit is contained in:
parent
9e0b85c9b5
commit
2a53d78ae4
82
mediapipe/tasks/ios/test/text/text_classifier/BUILD
Normal file
82
mediapipe/tasks/ios/test/text/text_classifier/BUILD
Normal file
|
@ -0,0 +1,82 @@
|
|||
load(
|
||||
"@build_bazel_rules_apple//apple:ios.bzl",
|
||||
"ios_unit_test",
|
||||
)
|
||||
load(
|
||||
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
|
||||
"tflite_ios_lab_runner"
|
||||
)
|
||||
load(
|
||||
"@build_bazel_rules_swift//swift:swift.bzl",
|
||||
"swift_library"
|
||||
)
|
||||
load(
|
||||
"//mediapipe/tasks:ios/ios.bzl",
|
||||
"MPP_TASK_MINIMUM_OS_VERSION"
|
||||
)
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
|
||||
TFL_DEFAULT_TAGS = [
|
||||
"apple",
|
||||
]
|
||||
|
||||
# Following sanitizer tests are not supported by iOS test targets.
|
||||
TFL_DISABLED_SANITIZER_TAGS = [
|
||||
"noasan",
|
||||
"nomsan",
|
||||
"notsan",
|
||||
]
|
||||
|
||||
objc_library(
|
||||
name = "MPPTextClassifierObjcTestLibrary",
|
||||
testonly = 1,
|
||||
srcs = ["MPPTextClassifierTests.m"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
|
||||
"//mediapipe/tasks/testdata/text:text_classifier_models",
|
||||
],
|
||||
tags = [],
|
||||
deps = [
|
||||
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier",
|
||||
],
|
||||
|
||||
)
|
||||
|
||||
ios_unit_test(
|
||||
name = "MPPTextClassifierObjcTest",
|
||||
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
|
||||
runner = tflite_ios_lab_runner("IOS_LATEST"),
|
||||
tags =[],
|
||||
deps = [
|
||||
":MPPTextClassifierObjcTestLibrary",
|
||||
],
|
||||
)
|
||||
|
||||
swift_library(
|
||||
name = "MPPTextClassifierSwiftTestLibrary",
|
||||
testonly = 1,
|
||||
srcs = ["TextClassifierTests.swift"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
|
||||
"//mediapipe/tasks/testdata/text:text_classifier_models",
|
||||
],
|
||||
tags = TFL_DEFAULT_TAGS,
|
||||
deps = [
|
||||
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier",
|
||||
],
|
||||
)
|
||||
|
||||
ios_unit_test(
|
||||
name = "MPPTextClassifierSwiftTest",
|
||||
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
|
||||
runner = tflite_ios_lab_runner("IOS_LATEST"),
|
||||
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
|
||||
deps = [
|
||||
":MPPTextClassifierSwiftTestLibrary",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,281 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
|
||||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
|
||||
|
||||
static NSString *const kBertTextClassifierModelName = @"bert_text_classifier";
|
||||
static NSString *const kRegexTextClassifierModelName =
|
||||
@"test_model_text_classifier_with_regex_tokenizer";
|
||||
static NSString *const kNegativeText = @"unflinchingly bleak and desperate";
|
||||
static NSString *const kPositiveText = @"it's a charming and often affecting journey";
|
||||
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||
|
||||
#define AssertEqualErrors(error, expectedError) \
|
||||
XCTAssertNotNil(error); \
|
||||
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
||||
XCTAssertEqual(error.code, expectedError.code); \
|
||||
XCTAssertNotEqual( \
|
||||
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
|
||||
NSNotFound)
|
||||
|
||||
#define AssertEqualCategoryArrays(categories, expectedCategories) \
|
||||
XCTAssertEqual(categories.count, expectedCategories.count); \
|
||||
for (int i = 0; i < categories.count; i++) { \
|
||||
XCTAssertEqual(categories[i].index, expectedCategories[i].index); \
|
||||
XCTAssertEqualWithAccuracy(categories[i].score, expectedCategories[i].score, 1e-6); \
|
||||
XCTAssertEqualObjects(categories[i].categoryName, expectedCategories[i].categoryName); \
|
||||
XCTAssertEqualObjects(categories[i].displayName, expectedCategories[i].displayName); \
|
||||
}
|
||||
|
||||
#define AssertTextClassifierResultHasOneHead(textClassifierResult) \
|
||||
XCTAssertNotNil(textClassifierResult); \
|
||||
\
|
||||
XCTAssertNotNil(textClassifierResult.classificationResult); \
|
||||
XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); \
|
||||
XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0);
|
||||
|
||||
@interface MPPTextClassifierTests : XCTestCase
|
||||
@end
|
||||
|
||||
@implementation MPPTextClassifierTests
|
||||
|
||||
- (void)setUp {
|
||||
}
|
||||
|
||||
- (void)tearDown {
|
||||
// Put teardown code here. This method is called after the invocation of each test method in the
|
||||
// class.
|
||||
}
|
||||
|
||||
+ (NSArray<MPPCategory *> *)expectedBertResultCategoriesForNegativeText {
|
||||
return @[
|
||||
[[MPPCategory alloc] initWithIndex:0 score:0.956187f categoryName:@"negative" displayName:nil],
|
||||
[[MPPCategory alloc] initWithIndex:1 score:0.043812f categoryName:@"positive" displayName:nil]
|
||||
];
|
||||
}
|
||||
|
||||
+ (NSArray<MPPCategory *> *)expectedBertResultCategoriesForPositiveText {
|
||||
return @[
|
||||
[[MPPCategory alloc] initWithIndex:1 score:0.999945f categoryName:@"positive" displayName:nil],
|
||||
[[MPPCategory alloc] initWithIndex:0 score:0.000055f categoryName:@"negative" displayName:nil]
|
||||
];
|
||||
}
|
||||
|
||||
+ (NSArray<MPPCategory *> *)expectedRegexResultCategoriesForNegativeText {
|
||||
return @[
|
||||
[[MPPCategory alloc] initWithIndex:0 score:0.6647746f categoryName:@"Negative" displayName:nil],
|
||||
[[MPPCategory alloc] initWithIndex:1 score:0.33522537 categoryName:@"Positive" displayName:nil]
|
||||
];
|
||||
}
|
||||
|
||||
+ (NSArray<MPPCategory *> *)expectedRegexResultCategoriesForPositiveText {
|
||||
return @[
|
||||
[[MPPCategory alloc] initWithIndex:0 score:0.5120041f categoryName:@"Negative" displayName:nil],
|
||||
[[MPPCategory alloc] initWithIndex:1 score:0.48799595 categoryName:@"Positive" displayName:nil]
|
||||
];
|
||||
}
|
||||
|
||||
+ (NSArray<MPPCategory *> *)expectedBertResultCategoriesForEdgeCaseTests {
|
||||
return @[ [[MPPCategory alloc] initWithIndex:0
|
||||
score:0.956187f
|
||||
categoryName:@"negative"
|
||||
displayName:nil] ];
|
||||
}
|
||||
|
||||
- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
|
||||
NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName
|
||||
ofType:extension];
|
||||
return filePath;
|
||||
}
|
||||
|
||||
- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName {
|
||||
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
|
||||
MPPTextClassifierOptions *textClassifierOptions = [[MPPTextClassifierOptions alloc] init];
|
||||
textClassifierOptions.baseOptions.modelAssetPath = modelPath;
|
||||
|
||||
return textClassifierOptions;
|
||||
}
|
||||
|
||||
- (MPPTextClassifier *)textClassifierFromModelFileWithName:(NSString *)modelName {
|
||||
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
|
||||
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath
|
||||
error:nil];
|
||||
XCTAssertNotNil(textClassifier);
|
||||
|
||||
return textClassifier;
|
||||
}
|
||||
|
||||
- (void)assertCreateTextClassifierWithOptions:(MPPTextClassifierOptions *)textClassifierOptions
|
||||
failsWithExpectedError:(NSError *)expectedError {
|
||||
NSError *error = nil;
|
||||
MPPTextClassifier *textClassifier =
|
||||
[[MPPTextClassifier alloc] initWithOptions:textClassifierOptions error:&error];
|
||||
XCTAssertNil(textClassifier);
|
||||
AssertEqualErrors(error, expectedError);
|
||||
}
|
||||
|
||||
- (void)assertResultsOfClassifyText:(NSString *)text
|
||||
usingTextClassifier:(MPPTextClassifier *)textClassifier
|
||||
equalsCategories:(NSArray<MPPCategory *> *)expectedCategories {
|
||||
MPPTextClassifierResult *negativeResult = [textClassifier classifyText:text error:nil];
|
||||
AssertTextClassifierResultHasOneHead(negativeResult);
|
||||
AssertEqualCategoryArrays(negativeResult.classificationResult.classifications[0].categories,
|
||||
expectedCategories);
|
||||
}
|
||||
|
||||
- (void)testCreateTextClassifierFailsWithMissingModelPath {
|
||||
NSString *modelPath = [self filePathWithName:@"" extension:@""];
|
||||
|
||||
NSError *error = nil;
|
||||
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath
|
||||
error:&error];
|
||||
XCTAssertNil(textClassifier);
|
||||
|
||||
NSError *expectedError = [NSError
|
||||
errorWithDomain:kExpectedErrorDomain
|
||||
code:MPPTasksErrorCodeInvalidArgumentError
|
||||
userInfo:@{
|
||||
NSLocalizedDescriptionKey :
|
||||
@"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', "
|
||||
@"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."
|
||||
}];
|
||||
AssertEqualErrors(error, expectedError);
|
||||
}
|
||||
|
||||
- (void)testCreateTextClassifierFailsWithBothAllowListAndDenyList {
|
||||
MPPTextClassifierOptions *options =
|
||||
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
|
||||
options.categoryAllowlist = @[ @"positive" ];
|
||||
options.categoryDenylist = @[ @"negative" ];
|
||||
|
||||
[self assertCreateTextClassifierWithOptions:options
|
||||
failsWithExpectedError:
|
||||
[NSError
|
||||
errorWithDomain:kExpectedErrorDomain
|
||||
code:MPPTasksErrorCodeInvalidArgumentError
|
||||
userInfo:@{
|
||||
NSLocalizedDescriptionKey :
|
||||
@"INVALID_ARGUMENT: `category_allowlist` and "
|
||||
@"`category_denylist` are mutually exclusive options."
|
||||
}]];
|
||||
}
|
||||
|
||||
- (void)testCreateTextClassifierFailsWithInvalidMaxResults {
|
||||
MPPTextClassifierOptions *options =
|
||||
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
|
||||
options.maxResults = 0;
|
||||
|
||||
[self assertCreateTextClassifierWithOptions:options
|
||||
failsWithExpectedError:
|
||||
[NSError errorWithDomain:kExpectedErrorDomain
|
||||
code:MPPTasksErrorCodeInvalidArgumentError
|
||||
userInfo:@{
|
||||
NSLocalizedDescriptionKey :
|
||||
@"INVALID_ARGUMENT: Invalid `max_results` option: "
|
||||
@"value must be != 0."
|
||||
}]];
|
||||
}
|
||||
|
||||
- (void)testClassifyWithBertSucceeds {
|
||||
MPPTextClassifier *textClassifier =
|
||||
[self textClassifierFromModelFileWithName:kBertTextClassifierModelName];
|
||||
|
||||
[self assertResultsOfClassifyText:kNegativeText
|
||||
usingTextClassifier:textClassifier
|
||||
equalsCategories:[MPPTextClassifierTests
|
||||
expectedBertResultCategoriesForNegativeText]];
|
||||
|
||||
[self assertResultsOfClassifyText:kPositiveText
|
||||
usingTextClassifier:textClassifier
|
||||
equalsCategories:[MPPTextClassifierTests
|
||||
expectedBertResultCategoriesForPositiveText]];
|
||||
}
|
||||
|
||||
- (void)testClassifyWithRegexSucceeds {
|
||||
MPPTextClassifier *textClassifier =
|
||||
[self textClassifierFromModelFileWithName:kRegexTextClassifierModelName];
|
||||
|
||||
[self assertResultsOfClassifyText:kNegativeText
|
||||
usingTextClassifier:textClassifier
|
||||
equalsCategories:[MPPTextClassifierTests
|
||||
expectedRegexResultCategoriesForNegativeText]];
|
||||
[self assertResultsOfClassifyText:kPositiveText
|
||||
usingTextClassifier:textClassifier
|
||||
equalsCategories:[MPPTextClassifierTests
|
||||
expectedRegexResultCategoriesForPositiveText]];
|
||||
}
|
||||
|
||||
- (void)testClassifyWithMaxResultsSucceeds {
|
||||
MPPTextClassifierOptions *options =
|
||||
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
|
||||
options.maxResults = 1;
|
||||
|
||||
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
|
||||
XCTAssertNotNil(textClassifier);
|
||||
|
||||
[self assertResultsOfClassifyText:kNegativeText
|
||||
usingTextClassifier:textClassifier
|
||||
equalsCategories:[MPPTextClassifierTests
|
||||
expectedBertResultCategoriesForEdgeCaseTests]];
|
||||
}
|
||||
|
||||
- (void)testClassifyWithCategoryAllowListSucceeds {
|
||||
MPPTextClassifierOptions *options =
|
||||
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
|
||||
options.categoryAllowlist = @[ @"negative" ];
|
||||
|
||||
NSError *error = nil;
|
||||
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options
|
||||
error:&error];
|
||||
XCTAssertNotNil(textClassifier);
|
||||
XCTAssertNil(error);
|
||||
|
||||
[self assertResultsOfClassifyText:kNegativeText
|
||||
usingTextClassifier:textClassifier
|
||||
equalsCategories:[MPPTextClassifierTests
|
||||
expectedBertResultCategoriesForEdgeCaseTests]];
|
||||
}
|
||||
|
||||
- (void)testClassifyWithCategoryDenyListSucceeds {
|
||||
MPPTextClassifierOptions *options =
|
||||
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
|
||||
options.categoryDenylist = @[ @"positive" ];
|
||||
|
||||
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
|
||||
XCTAssertNotNil(textClassifier);
|
||||
|
||||
[self assertResultsOfClassifyText:kNegativeText
|
||||
usingTextClassifier:textClassifier
|
||||
equalsCategories:[MPPTextClassifierTests
|
||||
expectedBertResultCategoriesForEdgeCaseTests]];
|
||||
}
|
||||
|
||||
- (void)testClassifyWithScoreThresholdSucceeds {
|
||||
MPPTextClassifierOptions *options =
|
||||
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
|
||||
options.scoreThreshold = 0.5f;
|
||||
|
||||
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
|
||||
XCTAssertNotNil(textClassifier);
|
||||
|
||||
[self assertResultsOfClassifyText:kNegativeText
|
||||
usingTextClassifier:textClassifier
|
||||
equalsCategories:[MPPTextClassifierTests
|
||||
expectedBertResultCategoriesForEdgeCaseTests]];
|
||||
}
|
||||
|
||||
@end
|
|
@ -0,0 +1,237 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import XCTest
|
||||
|
||||
import MPPCommon
|
||||
|
||||
@testable import MPPTextClassifier
|
||||
|
||||
class TextClassifierTests: XCTestCase {
|
||||
|
||||
static let bundle = Bundle(for: TextClassifierTests.self)
|
||||
|
||||
static let kBertModelPath = bundle.path(
|
||||
forResource: "bert_text_classifier",
|
||||
ofType: "tflite")
|
||||
|
||||
static let kPositiveText = "it's a charming and often affecting journey"
|
||||
|
||||
static let kNegativeText = "unflinchingly bleak and desperate"
|
||||
|
||||
static let kBertNegativeTextResults = [
|
||||
ResultCategory(
|
||||
index: 0,
|
||||
score: 0.956187,
|
||||
categoryName: "negative",
|
||||
displayName: nil),
|
||||
ResultCategory(
|
||||
index: 1,
|
||||
score: 0.043812,
|
||||
categoryName: "positive",
|
||||
displayName: nil)
|
||||
]
|
||||
|
||||
static let kBertNegativeTextResultsForEdgeTestCases = [
|
||||
ResultCategory(
|
||||
index: 0,
|
||||
score: 0.956187,
|
||||
categoryName: "negative",
|
||||
displayName: nil),
|
||||
]
|
||||
|
||||
func assertEqualErrorDescriptions(
|
||||
_ error: Error, expectedLocalizedDescription:String) {
|
||||
XCTAssertEqual(
|
||||
error.localizedDescription,
|
||||
expectedLocalizedDescription)
|
||||
}
|
||||
|
||||
func assertCategoriesAreEqual(
|
||||
category: ResultCategory,
|
||||
expectedCategory: ResultCategory) {
|
||||
XCTAssertEqual(
|
||||
category.index,
|
||||
expectedCategory.index)
|
||||
XCTAssertEqual(
|
||||
category.score,
|
||||
expectedCategory.score,
|
||||
accuracy:1e-6)
|
||||
XCTAssertEqual(
|
||||
category.categoryName,
|
||||
expectedCategory.categoryName)
|
||||
XCTAssertEqual(
|
||||
category.displayName,
|
||||
expectedCategory.displayName)
|
||||
}
|
||||
|
||||
func assertEqualCategoryArrays(
|
||||
categoryArray: [ResultCategory],
|
||||
expectedCategoryArray:[ResultCategory]) {
|
||||
|
||||
XCTAssertEqual(categoryArray.count, expectedCategoryArray.count)
|
||||
|
||||
for (category, expectedCategory) in
|
||||
zip(categoryArray, expectedCategoryArray) {
|
||||
assertCategoriesAreEqual(
|
||||
category:category,
|
||||
expectedCategory:expectedCategory)
|
||||
}
|
||||
}
|
||||
|
||||
func assertTextClassifierResultHasOneHead(
|
||||
_ textClassifierResult: TextClassifierResult) {
|
||||
XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1);
|
||||
XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0);
|
||||
}
|
||||
|
||||
func textClassifierOptionsWithModelPath(
|
||||
_ modelPath: String?) throws -> TextClassifierOptions {
|
||||
let modelPath = try XCTUnwrap(modelPath)
|
||||
|
||||
let textClassifierOptions = TextClassifierOptions();
|
||||
textClassifierOptions.baseOptions.modelAssetPath = modelPath;
|
||||
|
||||
return textClassifierOptions
|
||||
}
|
||||
|
||||
func assertCreateTextClassifierThrowsError(
|
||||
textClassifierOptions: TextClassifierOptions,
|
||||
expectedErrorDescription: String) {
|
||||
do {
|
||||
let textClassifier = try TextClassifier(options:textClassifierOptions)
|
||||
XCTAssertNil(textClassifier)
|
||||
}
|
||||
catch {
|
||||
assertEqualErrorDescriptions(
|
||||
error,
|
||||
expectedLocalizedDescription: expectedErrorDescription)
|
||||
}
|
||||
}
|
||||
|
||||
func assertResultsForClassify(
|
||||
text: String,
|
||||
using textClassifier: TextClassifier,
|
||||
equals expectedCategories: [ResultCategory]) throws {
|
||||
let textClassifierResult =
|
||||
try XCTUnwrap(
|
||||
textClassifier.classify(text: text));
|
||||
assertTextClassifierResultHasOneHead(textClassifierResult);
|
||||
assertEqualCategoryArrays(
|
||||
categoryArray:
|
||||
textClassifierResult.classificationResult.classifications[0].categories,
|
||||
expectedCategoryArray: expectedCategories);
|
||||
}
|
||||
|
||||
func testCreateTextClassifierWithInvalidMaxResultsFails() throws {
|
||||
let textClassifierOptions =
|
||||
try XCTUnwrap(
|
||||
textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
|
||||
textClassifierOptions.maxResults = 0
|
||||
|
||||
assertCreateTextClassifierThrowsError(
|
||||
textClassifierOptions: textClassifierOptions,
|
||||
expectedErrorDescription: """
|
||||
INVALID_ARGUMENT: Invalid `max_results` option: value must be != 0.
|
||||
""")
|
||||
}
|
||||
|
||||
func testCreateTextClassifierWithCategoryAllowlistandDenylistFails() throws {
|
||||
|
||||
let textClassifierOptions =
|
||||
try XCTUnwrap(
|
||||
textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
|
||||
textClassifierOptions.categoryAllowlist = ["positive"]
|
||||
textClassifierOptions.categoryDenylist = ["positive"]
|
||||
|
||||
assertCreateTextClassifierThrowsError(
|
||||
textClassifierOptions: textClassifierOptions,
|
||||
expectedErrorDescription: """
|
||||
INVALID_ARGUMENT: `category_allowlist` and `category_denylist` are \
|
||||
mutually exclusive options.
|
||||
""")
|
||||
}
|
||||
|
||||
func testClassifyWithBertSucceeds() throws {
|
||||
|
||||
let modelPath = try XCTUnwrap(TextClassifierTests.kBertModelPath)
|
||||
let textClassifier = try XCTUnwrap(TextClassifier(modelPath: modelPath))
|
||||
|
||||
try assertResultsForClassify(
|
||||
text: TextClassifierTests.kNegativeText,
|
||||
using: textClassifier,
|
||||
equals: TextClassifierTests.kBertNegativeTextResults)
|
||||
}
|
||||
|
||||
func testClassifyWithMaxResultsSucceeds() throws {
|
||||
let textClassifierOptions =
|
||||
try XCTUnwrap(
|
||||
textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
|
||||
textClassifierOptions.maxResults = 1
|
||||
|
||||
let textClassifier =
|
||||
try XCTUnwrap(TextClassifier(options: textClassifierOptions))
|
||||
|
||||
try assertResultsForClassify(
|
||||
text: TextClassifierTests.kNegativeText,
|
||||
using: textClassifier,
|
||||
equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases)
|
||||
}
|
||||
|
||||
func testClassifyWithCategoryAllowlistSucceeds() throws {
|
||||
let textClassifierOptions =
|
||||
try XCTUnwrap(
|
||||
textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
|
||||
textClassifierOptions.categoryAllowlist = ["negative"];
|
||||
|
||||
let textClassifier =
|
||||
try XCTUnwrap(TextClassifier(options: textClassifierOptions))
|
||||
|
||||
try assertResultsForClassify(
|
||||
text: TextClassifierTests.kNegativeText,
|
||||
using: textClassifier,
|
||||
equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases)
|
||||
}
|
||||
|
||||
func testClassifyWithCategoryDenylistSucceeds() throws {
|
||||
let textClassifierOptions =
|
||||
try XCTUnwrap(
|
||||
textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
|
||||
textClassifierOptions.categoryDenylist = ["positive"];
|
||||
|
||||
let textClassifier =
|
||||
try XCTUnwrap(TextClassifier(options: textClassifierOptions))
|
||||
|
||||
try assertResultsForClassify(
|
||||
text: TextClassifierTests.kNegativeText,
|
||||
using: textClassifier,
|
||||
equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases)
|
||||
}
|
||||
|
||||
func testClassifyWithScoreThresholdSucceeds() throws {
|
||||
let textClassifierOptions =
|
||||
try XCTUnwrap(
|
||||
textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
|
||||
textClassifierOptions.scoreThreshold = 0.5;
|
||||
|
||||
let textClassifier =
|
||||
try XCTUnwrap(TextClassifier(options: textClassifierOptions))
|
||||
|
||||
try assertResultsForClassify(
|
||||
text: TextClassifierTests.kNegativeText,
|
||||
using: textClassifier,
|
||||
equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases)
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user