Merge pull request #3995 from priankakariatyml:ios-text-classifier-tests

PiperOrigin-RevId: 503242486
This commit is contained in:
Copybara-Service 2023-01-19 12:59:49 -08:00
commit 4b9a52dc34
12 changed files with 764 additions and 220 deletions

View File

@ -18,160 +18,92 @@ NS_ASSUME_NONNULL_BEGIN
/** /**
* @enum MPPTasksErrorCode * @enum MPPTasksErrorCode
* This enum specifies error codes for MediaPipe Task Library. * This enum specifies error codes for errors thrown by iOS MediaPipe Task Library.
* It maintains a 1:1 mapping to MediaPipeTasksStatus of the C ++libray.
*/ */
typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) {
// Generic error codes. // Generic error codes.
// Unspecified error. /** Indicates the operation was cancelled, typically by the caller. */
MPPTasksErrorCodeError = 1, MPPTasksErrorCodeCancelledError = 1,
// Invalid argument specified.
MPPTasksErrorCodeInvalidArgumentError = 2,
// Invalid FlatBuffer file or buffer specified.
MPPTasksErrorCodeInvalidFlatBufferError = 3,
// Model contains a builtin op that isn't supported by the OpResolver or
// delegates.
MPPTasksErrorCodeUnsupportedBuiltinOp = 4,
// Model contains a custom op that isn't supported by the OpResolver or
// delegates.
MPPTasksErrorCodeUnsupportedCustomOp = 5,
// File I/O error codes. /** Indicates an unknown error occurred. */
MPPTasksErrorCodeUnknownError = 2,
// No such file. /** Indicates the caller specified an invalid argument, such as a malformed filename. */
MPPTasksErrorCodeFileNotFoundError = 100, MPPTasksErrorCodeInvalidArgumentError = 3,
// Permission issue.
MPPTasksErrorCodeFilePermissionDeniedError,
// I/O error when reading file.
MPPTasksErrorCodeFileReadError,
// I/O error when mmap-ing file.
MPPTasksErrorCodeFileMmapError,
// ZIP I/O error when unpacking the zip file.
MPPTasksErrorCodeFileZipError,
// TensorFlow Lite metadata error codes. /** Indicates a deadline expired before the operation could complete. */
MPPTasksErrorCodeDeadlineExceededError = 4,
// Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer. /** Indicates some requested entity (such as a file or directory) was not found. */
MPPTasksErrorCodeMetadataInvalidSchemaVersionError = 200, MPPTasksErrorCodeNotFoundError = 5,
// No such associated file within metadata, or file has not been packed.
MPPTasksErrorCodeMetadataAssociatedFileNotFoundError,
// ZIP I/O error when unpacking an associated file.
MPPTasksErrorCodeMetadataAssociatedFileZipError,
// Inconsistency error between the metadata and actual TF Lite model.
// E.g.: number of labels and output tensor values differ.
MPPTasksErrorCodeMetadataInconsistencyError,
// Invalid process units specified.
// E.g.: multiple ProcessUnits with the same type for a given tensor.
MPPTasksErrorCodeMetadataInvalidProcessUnitsError,
// Inconsistency error with the number of labels.
// E.g.: label files for different locales have a different number of labels.
MPPTasksErrorCodeMetadataNumLabelsMismatchError,
// Score calibration parameters parsing error.
// E.g.: too many parameters provided in the corresponding associated file.
MPPTasksErrorCodeMetadataMalformedScoreCalibrationError,
// Unexpected number of subgraphs for the current task.
// E.g.: image classification expects a single subgraph.
MPPTasksErrorCodeMetadataInvalidNumSubgraphsError,
// A given tensor requires NormalizationOptions but none were found.
// E.g.: float input tensor requires normalization to preprocess input images.
MPPTasksErrorCodeMetadataMissingNormalizationOptionsError,
// Invalid ContentProperties specified.
// E.g. expected ImageProperties, got BoundingBoxProperties.
MPPTasksErrorCodeMetadataInvalidContentPropertiesError,
// Metadata is mandatory but was not found.
// E.g. current task requires TFLite Model Metadata but none was found.
MPPTasksErrorCodeMetadataNotFoundError,
// Associated TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS file is mandatory but
// none was found or it was empty.
// E.g. current task requires labels but none were found.
MPPTasksErrorCodeMetadataMissingLabelsError,
// The ProcessingUnit for tokenizer is not correctly configured.
// E.g BertTokenizer doesn't have a valid vocab file associated.
MPPTasksErrorCodeMetadataInvalidTokenizerError,
// Input tensor(s) error codes. /**
* Indicates that the entity a caller attempted to create (such as a file or directory) is
* already present.
*/
MPPTasksErrorCodeAlreadyExistsError = 6,
// Unexpected number of input tensors for the current task. /** Indicates that the caller does not have permission to execute the specified operation. */
// E.g. current task expects a single input tensor. MPPTasksErrorCodePermissionDeniedError = 7,
MPPTasksErrorCodeInvalidNumInputTensorsError = 300,
// Unexpected input tensor dimensions for the current task.
// E.g.: only 4D input tensors supported.
MPPTasksErrorCodeInvalidInputTensorDimensionsError,
// Unexpected input tensor type for the current task.
// E.g.: current task expects a uint8 pixel image as input.
MPPTasksErrorCodeInvalidInputTensorTypeError,
// Unexpected input tensor bytes size.
// E.g.: size in bytes does not correspond to the expected number of pixels.
MPPTasksErrorCodeInvalidInputTensorSizeError,
// No correct input tensor found for the model.
// E.g.: input tensor name is not part of the text model's input tensors.
MPPTasksErrorCodeInputTensorNotFoundError,
// Output tensor(s) error codes. /**
* Indicates some resource has been exhausted, perhaps a per-user quota, or perhaps the entire
* file system is out of space.
*/
MPPTasksErrorCodeResourceExhaustedError = 8,
// Unexpected output tensor dimensions for the current task. /**
// E.g.: only a batch size of 1 is supported. * Indicates that the operation was rejected because the system is not in a state required for
MPPTasksErrorCodeInvalidOutputTensorDimensionsError = 400, * the operation's execution. For example, a directory to be deleted may be non-empty, an "rmdir"
// Unexpected input tensor type for the current task. * operation is applied to a non-directory, etc.
// E.g.: multi-head model with different output tensor types. */
MPPTasksErrorCodeInvalidOutputTensorTypeError, MPPTasksErrorCodeFailedPreconditionError = 9,
// No correct output tensor found for the model.
// E.g.: output tensor name is not part of the text model's output tensors.
MPPTasksErrorCodeOutputTensorNotFoundError,
// Unexpected number of output tensors for the current task.
// E.g.: current task expects a single output tensor.
MPPTasksErrorCodeInvalidNumOutputTensorsError,
// Image processing error codes. /**
* Indicates the operation was aborted, typically due to a concurrency issue such as a sequencer
* check failure or a failed transaction.
*/
MPPTasksErrorCodeAbortedError = 10,
// Unspecified image processing failures. /**
MPPTasksErrorCodeImageProcessingError = 500, * Indicates the operation was attempted past the valid range, such as seeking or reading past an
// Unexpected input or output buffer metadata. * end-of-file.
// E.g.: rotate RGBA buffer to Grayscale buffer by 90 degrees. */
MPPTasksErrorCodeImageProcessingInvalidArgumentError, MPPTasksErrorCodeOutOfRangeError = 11,
// Image processing operation failures.
// E.g. libyuv rotation failed for an unknown reason.
MPPTasksErrorCodeImageProcessingBackendError,
// Task runner error codes. /**
MPPTasksErrorCodeRunnerError = 600, * Indicates the operation is not implemented or supported in this service. In this case, the
// Task runner is not initialized. * operation should not be re-attempted.
MPPTasksErrorCodeRunnerInitializationError, */
// Task runner is not started successfully. MPPTasksErrorCodeUnimplementedError = 12,
MPPTasksErrorCodeRunnerFailsToStartError,
// Task runner is not started.
MPPTasksErrorCodeRunnerNotStartedError,
// Task runner API is called in the wrong processing mode.
MPPTasksErrorCodeRunnerApiCalledInWrongModeError,
// Task runner receives/produces invalid MediaPipe packet timestamp.
MPPTasksErrorCodeRunnerInvalidTimestampError,
// Task runner receives unexpected MediaPipe graph input packet.
// E.g. The packet type doesn't match the graph input stream's data type.
MPPTasksErrorCodeRunnerUnexpectedInputError,
// Task runner produces unexpected MediaPipe graph output packet.
// E.g. The number of output packets is not equal to the number of graph
// output streams.
MPPTasksErrorCodeRunnerUnexpectedOutputError,
// Task runner is not closed successfully.
MPPTasksErrorCodeRunnerFailsToCloseError,
// Task runner's model resources cache service is unavailable or the
// targeting model resources bundle is not found.
MPPTasksErrorCodeRunnerModelResourcesCacheServiceError,
// Task graph error codes. /**
MPPTasksErrorCodeGraphError = 700, * Indicates an internal error has occurred and some invariants expected by the underlying system
// Task graph is not implemented. * have not been satisfied. This error code is reserved for serious errors.
MPPTasksErrorCodeTaskGraphNotImplementedError, */
// Task graph config is invalid. MPPTasksErrorCodeInternalError = 13,
MPPTasksErrorCodeInvalidTaskGraphConfigError,
// The first error code in MPPTasksErrorCode (for internal use only). /**
MPPTasksErrorCodeFirst = MPPTasksErrorCodeError, * Indicates the service is currently unavailable and that this is most likely a transient
* condition.
*/
MPPTasksErrorCodeUnavailableError = 14,
// The last error code in MPPTasksErrorCode (for internal use only). /** Indicates that unrecoverable data loss or corruption has occurred. */
MPPTasksErrorCodeLast = MPPTasksErrorCodeInvalidTaskGraphConfigError, MPPTasksErrorCodeDataLossError = 15,
/**
* Indicates that the request does not have valid authentication credentials for the operation.
*/
MPPTasksErrorCodeUnauthenticatedError = 16,
/** The first error code in MPPTasksErrorCode (for internal use only). */
MPPTasksErrorCodeFirst = MPPTasksErrorCodeCancelledError,
/** The last error code in MPPTasksErrorCode (for internal use only). */
MPPTasksErrorCodeLast = MPPTasksErrorCodeUnauthenticatedError,
} NS_SWIFT_NAME(TasksErrorCode); } NS_SWIFT_NAME(TasksErrorCode);

View File

@ -25,6 +25,10 @@
/** Error domain of MediaPipe task library errors. */ /** Error domain of MediaPipe task library errors. */
NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks"; NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks";
namespace {
using absl::StatusCode;
}
@implementation MPPCommonUtils @implementation MPPCommonUtils
+ (void)createCustomError:(NSError **)error + (void)createCustomError:(NSError **)error
@ -67,68 +71,69 @@ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks";
if (status.ok()) { if (status.ok()) {
return YES; return YES;
} }
// Payload of absl::Status created by the MediaPipe task library stores an appropriate value of
// the enum MediaPipeTasksStatus. The integer value corresponding to the MediaPipeTasksStatus enum
// stored in the payload is extracted here to later map to the appropriate error code to be
// returned. In cases where the enum is not stored in (payload is NULL or the payload string
// cannot be converted to an integer), we set the error code value to be 1
// (MPPTasksErrorCodeError of MPPTasksErrorCode used in the iOS library to signify
// any errors not falling into other categories.) Since payload is of type absl::Cord that can be
// type cast into an absl::optional<std::string>, we use the std::stoi function to convert it into
// an integer code if possible.
NSUInteger genericErrorCode = MPPTasksErrorCodeError;
NSUInteger errorCode;
try {
// Try converting payload to integer if payload is not empty. Otherwise convert a string
// signifying generic error code MPPTasksErrorCodeError to integer.
errorCode =
(NSUInteger)std::stoi(static_cast<absl::optional<std::string>>(
status.GetPayload(mediapipe::tasks::kMediaPipeTasksPayload))
.value_or(std::to_string(genericErrorCode)));
} catch (std::invalid_argument &e) {
// If non empty payload string cannot be converted to an integer. Set error code to 1(kError).
errorCode = MPPTasksErrorCodeError;
}
// If errorCode is outside the range of enum values possible or is // Converts the absl status message to an NSString.
// MPPTasksErrorCodeError, we try to map the absl::Status::code() to assign
// appropriate MPPTasksErrorCode in default cases. Note:
// The mapping to absl::Status::code() is done to generate a more specific error code than
// MPPTasksErrorCodeError in cases when the payload can't be mapped to
// MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn
// returned without modification by MediaPipe cc library methods.
if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) {
switch (status.code()) {
case absl::StatusCode::kInternal:
errorCode = MPPTasksErrorCodeError;
break;
case absl::StatusCode::kInvalidArgument:
errorCode = MPPTasksErrorCodeInvalidArgumentError;
break;
case absl::StatusCode::kNotFound:
errorCode = MPPTasksErrorCodeError;
break;
default:
errorCode = MPPTasksErrorCodeError;
break;
}
}
// Creates the NSEror with the appropriate error
// MPPTasksErrorCode and message. MPPTasksErrorCode has a one to one
// mapping with MediaPipeTasksStatus starting from the value 1(MPPTasksErrorCodeError)
// and hence will be correctly initialized if directly cast from the integer code derived from
// MediaPipeTasksStatus stored in its payload. MPPTasksErrorCode omits kOk = 0 of
// MediaPipeTasksStatusx.
//
// Stores a string including absl status code and message(if non empty) as the
// error message See
// https://github.com/abseil/abseil-cpp/blob/master/absl/status/status.h#L514
// for explanation. absl::Status::message() can also be used but not always
// guaranteed to be non empty.
NSString *description = [NSString NSString *description = [NSString
stringWithCString:status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str() stringWithCString:status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str()
encoding:NSUTF8StringEncoding]; encoding:NSUTF8StringEncoding];
MPPTasksErrorCode errorCode = MPPTasksErrorCodeUnknownError;
// Maps the absl::StatusCode to the appropriate MPPTasksErrorCode. Note: MPPTasksErrorCode omits
// absl::StatusCode::kOk.
switch (status.code()) {
case StatusCode::kCancelled:
errorCode = MPPTasksErrorCodeCancelledError;
break;
case StatusCode::kUnknown:
errorCode = MPPTasksErrorCodeUnknownError;
break;
case StatusCode::kInvalidArgument:
errorCode = MPPTasksErrorCodeInvalidArgumentError;
break;
case StatusCode::kDeadlineExceeded:
errorCode = MPPTasksErrorCodeDeadlineExceededError;
break;
case StatusCode::kNotFound:
errorCode = MPPTasksErrorCodeNotFoundError;
break;
case StatusCode::kAlreadyExists:
errorCode = MPPTasksErrorCodeAlreadyExistsError;
break;
case StatusCode::kPermissionDenied:
errorCode = MPPTasksErrorCodePermissionDeniedError;
break;
case StatusCode::kResourceExhausted:
errorCode = MPPTasksErrorCodeResourceExhaustedError;
break;
case StatusCode::kFailedPrecondition:
errorCode = MPPTasksErrorCodeFailedPreconditionError;
break;
case StatusCode::kAborted:
errorCode = MPPTasksErrorCodeAbortedError;
break;
case StatusCode::kOutOfRange:
errorCode = MPPTasksErrorCodeOutOfRangeError;
break;
case StatusCode::kUnimplemented:
errorCode = MPPTasksErrorCodeUnimplementedError;
break;
case StatusCode::kInternal:
errorCode = MPPTasksErrorCodeInternalError;
break;
case StatusCode::kUnavailable:
errorCode = MPPTasksErrorCodeUnavailableError;
break;
case StatusCode::kDataLoss:
errorCode = MPPTasksErrorCodeDataLossError;
break;
case StatusCode::kUnauthenticated:
errorCode = MPPTasksErrorCodeUnauthenticatedError;
break;
default:
break;
}
[MPPCommonUtils createCustomError:error withCode:errorCode description:description]; [MPPCommonUtils createCustomError:error withCode:errorCode description:description];
return NO; return NO;
} }

View File

@ -20,15 +20,11 @@ objc_library(
name = "MPPCategory", name = "MPPCategory",
srcs = ["sources/MPPCategory.m"], srcs = ["sources/MPPCategory.m"],
hdrs = ["sources/MPPCategory.h"], hdrs = ["sources/MPPCategory.h"],
deps = ["//third_party/apple_frameworks:Foundation"],
) )
objc_library( objc_library(
name = "MPPClassificationResult", name = "MPPClassificationResult",
srcs = ["sources/MPPClassificationResult.m"], srcs = ["sources/MPPClassificationResult.m"],
hdrs = ["sources/MPPClassificationResult.h"], hdrs = ["sources/MPPClassificationResult.h"],
deps = [ deps = [":MPPCategory"],
":MPPCategory",
"//third_party/apple_frameworks:Foundation",
],
) )

View File

@ -21,7 +21,7 @@ NS_ASSUME_NONNULL_BEGIN
* index of the label in the corresponding label file. Typically it's used as the result of * index of the label in the corresponding label file. Typically it's used as the result of
* classification tasks. * classification tasks.
*/ */
NS_SWIFT_NAME(ClassificationCategory) NS_SWIFT_NAME(ResultCategory)
@interface MPPCategory : NSObject @interface MPPCategory : NSObject
/** /**

View File

@ -26,9 +26,7 @@ objc_library(
name = "MPPTaskOptions", name = "MPPTaskOptions",
srcs = ["sources/MPPTaskOptions.m"], srcs = ["sources/MPPTaskOptions.m"],
hdrs = ["sources/MPPTaskOptions.h"], hdrs = ["sources/MPPTaskOptions.h"],
deps = [ deps = [":MPPBaseOptions"],
":MPPBaseOptions",
],
) )
objc_library( objc_library(
@ -40,9 +38,7 @@ objc_library(
objc_library( objc_library(
name = "MPPTaskOptionsProtocol", name = "MPPTaskOptionsProtocol",
hdrs = ["sources/MPPTaskOptionsProtocol.h"], hdrs = ["sources/MPPTaskOptionsProtocol.h"],
deps = [ deps = ["//mediapipe/framework:calculator_options_cc_proto"],
"//mediapipe/framework:calculator_options_cc_proto",
],
) )
objc_library( objc_library(
@ -92,6 +88,5 @@ objc_library(
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver", "//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
"//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
"//third_party/apple_frameworks:Foundation",
], ],
) )

View File

@ -0,0 +1,3 @@
"""MediaPipe Task Library Helper Rules for iOS"""
MPP_TASK_MINIMUM_OS_VERSION = "11.0"

View File

@ -0,0 +1,80 @@
load(
"@build_bazel_rules_apple//apple:ios.bzl",
"ios_unit_test",
)
load(
"@build_bazel_rules_swift//swift:swift.bzl",
"swift_library",
)
load(
"//mediapipe/tasks:ios/ios.bzl",
"MPP_TASK_MINIMUM_OS_VERSION",
)
load(
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
"tflite_ios_lab_runner",
)
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
TFL_DEFAULT_TAGS = [
"apple",
]
# Following sanitizer tests are not supported by iOS test targets.
TFL_DISABLED_SANITIZER_TAGS = [
"noasan",
"nomsan",
"notsan",
]
objc_library(
name = "MPPTextClassifierObjcTestLibrary",
testonly = 1,
srcs = ["MPPTextClassifierTests.m"],
data = [
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
"//mediapipe/tasks/testdata/text:text_classifier_models",
],
deps = [
"//mediapipe/tasks/ios/common:MPPCommon",
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier",
],
)
ios_unit_test(
name = "MPPTextClassifierObjcTest",
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
runner = tflite_ios_lab_runner("IOS_LATEST"),
deps = [
":MPPTextClassifierObjcTestLibrary",
],
)
swift_library(
name = "MPPTextClassifierSwiftTestLibrary",
testonly = 1,
srcs = ["TextClassifierTests.swift"],
data = [
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
"//mediapipe/tasks/testdata/text:text_classifier_models",
],
tags = TFL_DEFAULT_TAGS,
deps = [
"//mediapipe/tasks/ios/common:MPPCommon",
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier",
],
)
ios_unit_test(
name = "MPPTextClassifierSwiftTest",
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
deps = [
":MPPTextClassifierSwiftTestLibrary",
],
)

View File

@ -0,0 +1,275 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <XCTest/XCTest.h>
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
static NSString *const kBertTextClassifierModelName = @"bert_text_classifier";
static NSString *const kRegexTextClassifierModelName =
@"test_model_text_classifier_with_regex_tokenizer";
static NSString *const kNegativeText = @"unflinchingly bleak and desperate";
static NSString *const kPositiveText = @"it's a charming and often affecting journey";
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
#define AssertEqualErrors(error, expectedError) \
XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \
XCTAssertNotEqual( \
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
NSNotFound)
#define AssertEqualCategoryArrays(categories, expectedCategories) \
XCTAssertEqual(categories.count, expectedCategories.count); \
for (int i = 0; i < categories.count; i++) { \
XCTAssertEqual(categories[i].index, expectedCategories[i].index, @"index i = %d", i); \
XCTAssertEqualWithAccuracy(categories[i].score, expectedCategories[i].score, 1e-3, \
@"index i = %d", i); \
XCTAssertEqualObjects(categories[i].categoryName, expectedCategories[i].categoryName, \
@"index i = %d", i); \
XCTAssertEqualObjects(categories[i].displayName, expectedCategories[i].displayName, \
@"index i = %d", i); \
}
#define AssertTextClassifierResultHasOneHead(textClassifierResult) \
XCTAssertNotNil(textClassifierResult); \
XCTAssertNotNil(textClassifierResult.classificationResult); \
XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); \
XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0);
@interface MPPTextClassifierTests : XCTestCase
@end
@implementation MPPTextClassifierTests
+ (NSArray<MPPCategory *> *)expectedBertResultCategoriesForNegativeText {
return @[
[[MPPCategory alloc] initWithIndex:0 score:0.956187f categoryName:@"negative" displayName:nil],
[[MPPCategory alloc] initWithIndex:1 score:0.043812f categoryName:@"positive" displayName:nil]
];
}
+ (NSArray<MPPCategory *> *)expectedBertResultCategoriesForPositiveText {
return @[
[[MPPCategory alloc] initWithIndex:1 score:0.999945f categoryName:@"positive" displayName:nil],
[[MPPCategory alloc] initWithIndex:0 score:0.000055f categoryName:@"negative" displayName:nil]
];
}
+ (NSArray<MPPCategory *> *)expectedRegexResultCategoriesForNegativeText {
return @[
[[MPPCategory alloc] initWithIndex:0 score:0.6647746f categoryName:@"Negative" displayName:nil],
[[MPPCategory alloc] initWithIndex:1 score:0.33522537 categoryName:@"Positive" displayName:nil]
];
}
+ (NSArray<MPPCategory *> *)expectedRegexResultCategoriesForPositiveText {
return @[
[[MPPCategory alloc] initWithIndex:0 score:0.5120041f categoryName:@"Negative" displayName:nil],
[[MPPCategory alloc] initWithIndex:1 score:0.48799595 categoryName:@"Positive" displayName:nil]
];
}
+ (NSArray<MPPCategory *> *)expectedBertResultCategoriesForEdgeCaseTests {
return @[ [[MPPCategory alloc] initWithIndex:0
score:0.956187f
categoryName:@"negative"
displayName:nil] ];
}
- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName
ofType:extension];
return filePath;
}
- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName {
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
MPPTextClassifierOptions *textClassifierOptions = [[MPPTextClassifierOptions alloc] init];
textClassifierOptions.baseOptions.modelAssetPath = modelPath;
return textClassifierOptions;
}
- (MPPTextClassifier *)textClassifierFromModelFileWithName:(NSString *)modelName {
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath
error:nil];
XCTAssertNotNil(textClassifier);
return textClassifier;
}
- (void)assertCreateTextClassifierWithOptions:(MPPTextClassifierOptions *)textClassifierOptions
failsWithExpectedError:(NSError *)expectedError {
NSError *error = nil;
MPPTextClassifier *textClassifier =
[[MPPTextClassifier alloc] initWithOptions:textClassifierOptions error:&error];
XCTAssertNil(textClassifier);
AssertEqualErrors(error, expectedError);
}
- (void)assertResultsOfClassifyText:(NSString *)text
usingTextClassifier:(MPPTextClassifier *)textClassifier
equalsCategories:(NSArray<MPPCategory *> *)expectedCategories {
MPPTextClassifierResult *negativeResult = [textClassifier classifyText:text error:nil];
AssertTextClassifierResultHasOneHead(negativeResult);
AssertEqualCategoryArrays(negativeResult.classificationResult.classifications[0].categories,
expectedCategories);
}
- (void)testCreateTextClassifierFailsWithMissingModelPath {
NSString *modelPath = [self filePathWithName:@"" extension:@""];
NSError *error = nil;
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath
error:&error];
XCTAssertNil(textClassifier);
NSError *expectedError = [NSError
errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', "
@"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."
}];
AssertEqualErrors(error, expectedError);
}
- (void)testCreateTextClassifierFailsWithBothAllowlistAndDenylist {
MPPTextClassifierOptions *options =
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
options.categoryAllowlist = @[ @"positive" ];
options.categoryDenylist = @[ @"negative" ];
[self assertCreateTextClassifierWithOptions:options
failsWithExpectedError:
[NSError
errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"INVALID_ARGUMENT: `category_allowlist` and "
@"`category_denylist` are mutually exclusive options."
}]];
}
- (void)testCreateTextClassifierFailsWithInvalidMaxResults {
MPPTextClassifierOptions *options =
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
options.maxResults = 0;
[self assertCreateTextClassifierWithOptions:options
failsWithExpectedError:
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"INVALID_ARGUMENT: Invalid `max_results` option: "
@"value must be != 0."
}]];
}
- (void)testClassifyWithBertSucceeds {
MPPTextClassifier *textClassifier =
[self textClassifierFromModelFileWithName:kBertTextClassifierModelName];
[self assertResultsOfClassifyText:kNegativeText
usingTextClassifier:textClassifier
equalsCategories:[MPPTextClassifierTests
expectedBertResultCategoriesForNegativeText]];
[self assertResultsOfClassifyText:kPositiveText
usingTextClassifier:textClassifier
equalsCategories:[MPPTextClassifierTests
expectedBertResultCategoriesForPositiveText]];
}
- (void)testClassifyWithRegexSucceeds {
MPPTextClassifier *textClassifier =
[self textClassifierFromModelFileWithName:kRegexTextClassifierModelName];
[self assertResultsOfClassifyText:kNegativeText
usingTextClassifier:textClassifier
equalsCategories:[MPPTextClassifierTests
expectedRegexResultCategoriesForNegativeText]];
[self assertResultsOfClassifyText:kPositiveText
usingTextClassifier:textClassifier
equalsCategories:[MPPTextClassifierTests
expectedRegexResultCategoriesForPositiveText]];
}
- (void)testClassifyWithMaxResultsSucceeds {
MPPTextClassifierOptions *options =
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
options.maxResults = 1;
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
XCTAssertNotNil(textClassifier);
[self assertResultsOfClassifyText:kNegativeText
usingTextClassifier:textClassifier
equalsCategories:[MPPTextClassifierTests
expectedBertResultCategoriesForEdgeCaseTests]];
}
- (void)testClassifyWithCategoryAllowlistSucceeds {
MPPTextClassifierOptions *options =
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
options.categoryAllowlist = @[ @"negative" ];
NSError *error = nil;
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options
error:&error];
XCTAssertNotNil(textClassifier);
XCTAssertNil(error);
[self assertResultsOfClassifyText:kNegativeText
usingTextClassifier:textClassifier
equalsCategories:[MPPTextClassifierTests
expectedBertResultCategoriesForEdgeCaseTests]];
}
- (void)testClassifyWithCategoryDenylistSucceeds {
MPPTextClassifierOptions *options =
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
options.categoryDenylist = @[ @"positive" ];
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
XCTAssertNotNil(textClassifier);
[self assertResultsOfClassifyText:kNegativeText
usingTextClassifier:textClassifier
equalsCategories:[MPPTextClassifierTests
expectedBertResultCategoriesForEdgeCaseTests]];
}
- (void)testClassifyWithScoreThresholdSucceeds {
MPPTextClassifierOptions *options =
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
options.scoreThreshold = 0.5f;
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
XCTAssertNotNil(textClassifier);
[self assertResultsOfClassifyText:kNegativeText
usingTextClassifier:textClassifier
equalsCategories:[MPPTextClassifierTests
expectedBertResultCategoriesForEdgeCaseTests]];
}
@end

View File

@ -0,0 +1,264 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import MPPCommon
import XCTest
@testable import MPPTextClassifier
class TextClassifierTests: XCTestCase {
static let bundle = Bundle(for: TextClassifierTests.self)
static let bertModelPath = bundle.path(
forResource: "bert_text_classifier",
ofType: "tflite")
static let positiveText = "it's a charming and often affecting journey"
static let negativeText = "unflinchingly bleak and desperate"
static let bertNegativeTextResults = [
ResultCategory(
index: 0,
score: 0.956187,
categoryName: "negative",
displayName: nil),
ResultCategory(
index: 1,
score: 0.043812,
categoryName: "positive",
displayName: nil),
]
static let bertNegativeTextResultsForEdgeTestCases = [
ResultCategory(
index: 0,
score: 0.956187,
categoryName: "negative",
displayName: nil)
]
func assertEqualErrorDescriptions(
_ error: Error, expectedLocalizedDescription: String
) {
XCTAssertEqual(
error.localizedDescription,
expectedLocalizedDescription)
}
func assertCategoriesAreEqual(
category: ResultCategory,
expectedCategory: ResultCategory,
indexInCategoryList: Int
) {
XCTAssertEqual(
category.index,
expectedCategory.index,
String(
format: """
category[%d].index and expectedCategory[%d].index are not equal.
""", indexInCategoryList))
XCTAssertEqual(
category.score,
expectedCategory.score,
accuracy: 1e-3,
String(
format: """
category[%d].score and expectedCategory[%d].score are not equal.
""", indexInCategoryList))
XCTAssertEqual(
category.categoryName,
expectedCategory.categoryName,
String(
format: """
category[%d].categoryName and expectedCategory[%d].categoryName are \
not equal.
""", indexInCategoryList))
XCTAssertEqual(
category.displayName,
expectedCategory.displayName,
String(
format: """
category[%d].displayName and expectedCategory[%d].displayName are \
not equal.
""", indexInCategoryList))
}
func assertEqualCategoryArrays(
categoryArray: [ResultCategory],
expectedCategoryArray: [ResultCategory]
) {
XCTAssertEqual(
categoryArray.count,
expectedCategoryArray.count)
for (index, (category, expectedCategory)) in zip(categoryArray, expectedCategoryArray)
.enumerated()
{
assertCategoriesAreEqual(
category: category,
expectedCategory: expectedCategory,
indexInCategoryList: index)
}
}
func assertTextClassifierResultHasOneHead(
_ textClassifierResult: TextClassifierResult
) {
XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1)
XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0)
}
func textClassifierOptionsWithModelPath(
_ modelPath: String?
) throws -> TextClassifierOptions {
let modelPath = try XCTUnwrap(modelPath)
let textClassifierOptions = TextClassifierOptions()
textClassifierOptions.baseOptions.modelAssetPath = modelPath
return textClassifierOptions
}
func assertCreateTextClassifierThrowsError(
textClassifierOptions: TextClassifierOptions,
expectedErrorDescription: String
) {
do {
let textClassifier = try TextClassifier(options: textClassifierOptions)
XCTAssertNil(textClassifier)
} catch {
assertEqualErrorDescriptions(
error,
expectedLocalizedDescription: expectedErrorDescription)
}
}
func assertResultsForClassify(
text: String,
using textClassifier: TextClassifier,
equals expectedCategories: [ResultCategory]
) throws {
let textClassifierResult =
try XCTUnwrap(
textClassifier.classify(text: text))
assertTextClassifierResultHasOneHead(textClassifierResult)
assertEqualCategoryArrays(
categoryArray:
textClassifierResult.classificationResult.classifications[0].categories,
expectedCategoryArray: expectedCategories)
}
func testCreateTextClassifierWithInvalidMaxResultsFails() throws {
let textClassifierOptions =
try XCTUnwrap(
textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath))
textClassifierOptions.maxResults = 0
assertCreateTextClassifierThrowsError(
textClassifierOptions: textClassifierOptions,
expectedErrorDescription: """
INVALID_ARGUMENT: Invalid `max_results` option: value must be != 0.
""")
}
func testCreateTextClassifierWithCategoryAllowlistAndDenylistFails() throws {
let textClassifierOptions =
try XCTUnwrap(
textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath))
textClassifierOptions.categoryAllowlist = ["positive"]
textClassifierOptions.categoryDenylist = ["positive"]
assertCreateTextClassifierThrowsError(
textClassifierOptions: textClassifierOptions,
expectedErrorDescription: """
INVALID_ARGUMENT: `category_allowlist` and `category_denylist` are \
mutually exclusive options.
""")
}
func testClassifyWithBertSucceeds() throws {
let modelPath = try XCTUnwrap(TextClassifierTests.bertModelPath)
let textClassifier = try XCTUnwrap(TextClassifier(modelPath: modelPath))
try assertResultsForClassify(
text: TextClassifierTests.negativeText,
using: textClassifier,
equals: TextClassifierTests.bertNegativeTextResults)
}
func testClassifyWithMaxResultsSucceeds() throws {
let textClassifierOptions =
try XCTUnwrap(
textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath))
textClassifierOptions.maxResults = 1
let textClassifier =
try XCTUnwrap(TextClassifier(options: textClassifierOptions))
try assertResultsForClassify(
text: TextClassifierTests.negativeText,
using: textClassifier,
equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases)
}
func testClassifyWithCategoryAllowlistSucceeds() throws {
let textClassifierOptions =
try XCTUnwrap(
textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath))
textClassifierOptions.categoryAllowlist = ["negative"]
let textClassifier =
try XCTUnwrap(TextClassifier(options: textClassifierOptions))
try assertResultsForClassify(
text: TextClassifierTests.negativeText,
using: textClassifier,
equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases)
}
func testClassifyWithCategoryDenylistSucceeds() throws {
let textClassifierOptions =
try XCTUnwrap(
textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath))
textClassifierOptions.categoryDenylist = ["positive"]
let textClassifier =
try XCTUnwrap(TextClassifier(options: textClassifierOptions))
try assertResultsForClassify(
text: TextClassifierTests.negativeText,
using: textClassifier,
equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases)
}
func testClassifyWithScoreThresholdSucceeds() throws {
let textClassifierOptions =
try XCTUnwrap(
textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath))
textClassifierOptions.scoreThreshold = 0.5
let textClassifier =
try XCTUnwrap(TextClassifier(options: textClassifierOptions))
try assertResultsForClassify(
text: TextClassifierTests.negativeText,
using: textClassifier,
equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases)
}
}

View File

@ -24,8 +24,5 @@ objc_library(
"-ObjC++", "-ObjC++",
"-std=c++17", "-std=c++17",
], ],
deps = [ deps = ["//mediapipe/tasks/ios/core:MPPTaskRunner"],
"//mediapipe/tasks/ios/core:MPPTaskRunner",
"//third_party/apple_frameworks:Foundation",
],
) )

View File

@ -20,10 +20,7 @@ objc_library(
name = "MPPTextClassifierOptions", name = "MPPTextClassifierOptions",
srcs = ["sources/MPPTextClassifierOptions.m"], srcs = ["sources/MPPTextClassifierOptions.m"],
hdrs = ["sources/MPPTextClassifierOptions.h"], hdrs = ["sources/MPPTextClassifierOptions.h"],
deps = [ deps = ["//mediapipe/tasks/ios/core:MPPTaskOptions"],
"//mediapipe/tasks/ios/core:MPPTaskOptions",
"//third_party/apple_frameworks:Foundation",
],
) )
objc_library( objc_library(
@ -33,7 +30,6 @@ objc_library(
deps = [ deps = [
"//mediapipe/tasks/ios/components/containers:MPPClassificationResult", "//mediapipe/tasks/ios/components/containers:MPPClassificationResult",
"//mediapipe/tasks/ios/core:MPPTaskResult", "//mediapipe/tasks/ios/core:MPPTaskResult",
"//third_party/apple_frameworks:Foundation",
], ],
) )
@ -46,6 +42,7 @@ objc_library(
"-std=c++17", "-std=c++17",
"-x objective-c++", "-x objective-c++",
], ],
module_name = "MPPTextClassifier",
deps = [ deps = [
":MPPTextClassifierOptions", ":MPPTextClassifierOptions",
":MPPTextClassifierResult", ":MPPTextClassifierResult",
@ -59,7 +56,6 @@ objc_library(
"//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner",
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers",
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers",
"//third_party/apple_frameworks:Foundation",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
], ],
) )

View File

@ -65,7 +65,7 @@ NS_SWIFT_NAME(TextClassifier)
* @return A new instance of `MPPTextClassifier` with the given model path. `nil` if there is an * @return A new instance of `MPPTextClassifier` with the given model path. `nil` if there is an
* error in initializing the text classifier. * error in initializing the text classifier.
*/ */
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; - (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error;
/** /**
* Creates a new instance of `MPPTextClassifier` from the given `MPPTextClassifierOptions`. * Creates a new instance of `MPPTextClassifier` from the given `MPPTextClassifierOptions`.
@ -78,7 +78,7 @@ NS_SWIFT_NAME(TextClassifier)
* @return A new instance of `MPPTextClassifier` with the given options. `nil` if there is an * @return A new instance of `MPPTextClassifier` with the given options. `nil` if there is an
* error in initializing the text classifier. * error in initializing the text classifier.
*/ */
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options - (nullable instancetype)initWithOptions:(MPPTextClassifierOptions *)options
error:(NSError **)error NS_DESIGNATED_INITIALIZER; error:(NSError **)error NS_DESIGNATED_INITIALIZER;
/** /**
@ -90,7 +90,8 @@ NS_SWIFT_NAME(TextClassifier)
* *
* @return A `MPPTextClassifierResult` object that contains a list of text classifications. * @return A `MPPTextClassifierResult` object that contains a list of text classifications.
*/ */
- (nullable MPPTextClassifierResult *)classifyText:(NSString *)text error:(NSError **)error; - (nullable MPPTextClassifierResult *)classifyText:(NSString *)text
error:(NSError **)error NS_SWIFT_NAME(classify(text:));
- (instancetype)init NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE;