Added swift and objective tests for iOS text classifier
This commit is contained in:
		
							parent
							
								
									9e0b85c9b5
								
							
						
					
					
						commit
						2a53d78ae4
					
				
							
								
								
									
										82
									
								
								mediapipe/tasks/ios/test/text/text_classifier/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								mediapipe/tasks/ios/test/text/text_classifier/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,82 @@
 | 
			
		|||
load(
 | 
			
		||||
    "@build_bazel_rules_apple//apple:ios.bzl",
 | 
			
		||||
    "ios_unit_test",
 | 
			
		||||
)
 | 
			
		||||
load(
 | 
			
		||||
    "@org_tensorflow//tensorflow/lite:special_rules.bzl",
 | 
			
		||||
    "tflite_ios_lab_runner"
 | 
			
		||||
)
 | 
			
		||||
load(
 | 
			
		||||
    "@build_bazel_rules_swift//swift:swift.bzl", 
 | 
			
		||||
    "swift_library"
 | 
			
		||||
)
 | 
			
		||||
load(
 | 
			
		||||
    "//mediapipe/tasks:ios/ios.bzl", 
 | 
			
		||||
    "MPP_TASK_MINIMUM_OS_VERSION"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
 | 
			
		||||
TFL_DEFAULT_TAGS = [
 | 
			
		||||
    "apple",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
# Following sanitizer tests are not supported by iOS test targets.
 | 
			
		||||
TFL_DISABLED_SANITIZER_TAGS = [
 | 
			
		||||
    "noasan",
 | 
			
		||||
    "nomsan",
 | 
			
		||||
    "notsan",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
objc_library(
 | 
			
		||||
    name = "MPPTextClassifierObjcTestLibrary",
 | 
			
		||||
    testonly = 1,
 | 
			
		||||
    srcs = ["MPPTextClassifierTests.m"],
 | 
			
		||||
    data = [
 | 
			
		||||
        "//mediapipe/tasks/testdata/text:bert_text_classifier_models",
 | 
			
		||||
        "//mediapipe/tasks/testdata/text:text_classifier_models",
 | 
			
		||||
    ],
 | 
			
		||||
    tags = [],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier",
 | 
			
		||||
    ],
 | 
			
		||||
    
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
ios_unit_test(
 | 
			
		||||
    name = "MPPTextClassifierObjcTest",
 | 
			
		||||
    minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
 | 
			
		||||
    runner = tflite_ios_lab_runner("IOS_LATEST"),
 | 
			
		||||
    tags =[],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":MPPTextClassifierObjcTestLibrary",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
swift_library(
 | 
			
		||||
    name = "MPPTextClassifierSwiftTestLibrary",
 | 
			
		||||
    testonly = 1,
 | 
			
		||||
    srcs = ["TextClassifierTests.swift"],
 | 
			
		||||
    data = [
 | 
			
		||||
        "//mediapipe/tasks/testdata/text:bert_text_classifier_models",
 | 
			
		||||
        "//mediapipe/tasks/testdata/text:text_classifier_models",
 | 
			
		||||
    ],
 | 
			
		||||
    tags = TFL_DEFAULT_TAGS,
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/tasks/ios/common:MPPCommon",
 | 
			
		||||
        "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
ios_unit_test(
 | 
			
		||||
    name = "MPPTextClassifierSwiftTest",
 | 
			
		||||
    minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
 | 
			
		||||
    runner = tflite_ios_lab_runner("IOS_LATEST"),
 | 
			
		||||
    tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":MPPTextClassifierSwiftTestLibrary",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,281 @@
 | 
			
		|||
// Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#import <XCTest/XCTest.h>
 | 
			
		||||
 | 
			
		||||
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
 | 
			
		||||
 | 
			
		||||
static NSString *const kBertTextClassifierModelName = @"bert_text_classifier";
 | 
			
		||||
static NSString *const kRegexTextClassifierModelName =
 | 
			
		||||
    @"test_model_text_classifier_with_regex_tokenizer";
 | 
			
		||||
static NSString *const kNegativeText = @"unflinchingly bleak and desperate";
 | 
			
		||||
static NSString *const kPositiveText = @"it's a charming and often affecting journey";
 | 
			
		||||
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
 | 
			
		||||
 | 
			
		||||
#define AssertEqualErrors(error, expectedError)                                               \
 | 
			
		||||
  XCTAssertNotNil(error);                                                                     \
 | 
			
		||||
  XCTAssertEqualObjects(error.domain, expectedError.domain);                                  \
 | 
			
		||||
  XCTAssertEqual(error.code, expectedError.code);                                             \
 | 
			
		||||
  XCTAssertNotEqual(                                                                          \
 | 
			
		||||
      [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
 | 
			
		||||
      NSNotFound)
 | 
			
		||||
 | 
			
		||||
#define AssertEqualCategoryArrays(categories, expectedCategories)                          \
 | 
			
		||||
  XCTAssertEqual(categories.count, expectedCategories.count);                              \
 | 
			
		||||
  for (int i = 0; i < categories.count; i++) {                                             \
 | 
			
		||||
    XCTAssertEqual(categories[i].index, expectedCategories[i].index);                      \
 | 
			
		||||
    XCTAssertEqualWithAccuracy(categories[i].score, expectedCategories[i].score, 1e-6);    \
 | 
			
		||||
    XCTAssertEqualObjects(categories[i].categoryName, expectedCategories[i].categoryName); \
 | 
			
		||||
    XCTAssertEqualObjects(categories[i].displayName, expectedCategories[i].displayName);   \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#define AssertTextClassifierResultHasOneHead(textClassifierResult)                    \
 | 
			
		||||
  XCTAssertNotNil(textClassifierResult);                                              \
 | 
			
		||||
  \      
 | 
			
		||||
  XCTAssertNotNil(textClassifierResult.classificationResult);                         \
 | 
			
		||||
  XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); \
 | 
			
		||||
  XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0);
 | 
			
		||||
 | 
			
		||||
@interface MPPTextClassifierTests : XCTestCase
 | 
			
		||||
@end
 | 
			
		||||
 | 
			
		||||
@implementation MPPTextClassifierTests
 | 
			
		||||
 | 
			
		||||
- (void)setUp {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)tearDown {
 | 
			
		||||
  // Put teardown code here. This method is called after the invocation of each test method in the
 | 
			
		||||
  // class.
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
+ (NSArray<MPPCategory *> *)expectedBertResultCategoriesForNegativeText {
 | 
			
		||||
  return @[
 | 
			
		||||
    [[MPPCategory alloc] initWithIndex:0 score:0.956187f categoryName:@"negative" displayName:nil],
 | 
			
		||||
    [[MPPCategory alloc] initWithIndex:1 score:0.043812f categoryName:@"positive" displayName:nil]
 | 
			
		||||
  ];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
+ (NSArray<MPPCategory *> *)expectedBertResultCategoriesForPositiveText {
 | 
			
		||||
  return @[
 | 
			
		||||
    [[MPPCategory alloc] initWithIndex:1 score:0.999945f categoryName:@"positive" displayName:nil],
 | 
			
		||||
    [[MPPCategory alloc] initWithIndex:0 score:0.000055f categoryName:@"negative" displayName:nil]
 | 
			
		||||
  ];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
+ (NSArray<MPPCategory *> *)expectedRegexResultCategoriesForNegativeText {
 | 
			
		||||
  return @[
 | 
			
		||||
    [[MPPCategory alloc] initWithIndex:0 score:0.6647746f categoryName:@"Negative" displayName:nil],
 | 
			
		||||
    [[MPPCategory alloc] initWithIndex:1 score:0.33522537 categoryName:@"Positive" displayName:nil]
 | 
			
		||||
  ];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
+ (NSArray<MPPCategory *> *)expectedRegexResultCategoriesForPositiveText {
 | 
			
		||||
  return @[
 | 
			
		||||
    [[MPPCategory alloc] initWithIndex:0 score:0.5120041f categoryName:@"Negative" displayName:nil],
 | 
			
		||||
    [[MPPCategory alloc] initWithIndex:1 score:0.48799595 categoryName:@"Positive" displayName:nil]
 | 
			
		||||
  ];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
+ (NSArray<MPPCategory *> *)expectedBertResultCategoriesForEdgeCaseTests {
 | 
			
		||||
  return @[ [[MPPCategory alloc] initWithIndex:0
 | 
			
		||||
                                         score:0.956187f
 | 
			
		||||
                                  categoryName:@"negative"
 | 
			
		||||
                                   displayName:nil] ];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
 | 
			
		||||
  NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName
 | 
			
		||||
                                                                      ofType:extension];
 | 
			
		||||
  return filePath;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName {
 | 
			
		||||
  NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
 | 
			
		||||
  MPPTextClassifierOptions *textClassifierOptions = [[MPPTextClassifierOptions alloc] init];
 | 
			
		||||
  textClassifierOptions.baseOptions.modelAssetPath = modelPath;
 | 
			
		||||
 | 
			
		||||
  return textClassifierOptions;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (MPPTextClassifier *)textClassifierFromModelFileWithName:(NSString *)modelName {
 | 
			
		||||
  NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
 | 
			
		||||
  MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath
 | 
			
		||||
                                                                             error:nil];
 | 
			
		||||
  XCTAssertNotNil(textClassifier);
 | 
			
		||||
 | 
			
		||||
  return textClassifier;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)assertCreateTextClassifierWithOptions:(MPPTextClassifierOptions *)textClassifierOptions
 | 
			
		||||
                       failsWithExpectedError:(NSError *)expectedError {
 | 
			
		||||
  NSError *error = nil;
 | 
			
		||||
  MPPTextClassifier *textClassifier =
 | 
			
		||||
      [[MPPTextClassifier alloc] initWithOptions:textClassifierOptions error:&error];
 | 
			
		||||
  XCTAssertNil(textClassifier);
 | 
			
		||||
  AssertEqualErrors(error, expectedError);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)assertResultsOfClassifyText:(NSString *)text
 | 
			
		||||
                usingTextClassifier:(MPPTextClassifier *)textClassifier
 | 
			
		||||
                   equalsCategories:(NSArray<MPPCategory *> *)expectedCategories {
 | 
			
		||||
  MPPTextClassifierResult *negativeResult = [textClassifier classifyText:text error:nil];
 | 
			
		||||
  AssertTextClassifierResultHasOneHead(negativeResult);
 | 
			
		||||
  AssertEqualCategoryArrays(negativeResult.classificationResult.classifications[0].categories,
 | 
			
		||||
                            expectedCategories);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testCreateTextClassifierFailsWithMissingModelPath {
 | 
			
		||||
  NSString *modelPath = [self filePathWithName:@"" extension:@""];
 | 
			
		||||
 | 
			
		||||
  NSError *error = nil;
 | 
			
		||||
  MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath
 | 
			
		||||
                                                                             error:&error];
 | 
			
		||||
  XCTAssertNil(textClassifier);
 | 
			
		||||
 | 
			
		||||
  NSError *expectedError = [NSError
 | 
			
		||||
      errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                 code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
             userInfo:@{
 | 
			
		||||
               NSLocalizedDescriptionKey :
 | 
			
		||||
                   @"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', "
 | 
			
		||||
                   @"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."
 | 
			
		||||
             }];
 | 
			
		||||
  AssertEqualErrors(error, expectedError);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testCreateTextClassifierFailsWithBothAllowListAndDenyList {
 | 
			
		||||
  MPPTextClassifierOptions *options =
 | 
			
		||||
      [self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
 | 
			
		||||
  options.categoryAllowlist = @[ @"positive" ];
 | 
			
		||||
  options.categoryDenylist = @[ @"negative" ];
 | 
			
		||||
 | 
			
		||||
  [self assertCreateTextClassifierWithOptions:options
 | 
			
		||||
                       failsWithExpectedError:
 | 
			
		||||
                           [NSError
 | 
			
		||||
                               errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                                      userInfo:@{
 | 
			
		||||
                                        NSLocalizedDescriptionKey :
 | 
			
		||||
                                            @"INVALID_ARGUMENT: `category_allowlist` and "
 | 
			
		||||
                                            @"`category_denylist` are mutually exclusive options."
 | 
			
		||||
                                      }]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testCreateTextClassifierFailsWithInvalidMaxResults {
 | 
			
		||||
  MPPTextClassifierOptions *options =
 | 
			
		||||
      [self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
 | 
			
		||||
  options.maxResults = 0;
 | 
			
		||||
 | 
			
		||||
  [self assertCreateTextClassifierWithOptions:options
 | 
			
		||||
                       failsWithExpectedError:
 | 
			
		||||
                           [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                                               code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                                           userInfo:@{
 | 
			
		||||
                                             NSLocalizedDescriptionKey :
 | 
			
		||||
                                                 @"INVALID_ARGUMENT: Invalid `max_results` option: "
 | 
			
		||||
                                                 @"value must be != 0."
 | 
			
		||||
                                           }]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testClassifyWithBertSucceeds {
 | 
			
		||||
  MPPTextClassifier *textClassifier =
 | 
			
		||||
      [self textClassifierFromModelFileWithName:kBertTextClassifierModelName];
 | 
			
		||||
 | 
			
		||||
  [self assertResultsOfClassifyText:kNegativeText
 | 
			
		||||
                usingTextClassifier:textClassifier
 | 
			
		||||
                   equalsCategories:[MPPTextClassifierTests
 | 
			
		||||
                                        expectedBertResultCategoriesForNegativeText]];
 | 
			
		||||
 | 
			
		||||
  [self assertResultsOfClassifyText:kPositiveText
 | 
			
		||||
                usingTextClassifier:textClassifier
 | 
			
		||||
                   equalsCategories:[MPPTextClassifierTests
 | 
			
		||||
                                        expectedBertResultCategoriesForPositiveText]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testClassifyWithRegexSucceeds {
 | 
			
		||||
  MPPTextClassifier *textClassifier =
 | 
			
		||||
      [self textClassifierFromModelFileWithName:kRegexTextClassifierModelName];
 | 
			
		||||
 | 
			
		||||
  [self assertResultsOfClassifyText:kNegativeText
 | 
			
		||||
                usingTextClassifier:textClassifier
 | 
			
		||||
                   equalsCategories:[MPPTextClassifierTests
 | 
			
		||||
                                        expectedRegexResultCategoriesForNegativeText]];
 | 
			
		||||
  [self assertResultsOfClassifyText:kPositiveText
 | 
			
		||||
                usingTextClassifier:textClassifier
 | 
			
		||||
                   equalsCategories:[MPPTextClassifierTests
 | 
			
		||||
                                        expectedRegexResultCategoriesForPositiveText]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testClassifyWithMaxResultsSucceeds {
 | 
			
		||||
  MPPTextClassifierOptions *options =
 | 
			
		||||
      [self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
 | 
			
		||||
  options.maxResults = 1;
 | 
			
		||||
 | 
			
		||||
  MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
 | 
			
		||||
  XCTAssertNotNil(textClassifier);
 | 
			
		||||
 | 
			
		||||
  [self assertResultsOfClassifyText:kNegativeText
 | 
			
		||||
                usingTextClassifier:textClassifier
 | 
			
		||||
                   equalsCategories:[MPPTextClassifierTests
 | 
			
		||||
                                        expectedBertResultCategoriesForEdgeCaseTests]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testClassifyWithCategoryAllowListSucceeds {
 | 
			
		||||
  MPPTextClassifierOptions *options =
 | 
			
		||||
      [self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
 | 
			
		||||
  options.categoryAllowlist = @[ @"negative" ];
 | 
			
		||||
 | 
			
		||||
  NSError *error = nil;
 | 
			
		||||
  MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options
 | 
			
		||||
                                                                           error:&error];
 | 
			
		||||
  XCTAssertNotNil(textClassifier);
 | 
			
		||||
  XCTAssertNil(error);
 | 
			
		||||
 | 
			
		||||
  [self assertResultsOfClassifyText:kNegativeText
 | 
			
		||||
                usingTextClassifier:textClassifier
 | 
			
		||||
                   equalsCategories:[MPPTextClassifierTests
 | 
			
		||||
                                        expectedBertResultCategoriesForEdgeCaseTests]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testClassifyWithCategoryDenyListSucceeds {
 | 
			
		||||
  MPPTextClassifierOptions *options =
 | 
			
		||||
      [self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
 | 
			
		||||
  options.categoryDenylist = @[ @"positive" ];
 | 
			
		||||
 | 
			
		||||
  MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
 | 
			
		||||
  XCTAssertNotNil(textClassifier);
 | 
			
		||||
 | 
			
		||||
  [self assertResultsOfClassifyText:kNegativeText
 | 
			
		||||
                usingTextClassifier:textClassifier
 | 
			
		||||
                   equalsCategories:[MPPTextClassifierTests
 | 
			
		||||
                                        expectedBertResultCategoriesForEdgeCaseTests]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testClassifyWithScoreThresholdSucceeds {
 | 
			
		||||
  MPPTextClassifierOptions *options =
 | 
			
		||||
      [self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
 | 
			
		||||
  options.scoreThreshold = 0.5f;
 | 
			
		||||
 | 
			
		||||
  MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
 | 
			
		||||
  XCTAssertNotNil(textClassifier);
 | 
			
		||||
 | 
			
		||||
  [self assertResultsOfClassifyText:kNegativeText
 | 
			
		||||
                usingTextClassifier:textClassifier
 | 
			
		||||
                   equalsCategories:[MPPTextClassifierTests
 | 
			
		||||
                                        expectedBertResultCategoriesForEdgeCaseTests]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,237 @@
 | 
			
		|||
// Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
import XCTest
 | 
			
		||||
 | 
			
		||||
import MPPCommon
 | 
			
		||||
 | 
			
		||||
@testable import MPPTextClassifier
 | 
			
		||||
 | 
			
		||||
class TextClassifierTests: XCTestCase {
 | 
			
		||||
 | 
			
		||||
  static let bundle = Bundle(for: TextClassifierTests.self)
 | 
			
		||||
  
 | 
			
		||||
  static let kBertModelPath = bundle.path(
 | 
			
		||||
    forResource: "bert_text_classifier",
 | 
			
		||||
    ofType: "tflite")
 | 
			
		||||
  
 | 
			
		||||
  static let kPositiveText = "it's a charming and often affecting journey"
 | 
			
		||||
 | 
			
		||||
  static let kNegativeText = "unflinchingly bleak and desperate"
 | 
			
		||||
 | 
			
		||||
  static let kBertNegativeTextResults = [
 | 
			
		||||
      ResultCategory(
 | 
			
		||||
        index: 0, 
 | 
			
		||||
        score: 0.956187, 
 | 
			
		||||
        categoryName: "negative", 
 | 
			
		||||
        displayName: nil),
 | 
			
		||||
      ResultCategory(
 | 
			
		||||
        index: 1, 
 | 
			
		||||
        score: 0.043812, 
 | 
			
		||||
        categoryName: "positive", 
 | 
			
		||||
        displayName: nil)
 | 
			
		||||
      ]
 | 
			
		||||
 | 
			
		||||
  static let kBertNegativeTextResultsForEdgeTestCases = [
 | 
			
		||||
      ResultCategory(
 | 
			
		||||
        index: 0, 
 | 
			
		||||
        score: 0.956187, 
 | 
			
		||||
        categoryName: "negative", 
 | 
			
		||||
        displayName: nil),
 | 
			
		||||
      ]
 | 
			
		||||
 | 
			
		||||
  func assertEqualErrorDescriptions(
 | 
			
		||||
    _ error: Error, expectedLocalizedDescription:String) {
 | 
			
		||||
   XCTAssertEqual(
 | 
			
		||||
      error.localizedDescription,
 | 
			
		||||
      expectedLocalizedDescription)
 | 
			
		||||
  }
 | 
			
		||||
  
 | 
			
		||||
  func assertCategoriesAreEqual(
 | 
			
		||||
    category: ResultCategory, 
 | 
			
		||||
    expectedCategory: ResultCategory) {
 | 
			
		||||
     XCTAssertEqual(
 | 
			
		||||
      category.index,
 | 
			
		||||
      expectedCategory.index)
 | 
			
		||||
    XCTAssertEqual(
 | 
			
		||||
      category.score,
 | 
			
		||||
      expectedCategory.score,
 | 
			
		||||
      accuracy:1e-6)
 | 
			
		||||
    XCTAssertEqual(
 | 
			
		||||
      category.categoryName,
 | 
			
		||||
      expectedCategory.categoryName)
 | 
			
		||||
    XCTAssertEqual(
 | 
			
		||||
      category.displayName,
 | 
			
		||||
      expectedCategory.displayName)
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  func assertEqualCategoryArrays(
 | 
			
		||||
    categoryArray: [ResultCategory], 
 | 
			
		||||
    expectedCategoryArray:[ResultCategory]) {
 | 
			
		||||
 | 
			
		||||
    XCTAssertEqual(categoryArray.count, expectedCategoryArray.count)
 | 
			
		||||
 | 
			
		||||
    for (category, expectedCategory) in 
 | 
			
		||||
      zip(categoryArray, expectedCategoryArray)  {
 | 
			
		||||
      assertCategoriesAreEqual(
 | 
			
		||||
        category:category, 
 | 
			
		||||
        expectedCategory:expectedCategory)
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  
 | 
			
		||||
  func assertTextClassifierResultHasOneHead(
 | 
			
		||||
    _ textClassifierResult: TextClassifierResult) {
 | 
			
		||||
    XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1);
 | 
			
		||||
    XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  func textClassifierOptionsWithModelPath(
 | 
			
		||||
    _ modelPath: String?) throws -> TextClassifierOptions {
 | 
			
		||||
    let modelPath = try XCTUnwrap(modelPath)
 | 
			
		||||
 | 
			
		||||
    let textClassifierOptions = TextClassifierOptions();
 | 
			
		||||
    textClassifierOptions.baseOptions.modelAssetPath = modelPath;
 | 
			
		||||
 | 
			
		||||
    return textClassifierOptions
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  func assertCreateTextClassifierThrowsError(
 | 
			
		||||
    textClassifierOptions: TextClassifierOptions,
 | 
			
		||||
    expectedErrorDescription: String) {
 | 
			
		||||
    do {
 | 
			
		||||
      let textClassifier = try TextClassifier(options:textClassifierOptions)
 | 
			
		||||
      XCTAssertNil(textClassifier)
 | 
			
		||||
    }
 | 
			
		||||
    catch {
 | 
			
		||||
      assertEqualErrorDescriptions(
 | 
			
		||||
        error, 
 | 
			
		||||
        expectedLocalizedDescription: expectedErrorDescription)
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  func assertResultsForClassify(
 | 
			
		||||
    text: String, 
 | 
			
		||||
    using textClassifier: TextClassifier,
 | 
			
		||||
    equals expectedCategories: [ResultCategory]) throws {
 | 
			
		||||
    let textClassifierResult = 
 | 
			
		||||
      try XCTUnwrap(
 | 
			
		||||
        textClassifier.classify(text: text));
 | 
			
		||||
    assertTextClassifierResultHasOneHead(textClassifierResult);
 | 
			
		||||
    assertEqualCategoryArrays(
 | 
			
		||||
      categoryArray:
 | 
			
		||||
        textClassifierResult.classificationResult.classifications[0].categories,
 | 
			
		||||
      expectedCategoryArray: expectedCategories);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  func testCreateTextClassifierWithInvalidMaxResultsFails() throws {
 | 
			
		||||
    let textClassifierOptions = 
 | 
			
		||||
      try XCTUnwrap(
 | 
			
		||||
        textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
 | 
			
		||||
    textClassifierOptions.maxResults = 0
 | 
			
		||||
 | 
			
		||||
    assertCreateTextClassifierThrowsError(
 | 
			
		||||
      textClassifierOptions: textClassifierOptions,
 | 
			
		||||
      expectedErrorDescription: """
 | 
			
		||||
          INVALID_ARGUMENT: Invalid `max_results` option: value must be != 0.
 | 
			
		||||
          """)
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  func testCreateTextClassifierWithCategoryAllowlistandDenylistFails() throws {
 | 
			
		||||
 | 
			
		||||
    let textClassifierOptions = 
 | 
			
		||||
      try XCTUnwrap(
 | 
			
		||||
        textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
 | 
			
		||||
    textClassifierOptions.categoryAllowlist = ["positive"]
 | 
			
		||||
    textClassifierOptions.categoryDenylist = ["positive"]
 | 
			
		||||
 | 
			
		||||
    assertCreateTextClassifierThrowsError(
 | 
			
		||||
      textClassifierOptions: textClassifierOptions,
 | 
			
		||||
      expectedErrorDescription: """
 | 
			
		||||
          INVALID_ARGUMENT: `category_allowlist` and `category_denylist` are \
 | 
			
		||||
          mutually exclusive options.
 | 
			
		||||
          """)
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  func testClassifyWithBertSucceeds() throws {
 | 
			
		||||
 | 
			
		||||
    let modelPath = try XCTUnwrap(TextClassifierTests.kBertModelPath)
 | 
			
		||||
    let textClassifier = try XCTUnwrap(TextClassifier(modelPath: modelPath))
 | 
			
		||||
    
 | 
			
		||||
    try assertResultsForClassify(
 | 
			
		||||
        text: TextClassifierTests.kNegativeText,
 | 
			
		||||
        using: textClassifier,
 | 
			
		||||
        equals: TextClassifierTests.kBertNegativeTextResults)
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  func testClassifyWithMaxResultsSucceeds() throws {
 | 
			
		||||
    let textClassifierOptions = 
 | 
			
		||||
      try XCTUnwrap(
 | 
			
		||||
        textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
 | 
			
		||||
    textClassifierOptions.maxResults = 1
 | 
			
		||||
 | 
			
		||||
    let textClassifier = 
 | 
			
		||||
      try XCTUnwrap(TextClassifier(options: textClassifierOptions))
 | 
			
		||||
 | 
			
		||||
    try assertResultsForClassify(
 | 
			
		||||
        text: TextClassifierTests.kNegativeText,
 | 
			
		||||
        using: textClassifier,
 | 
			
		||||
        equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases)
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  func testClassifyWithCategoryAllowlistSucceeds() throws {
 | 
			
		||||
    let textClassifierOptions = 
 | 
			
		||||
      try XCTUnwrap(
 | 
			
		||||
        textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
 | 
			
		||||
    textClassifierOptions.categoryAllowlist = ["negative"];
 | 
			
		||||
 | 
			
		||||
    let textClassifier = 
 | 
			
		||||
      try XCTUnwrap(TextClassifier(options: textClassifierOptions))
 | 
			
		||||
    
 | 
			
		||||
    try assertResultsForClassify(
 | 
			
		||||
        text: TextClassifierTests.kNegativeText,
 | 
			
		||||
        using: textClassifier,
 | 
			
		||||
        equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases)
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  func testClassifyWithCategoryDenylistSucceeds() throws {
 | 
			
		||||
    let textClassifierOptions = 
 | 
			
		||||
      try XCTUnwrap(
 | 
			
		||||
        textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
 | 
			
		||||
    textClassifierOptions.categoryDenylist = ["positive"];
 | 
			
		||||
 | 
			
		||||
    let textClassifier = 
 | 
			
		||||
      try XCTUnwrap(TextClassifier(options: textClassifierOptions))
 | 
			
		||||
    
 | 
			
		||||
    try assertResultsForClassify(
 | 
			
		||||
        text: TextClassifierTests.kNegativeText,
 | 
			
		||||
        using: textClassifier,
 | 
			
		||||
        equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases)
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  func testClassifyWithScoreThresholdSucceeds() throws {
 | 
			
		||||
    let textClassifierOptions = 
 | 
			
		||||
      try XCTUnwrap(
 | 
			
		||||
        textClassifierOptionsWithModelPath(TextClassifierTests.kBertModelPath))
 | 
			
		||||
    textClassifierOptions.scoreThreshold = 0.5;
 | 
			
		||||
 | 
			
		||||
    let textClassifier = 
 | 
			
		||||
      try XCTUnwrap(TextClassifier(options: textClassifierOptions))
 | 
			
		||||
    
 | 
			
		||||
    try assertResultsForClassify(
 | 
			
		||||
        text: TextClassifierTests.kNegativeText,
 | 
			
		||||
        using: textClassifier,
 | 
			
		||||
        equals: TextClassifierTests.kBertNegativeTextResultsForEdgeTestCases)
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user