Merge pull request #3995 from priankakariatyml:ios-text-classifier-tests
PiperOrigin-RevId: 503242486
This commit is contained in:
commit
4b9a52dc34
|
@ -18,160 +18,92 @@ NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @enum MPPTasksErrorCode
|
* @enum MPPTasksErrorCode
|
||||||
* This enum specifies error codes for MediaPipe Task Library.
|
* This enum specifies error codes for errors thrown by iOS MediaPipe Task Library.
|
||||||
* It maintains a 1:1 mapping to MediaPipeTasksStatus of the C ++libray.
|
|
||||||
*/
|
*/
|
||||||
typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) {
|
typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) {
|
||||||
|
|
||||||
// Generic error codes.
|
// Generic error codes.
|
||||||
|
|
||||||
// Unspecified error.
|
/** Indicates the operation was cancelled, typically by the caller. */
|
||||||
MPPTasksErrorCodeError = 1,
|
MPPTasksErrorCodeCancelledError = 1,
|
||||||
// Invalid argument specified.
|
|
||||||
MPPTasksErrorCodeInvalidArgumentError = 2,
|
|
||||||
// Invalid FlatBuffer file or buffer specified.
|
|
||||||
MPPTasksErrorCodeInvalidFlatBufferError = 3,
|
|
||||||
// Model contains a builtin op that isn't supported by the OpResolver or
|
|
||||||
// delegates.
|
|
||||||
MPPTasksErrorCodeUnsupportedBuiltinOp = 4,
|
|
||||||
// Model contains a custom op that isn't supported by the OpResolver or
|
|
||||||
// delegates.
|
|
||||||
MPPTasksErrorCodeUnsupportedCustomOp = 5,
|
|
||||||
|
|
||||||
// File I/O error codes.
|
/** Indicates an unknown error occurred. */
|
||||||
|
MPPTasksErrorCodeUnknownError = 2,
|
||||||
|
|
||||||
// No such file.
|
/** Indicates the caller specified an invalid argument, such as a malformed filename. */
|
||||||
MPPTasksErrorCodeFileNotFoundError = 100,
|
MPPTasksErrorCodeInvalidArgumentError = 3,
|
||||||
// Permission issue.
|
|
||||||
MPPTasksErrorCodeFilePermissionDeniedError,
|
|
||||||
// I/O error when reading file.
|
|
||||||
MPPTasksErrorCodeFileReadError,
|
|
||||||
// I/O error when mmap-ing file.
|
|
||||||
MPPTasksErrorCodeFileMmapError,
|
|
||||||
// ZIP I/O error when unpacking the zip file.
|
|
||||||
MPPTasksErrorCodeFileZipError,
|
|
||||||
|
|
||||||
// TensorFlow Lite metadata error codes.
|
/** Indicates a deadline expired before the operation could complete. */
|
||||||
|
MPPTasksErrorCodeDeadlineExceededError = 4,
|
||||||
|
|
||||||
// Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer.
|
/** Indicates some requested entity (such as a file or directory) was not found. */
|
||||||
MPPTasksErrorCodeMetadataInvalidSchemaVersionError = 200,
|
MPPTasksErrorCodeNotFoundError = 5,
|
||||||
// No such associated file within metadata, or file has not been packed.
|
|
||||||
MPPTasksErrorCodeMetadataAssociatedFileNotFoundError,
|
|
||||||
// ZIP I/O error when unpacking an associated file.
|
|
||||||
MPPTasksErrorCodeMetadataAssociatedFileZipError,
|
|
||||||
// Inconsistency error between the metadata and actual TF Lite model.
|
|
||||||
// E.g.: number of labels and output tensor values differ.
|
|
||||||
MPPTasksErrorCodeMetadataInconsistencyError,
|
|
||||||
// Invalid process units specified.
|
|
||||||
// E.g.: multiple ProcessUnits with the same type for a given tensor.
|
|
||||||
MPPTasksErrorCodeMetadataInvalidProcessUnitsError,
|
|
||||||
// Inconsistency error with the number of labels.
|
|
||||||
// E.g.: label files for different locales have a different number of labels.
|
|
||||||
MPPTasksErrorCodeMetadataNumLabelsMismatchError,
|
|
||||||
// Score calibration parameters parsing error.
|
|
||||||
// E.g.: too many parameters provided in the corresponding associated file.
|
|
||||||
MPPTasksErrorCodeMetadataMalformedScoreCalibrationError,
|
|
||||||
// Unexpected number of subgraphs for the current task.
|
|
||||||
// E.g.: image classification expects a single subgraph.
|
|
||||||
MPPTasksErrorCodeMetadataInvalidNumSubgraphsError,
|
|
||||||
// A given tensor requires NormalizationOptions but none were found.
|
|
||||||
// E.g.: float input tensor requires normalization to preprocess input images.
|
|
||||||
MPPTasksErrorCodeMetadataMissingNormalizationOptionsError,
|
|
||||||
// Invalid ContentProperties specified.
|
|
||||||
// E.g. expected ImageProperties, got BoundingBoxProperties.
|
|
||||||
MPPTasksErrorCodeMetadataInvalidContentPropertiesError,
|
|
||||||
// Metadata is mandatory but was not found.
|
|
||||||
// E.g. current task requires TFLite Model Metadata but none was found.
|
|
||||||
MPPTasksErrorCodeMetadataNotFoundError,
|
|
||||||
// Associated TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS file is mandatory but
|
|
||||||
// none was found or it was empty.
|
|
||||||
// E.g. current task requires labels but none were found.
|
|
||||||
MPPTasksErrorCodeMetadataMissingLabelsError,
|
|
||||||
// The ProcessingUnit for tokenizer is not correctly configured.
|
|
||||||
// E.g BertTokenizer doesn't have a valid vocab file associated.
|
|
||||||
MPPTasksErrorCodeMetadataInvalidTokenizerError,
|
|
||||||
|
|
||||||
// Input tensor(s) error codes.
|
/**
|
||||||
|
* Indicates that the entity a caller attempted to create (such as a file or directory) is
|
||||||
|
* already present.
|
||||||
|
*/
|
||||||
|
MPPTasksErrorCodeAlreadyExistsError = 6,
|
||||||
|
|
||||||
// Unexpected number of input tensors for the current task.
|
/** Indicates that the caller does not have permission to execute the specified operation. */
|
||||||
// E.g. current task expects a single input tensor.
|
MPPTasksErrorCodePermissionDeniedError = 7,
|
||||||
MPPTasksErrorCodeInvalidNumInputTensorsError = 300,
|
|
||||||
// Unexpected input tensor dimensions for the current task.
|
|
||||||
// E.g.: only 4D input tensors supported.
|
|
||||||
MPPTasksErrorCodeInvalidInputTensorDimensionsError,
|
|
||||||
// Unexpected input tensor type for the current task.
|
|
||||||
// E.g.: current task expects a uint8 pixel image as input.
|
|
||||||
MPPTasksErrorCodeInvalidInputTensorTypeError,
|
|
||||||
// Unexpected input tensor bytes size.
|
|
||||||
// E.g.: size in bytes does not correspond to the expected number of pixels.
|
|
||||||
MPPTasksErrorCodeInvalidInputTensorSizeError,
|
|
||||||
// No correct input tensor found for the model.
|
|
||||||
// E.g.: input tensor name is not part of the text model's input tensors.
|
|
||||||
MPPTasksErrorCodeInputTensorNotFoundError,
|
|
||||||
|
|
||||||
// Output tensor(s) error codes.
|
/**
|
||||||
|
* Indicates some resource has been exhausted, perhaps a per-user quota, or perhaps the entire
|
||||||
|
* file system is out of space.
|
||||||
|
*/
|
||||||
|
MPPTasksErrorCodeResourceExhaustedError = 8,
|
||||||
|
|
||||||
// Unexpected output tensor dimensions for the current task.
|
/**
|
||||||
// E.g.: only a batch size of 1 is supported.
|
* Indicates that the operation was rejected because the system is not in a state required for
|
||||||
MPPTasksErrorCodeInvalidOutputTensorDimensionsError = 400,
|
* the operation's execution. For example, a directory to be deleted may be non-empty, an "rmdir"
|
||||||
// Unexpected input tensor type for the current task.
|
* operation is applied to a non-directory, etc.
|
||||||
// E.g.: multi-head model with different output tensor types.
|
*/
|
||||||
MPPTasksErrorCodeInvalidOutputTensorTypeError,
|
MPPTasksErrorCodeFailedPreconditionError = 9,
|
||||||
// No correct output tensor found for the model.
|
|
||||||
// E.g.: output tensor name is not part of the text model's output tensors.
|
|
||||||
MPPTasksErrorCodeOutputTensorNotFoundError,
|
|
||||||
// Unexpected number of output tensors for the current task.
|
|
||||||
// E.g.: current task expects a single output tensor.
|
|
||||||
MPPTasksErrorCodeInvalidNumOutputTensorsError,
|
|
||||||
|
|
||||||
// Image processing error codes.
|
/**
|
||||||
|
* Indicates the operation was aborted, typically due to a concurrency issue such as a sequencer
|
||||||
|
* check failure or a failed transaction.
|
||||||
|
*/
|
||||||
|
MPPTasksErrorCodeAbortedError = 10,
|
||||||
|
|
||||||
// Unspecified image processing failures.
|
/**
|
||||||
MPPTasksErrorCodeImageProcessingError = 500,
|
* Indicates the operation was attempted past the valid range, such as seeking or reading past an
|
||||||
// Unexpected input or output buffer metadata.
|
* end-of-file.
|
||||||
// E.g.: rotate RGBA buffer to Grayscale buffer by 90 degrees.
|
*/
|
||||||
MPPTasksErrorCodeImageProcessingInvalidArgumentError,
|
MPPTasksErrorCodeOutOfRangeError = 11,
|
||||||
// Image processing operation failures.
|
|
||||||
// E.g. libyuv rotation failed for an unknown reason.
|
|
||||||
MPPTasksErrorCodeImageProcessingBackendError,
|
|
||||||
|
|
||||||
// Task runner error codes.
|
/**
|
||||||
MPPTasksErrorCodeRunnerError = 600,
|
* Indicates the operation is not implemented or supported in this service. In this case, the
|
||||||
// Task runner is not initialized.
|
* operation should not be re-attempted.
|
||||||
MPPTasksErrorCodeRunnerInitializationError,
|
*/
|
||||||
// Task runner is not started successfully.
|
MPPTasksErrorCodeUnimplementedError = 12,
|
||||||
MPPTasksErrorCodeRunnerFailsToStartError,
|
|
||||||
// Task runner is not started.
|
|
||||||
MPPTasksErrorCodeRunnerNotStartedError,
|
|
||||||
// Task runner API is called in the wrong processing mode.
|
|
||||||
MPPTasksErrorCodeRunnerApiCalledInWrongModeError,
|
|
||||||
// Task runner receives/produces invalid MediaPipe packet timestamp.
|
|
||||||
MPPTasksErrorCodeRunnerInvalidTimestampError,
|
|
||||||
// Task runner receives unexpected MediaPipe graph input packet.
|
|
||||||
// E.g. The packet type doesn't match the graph input stream's data type.
|
|
||||||
MPPTasksErrorCodeRunnerUnexpectedInputError,
|
|
||||||
// Task runner produces unexpected MediaPipe graph output packet.
|
|
||||||
// E.g. The number of output packets is not equal to the number of graph
|
|
||||||
// output streams.
|
|
||||||
MPPTasksErrorCodeRunnerUnexpectedOutputError,
|
|
||||||
// Task runner is not closed successfully.
|
|
||||||
MPPTasksErrorCodeRunnerFailsToCloseError,
|
|
||||||
// Task runner's model resources cache service is unavailable or the
|
|
||||||
// targeting model resources bundle is not found.
|
|
||||||
MPPTasksErrorCodeRunnerModelResourcesCacheServiceError,
|
|
||||||
|
|
||||||
// Task graph error codes.
|
/**
|
||||||
MPPTasksErrorCodeGraphError = 700,
|
* Indicates an internal error has occurred and some invariants expected by the underlying system
|
||||||
// Task graph is not implemented.
|
* have not been satisfied. This error code is reserved for serious errors.
|
||||||
MPPTasksErrorCodeTaskGraphNotImplementedError,
|
*/
|
||||||
// Task graph config is invalid.
|
MPPTasksErrorCodeInternalError = 13,
|
||||||
MPPTasksErrorCodeInvalidTaskGraphConfigError,
|
|
||||||
|
|
||||||
// The first error code in MPPTasksErrorCode (for internal use only).
|
/**
|
||||||
MPPTasksErrorCodeFirst = MPPTasksErrorCodeError,
|
* Indicates the service is currently unavailable and that this is most likely a transient
|
||||||
|
* condition.
|
||||||
|
*/
|
||||||
|
MPPTasksErrorCodeUnavailableError = 14,
|
||||||
|
|
||||||
// The last error code in MPPTasksErrorCode (for internal use only).
|
/** Indicates that unrecoverable data loss or corruption has occurred. */
|
||||||
MPPTasksErrorCodeLast = MPPTasksErrorCodeInvalidTaskGraphConfigError,
|
MPPTasksErrorCodeDataLossError = 15,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Indicates that the request does not have valid authentication credentials for the operation.
|
||||||
|
*/
|
||||||
|
MPPTasksErrorCodeUnauthenticatedError = 16,
|
||||||
|
|
||||||
|
/** The first error code in MPPTasksErrorCode (for internal use only). */
|
||||||
|
MPPTasksErrorCodeFirst = MPPTasksErrorCodeCancelledError,
|
||||||
|
|
||||||
|
/** The last error code in MPPTasksErrorCode (for internal use only). */
|
||||||
|
MPPTasksErrorCodeLast = MPPTasksErrorCodeUnauthenticatedError,
|
||||||
|
|
||||||
} NS_SWIFT_NAME(TasksErrorCode);
|
} NS_SWIFT_NAME(TasksErrorCode);
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,10 @@
|
||||||
/** Error domain of MediaPipe task library errors. */
|
/** Error domain of MediaPipe task library errors. */
|
||||||
NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks";
|
NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
using absl::StatusCode;
|
||||||
|
}
|
||||||
|
|
||||||
@implementation MPPCommonUtils
|
@implementation MPPCommonUtils
|
||||||
|
|
||||||
+ (void)createCustomError:(NSError **)error
|
+ (void)createCustomError:(NSError **)error
|
||||||
|
@ -67,68 +71,69 @@ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
return YES;
|
return YES;
|
||||||
}
|
}
|
||||||
// Payload of absl::Status created by the MediaPipe task library stores an appropriate value of
|
|
||||||
// the enum MediaPipeTasksStatus. The integer value corresponding to the MediaPipeTasksStatus enum
|
|
||||||
// stored in the payload is extracted here to later map to the appropriate error code to be
|
|
||||||
// returned. In cases where the enum is not stored in (payload is NULL or the payload string
|
|
||||||
// cannot be converted to an integer), we set the error code value to be 1
|
|
||||||
// (MPPTasksErrorCodeError of MPPTasksErrorCode used in the iOS library to signify
|
|
||||||
// any errors not falling into other categories.) Since payload is of type absl::Cord that can be
|
|
||||||
// type cast into an absl::optional<std::string>, we use the std::stoi function to convert it into
|
|
||||||
// an integer code if possible.
|
|
||||||
NSUInteger genericErrorCode = MPPTasksErrorCodeError;
|
|
||||||
NSUInteger errorCode;
|
|
||||||
try {
|
|
||||||
// Try converting payload to integer if payload is not empty. Otherwise convert a string
|
|
||||||
// signifying generic error code MPPTasksErrorCodeError to integer.
|
|
||||||
errorCode =
|
|
||||||
(NSUInteger)std::stoi(static_cast<absl::optional<std::string>>(
|
|
||||||
status.GetPayload(mediapipe::tasks::kMediaPipeTasksPayload))
|
|
||||||
.value_or(std::to_string(genericErrorCode)));
|
|
||||||
} catch (std::invalid_argument &e) {
|
|
||||||
// If non empty payload string cannot be converted to an integer. Set error code to 1(kError).
|
|
||||||
errorCode = MPPTasksErrorCodeError;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If errorCode is outside the range of enum values possible or is
|
// Converts the absl status message to an NSString.
|
||||||
// MPPTasksErrorCodeError, we try to map the absl::Status::code() to assign
|
|
||||||
// appropriate MPPTasksErrorCode in default cases. Note:
|
|
||||||
// The mapping to absl::Status::code() is done to generate a more specific error code than
|
|
||||||
// MPPTasksErrorCodeError in cases when the payload can't be mapped to
|
|
||||||
// MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn
|
|
||||||
// returned without modification by MediaPipe cc library methods.
|
|
||||||
if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) {
|
|
||||||
switch (status.code()) {
|
|
||||||
case absl::StatusCode::kInternal:
|
|
||||||
errorCode = MPPTasksErrorCodeError;
|
|
||||||
break;
|
|
||||||
case absl::StatusCode::kInvalidArgument:
|
|
||||||
errorCode = MPPTasksErrorCodeInvalidArgumentError;
|
|
||||||
break;
|
|
||||||
case absl::StatusCode::kNotFound:
|
|
||||||
errorCode = MPPTasksErrorCodeError;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
errorCode = MPPTasksErrorCodeError;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creates the NSEror with the appropriate error
|
|
||||||
// MPPTasksErrorCode and message. MPPTasksErrorCode has a one to one
|
|
||||||
// mapping with MediaPipeTasksStatus starting from the value 1(MPPTasksErrorCodeError)
|
|
||||||
// and hence will be correctly initialized if directly cast from the integer code derived from
|
|
||||||
// MediaPipeTasksStatus stored in its payload. MPPTasksErrorCode omits kOk = 0 of
|
|
||||||
// MediaPipeTasksStatusx.
|
|
||||||
//
|
|
||||||
// Stores a string including absl status code and message(if non empty) as the
|
|
||||||
// error message See
|
|
||||||
// https://github.com/abseil/abseil-cpp/blob/master/absl/status/status.h#L514
|
|
||||||
// for explanation. absl::Status::message() can also be used but not always
|
|
||||||
// guaranteed to be non empty.
|
|
||||||
NSString *description = [NSString
|
NSString *description = [NSString
|
||||||
stringWithCString:status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str()
|
stringWithCString:status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str()
|
||||||
encoding:NSUTF8StringEncoding];
|
encoding:NSUTF8StringEncoding];
|
||||||
|
|
||||||
|
MPPTasksErrorCode errorCode = MPPTasksErrorCodeUnknownError;
|
||||||
|
|
||||||
|
// Maps the absl::StatusCode to the appropriate MPPTasksErrorCode. Note: MPPTasksErrorCode omits
|
||||||
|
// absl::StatusCode::kOk.
|
||||||
|
switch (status.code()) {
|
||||||
|
case StatusCode::kCancelled:
|
||||||
|
errorCode = MPPTasksErrorCodeCancelledError;
|
||||||
|
break;
|
||||||
|
case StatusCode::kUnknown:
|
||||||
|
errorCode = MPPTasksErrorCodeUnknownError;
|
||||||
|
break;
|
||||||
|
case StatusCode::kInvalidArgument:
|
||||||
|
errorCode = MPPTasksErrorCodeInvalidArgumentError;
|
||||||
|
break;
|
||||||
|
case StatusCode::kDeadlineExceeded:
|
||||||
|
errorCode = MPPTasksErrorCodeDeadlineExceededError;
|
||||||
|
break;
|
||||||
|
case StatusCode::kNotFound:
|
||||||
|
errorCode = MPPTasksErrorCodeNotFoundError;
|
||||||
|
break;
|
||||||
|
case StatusCode::kAlreadyExists:
|
||||||
|
errorCode = MPPTasksErrorCodeAlreadyExistsError;
|
||||||
|
break;
|
||||||
|
case StatusCode::kPermissionDenied:
|
||||||
|
errorCode = MPPTasksErrorCodePermissionDeniedError;
|
||||||
|
break;
|
||||||
|
case StatusCode::kResourceExhausted:
|
||||||
|
errorCode = MPPTasksErrorCodeResourceExhaustedError;
|
||||||
|
break;
|
||||||
|
case StatusCode::kFailedPrecondition:
|
||||||
|
errorCode = MPPTasksErrorCodeFailedPreconditionError;
|
||||||
|
break;
|
||||||
|
case StatusCode::kAborted:
|
||||||
|
errorCode = MPPTasksErrorCodeAbortedError;
|
||||||
|
break;
|
||||||
|
case StatusCode::kOutOfRange:
|
||||||
|
errorCode = MPPTasksErrorCodeOutOfRangeError;
|
||||||
|
break;
|
||||||
|
case StatusCode::kUnimplemented:
|
||||||
|
errorCode = MPPTasksErrorCodeUnimplementedError;
|
||||||
|
break;
|
||||||
|
case StatusCode::kInternal:
|
||||||
|
errorCode = MPPTasksErrorCodeInternalError;
|
||||||
|
break;
|
||||||
|
case StatusCode::kUnavailable:
|
||||||
|
errorCode = MPPTasksErrorCodeUnavailableError;
|
||||||
|
break;
|
||||||
|
case StatusCode::kDataLoss:
|
||||||
|
errorCode = MPPTasksErrorCodeDataLossError;
|
||||||
|
break;
|
||||||
|
case StatusCode::kUnauthenticated:
|
||||||
|
errorCode = MPPTasksErrorCodeUnauthenticatedError;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
[MPPCommonUtils createCustomError:error withCode:errorCode description:description];
|
[MPPCommonUtils createCustomError:error withCode:errorCode description:description];
|
||||||
return NO;
|
return NO;
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,15 +20,11 @@ objc_library(
|
||||||
name = "MPPCategory",
|
name = "MPPCategory",
|
||||||
srcs = ["sources/MPPCategory.m"],
|
srcs = ["sources/MPPCategory.m"],
|
||||||
hdrs = ["sources/MPPCategory.h"],
|
hdrs = ["sources/MPPCategory.h"],
|
||||||
deps = ["//third_party/apple_frameworks:Foundation"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
objc_library(
|
objc_library(
|
||||||
name = "MPPClassificationResult",
|
name = "MPPClassificationResult",
|
||||||
srcs = ["sources/MPPClassificationResult.m"],
|
srcs = ["sources/MPPClassificationResult.m"],
|
||||||
hdrs = ["sources/MPPClassificationResult.h"],
|
hdrs = ["sources/MPPClassificationResult.h"],
|
||||||
deps = [
|
deps = [":MPPCategory"],
|
||||||
":MPPCategory",
|
|
||||||
"//third_party/apple_frameworks:Foundation",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,7 +21,7 @@ NS_ASSUME_NONNULL_BEGIN
|
||||||
* index of the label in the corresponding label file. Typically it's used as the result of
|
* index of the label in the corresponding label file. Typically it's used as the result of
|
||||||
* classification tasks.
|
* classification tasks.
|
||||||
*/
|
*/
|
||||||
NS_SWIFT_NAME(ClassificationCategory)
|
NS_SWIFT_NAME(ResultCategory)
|
||||||
@interface MPPCategory : NSObject
|
@interface MPPCategory : NSObject
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -26,9 +26,7 @@ objc_library(
|
||||||
name = "MPPTaskOptions",
|
name = "MPPTaskOptions",
|
||||||
srcs = ["sources/MPPTaskOptions.m"],
|
srcs = ["sources/MPPTaskOptions.m"],
|
||||||
hdrs = ["sources/MPPTaskOptions.h"],
|
hdrs = ["sources/MPPTaskOptions.h"],
|
||||||
deps = [
|
deps = [":MPPBaseOptions"],
|
||||||
":MPPBaseOptions",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
objc_library(
|
objc_library(
|
||||||
|
@ -40,9 +38,7 @@ objc_library(
|
||||||
objc_library(
|
objc_library(
|
||||||
name = "MPPTaskOptionsProtocol",
|
name = "MPPTaskOptionsProtocol",
|
||||||
hdrs = ["sources/MPPTaskOptionsProtocol.h"],
|
hdrs = ["sources/MPPTaskOptionsProtocol.h"],
|
||||||
deps = [
|
deps = ["//mediapipe/framework:calculator_options_cc_proto"],
|
||||||
"//mediapipe/framework:calculator_options_cc_proto",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
objc_library(
|
objc_library(
|
||||||
|
@ -92,6 +88,5 @@ objc_library(
|
||||||
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
|
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
|
||||||
"//mediapipe/tasks/cc/core:task_runner",
|
"//mediapipe/tasks/cc/core:task_runner",
|
||||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||||
"//third_party/apple_frameworks:Foundation",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
3
mediapipe/tasks/ios/ios.bzl
Normal file
3
mediapipe/tasks/ios/ios.bzl
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
"""MediaPipe Task Library Helper Rules for iOS"""
|
||||||
|
|
||||||
|
MPP_TASK_MINIMUM_OS_VERSION = "11.0"
|
80
mediapipe/tasks/ios/test/text/text_classifier/BUILD
Normal file
80
mediapipe/tasks/ios/test/text/text_classifier/BUILD
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
load(
|
||||||
|
"@build_bazel_rules_apple//apple:ios.bzl",
|
||||||
|
"ios_unit_test",
|
||||||
|
)
|
||||||
|
load(
|
||||||
|
"@build_bazel_rules_swift//swift:swift.bzl",
|
||||||
|
"swift_library",
|
||||||
|
)
|
||||||
|
load(
|
||||||
|
"//mediapipe/tasks:ios/ios.bzl",
|
||||||
|
"MPP_TASK_MINIMUM_OS_VERSION",
|
||||||
|
)
|
||||||
|
load(
|
||||||
|
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
|
||||||
|
"tflite_ios_lab_runner",
|
||||||
|
)
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
|
||||||
|
TFL_DEFAULT_TAGS = [
|
||||||
|
"apple",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Following sanitizer tests are not supported by iOS test targets.
|
||||||
|
TFL_DISABLED_SANITIZER_TAGS = [
|
||||||
|
"noasan",
|
||||||
|
"nomsan",
|
||||||
|
"notsan",
|
||||||
|
]
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPTextClassifierObjcTestLibrary",
|
||||||
|
testonly = 1,
|
||||||
|
srcs = ["MPPTextClassifierTests.m"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
|
||||||
|
"//mediapipe/tasks/testdata/text:text_classifier_models",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||||
|
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
ios_unit_test(
|
||||||
|
name = "MPPTextClassifierObjcTest",
|
||||||
|
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
|
||||||
|
runner = tflite_ios_lab_runner("IOS_LATEST"),
|
||||||
|
deps = [
|
||||||
|
":MPPTextClassifierObjcTestLibrary",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
swift_library(
|
||||||
|
name = "MPPTextClassifierSwiftTestLibrary",
|
||||||
|
testonly = 1,
|
||||||
|
srcs = ["TextClassifierTests.swift"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
|
||||||
|
"//mediapipe/tasks/testdata/text:text_classifier_models",
|
||||||
|
],
|
||||||
|
tags = TFL_DEFAULT_TAGS,
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||||
|
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
ios_unit_test(
|
||||||
|
name = "MPPTextClassifierSwiftTest",
|
||||||
|
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
|
||||||
|
runner = tflite_ios_lab_runner("IOS_LATEST"),
|
||||||
|
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
|
||||||
|
deps = [
|
||||||
|
":MPPTextClassifierSwiftTestLibrary",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,275 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import <XCTest/XCTest.h>
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
|
||||||
|
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
|
||||||
|
|
||||||
|
static NSString *const kBertTextClassifierModelName = @"bert_text_classifier";
|
||||||
|
static NSString *const kRegexTextClassifierModelName =
|
||||||
|
@"test_model_text_classifier_with_regex_tokenizer";
|
||||||
|
static NSString *const kNegativeText = @"unflinchingly bleak and desperate";
|
||||||
|
static NSString *const kPositiveText = @"it's a charming and often affecting journey";
|
||||||
|
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
|
|
||||||
|
#define AssertEqualErrors(error, expectedError) \
|
||||||
|
XCTAssertNotNil(error); \
|
||||||
|
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
||||||
|
XCTAssertEqual(error.code, expectedError.code); \
|
||||||
|
XCTAssertNotEqual( \
|
||||||
|
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
|
||||||
|
NSNotFound)
|
||||||
|
|
||||||
|
#define AssertEqualCategoryArrays(categories, expectedCategories) \
|
||||||
|
XCTAssertEqual(categories.count, expectedCategories.count); \
|
||||||
|
for (int i = 0; i < categories.count; i++) { \
|
||||||
|
XCTAssertEqual(categories[i].index, expectedCategories[i].index, @"index i = %d", i); \
|
||||||
|
XCTAssertEqualWithAccuracy(categories[i].score, expectedCategories[i].score, 1e-3, \
|
||||||
|
@"index i = %d", i); \
|
||||||
|
XCTAssertEqualObjects(categories[i].categoryName, expectedCategories[i].categoryName, \
|
||||||
|
@"index i = %d", i); \
|
||||||
|
XCTAssertEqualObjects(categories[i].displayName, expectedCategories[i].displayName, \
|
||||||
|
@"index i = %d", i); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define AssertTextClassifierResultHasOneHead(textClassifierResult) \
|
||||||
|
XCTAssertNotNil(textClassifierResult); \
|
||||||
|
XCTAssertNotNil(textClassifierResult.classificationResult); \
|
||||||
|
XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); \
|
||||||
|
XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0);
|
||||||
|
|
||||||
|
@interface MPPTextClassifierTests : XCTestCase
|
||||||
|
@end
|
||||||
|
|
||||||
|
@implementation MPPTextClassifierTests
|
||||||
|
|
||||||
|
+ (NSArray<MPPCategory *> *)expectedBertResultCategoriesForNegativeText {
|
||||||
|
return @[
|
||||||
|
[[MPPCategory alloc] initWithIndex:0 score:0.956187f categoryName:@"negative" displayName:nil],
|
||||||
|
[[MPPCategory alloc] initWithIndex:1 score:0.043812f categoryName:@"positive" displayName:nil]
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
+ (NSArray<MPPCategory *> *)expectedBertResultCategoriesForPositiveText {
|
||||||
|
return @[
|
||||||
|
[[MPPCategory alloc] initWithIndex:1 score:0.999945f categoryName:@"positive" displayName:nil],
|
||||||
|
[[MPPCategory alloc] initWithIndex:0 score:0.000055f categoryName:@"negative" displayName:nil]
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
+ (NSArray<MPPCategory *> *)expectedRegexResultCategoriesForNegativeText {
|
||||||
|
return @[
|
||||||
|
[[MPPCategory alloc] initWithIndex:0 score:0.6647746f categoryName:@"Negative" displayName:nil],
|
||||||
|
[[MPPCategory alloc] initWithIndex:1 score:0.33522537 categoryName:@"Positive" displayName:nil]
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
+ (NSArray<MPPCategory *> *)expectedRegexResultCategoriesForPositiveText {
|
||||||
|
return @[
|
||||||
|
[[MPPCategory alloc] initWithIndex:0 score:0.5120041f categoryName:@"Negative" displayName:nil],
|
||||||
|
[[MPPCategory alloc] initWithIndex:1 score:0.48799595 categoryName:@"Positive" displayName:nil]
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
+ (NSArray<MPPCategory *> *)expectedBertResultCategoriesForEdgeCaseTests {
|
||||||
|
return @[ [[MPPCategory alloc] initWithIndex:0
|
||||||
|
score:0.956187f
|
||||||
|
categoryName:@"negative"
|
||||||
|
displayName:nil] ];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
|
||||||
|
NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName
|
||||||
|
ofType:extension];
|
||||||
|
return filePath;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName {
|
||||||
|
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
|
||||||
|
MPPTextClassifierOptions *textClassifierOptions = [[MPPTextClassifierOptions alloc] init];
|
||||||
|
textClassifierOptions.baseOptions.modelAssetPath = modelPath;
|
||||||
|
|
||||||
|
return textClassifierOptions;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (MPPTextClassifier *)textClassifierFromModelFileWithName:(NSString *)modelName {
|
||||||
|
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
|
||||||
|
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath
|
||||||
|
error:nil];
|
||||||
|
XCTAssertNotNil(textClassifier);
|
||||||
|
|
||||||
|
return textClassifier;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)assertCreateTextClassifierWithOptions:(MPPTextClassifierOptions *)textClassifierOptions
|
||||||
|
failsWithExpectedError:(NSError *)expectedError {
|
||||||
|
NSError *error = nil;
|
||||||
|
MPPTextClassifier *textClassifier =
|
||||||
|
[[MPPTextClassifier alloc] initWithOptions:textClassifierOptions error:&error];
|
||||||
|
XCTAssertNil(textClassifier);
|
||||||
|
AssertEqualErrors(error, expectedError);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)assertResultsOfClassifyText:(NSString *)text
|
||||||
|
usingTextClassifier:(MPPTextClassifier *)textClassifier
|
||||||
|
equalsCategories:(NSArray<MPPCategory *> *)expectedCategories {
|
||||||
|
MPPTextClassifierResult *negativeResult = [textClassifier classifyText:text error:nil];
|
||||||
|
AssertTextClassifierResultHasOneHead(negativeResult);
|
||||||
|
AssertEqualCategoryArrays(negativeResult.classificationResult.classifications[0].categories,
|
||||||
|
expectedCategories);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testCreateTextClassifierFailsWithMissingModelPath {
|
||||||
|
NSString *modelPath = [self filePathWithName:@"" extension:@""];
|
||||||
|
|
||||||
|
NSError *error = nil;
|
||||||
|
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithModelPath:modelPath
|
||||||
|
error:&error];
|
||||||
|
XCTAssertNil(textClassifier);
|
||||||
|
|
||||||
|
NSError *expectedError = [NSError
|
||||||
|
errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey :
|
||||||
|
@"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', "
|
||||||
|
@"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."
|
||||||
|
}];
|
||||||
|
AssertEqualErrors(error, expectedError);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testCreateTextClassifierFailsWithBothAllowlistAndDenylist {
|
||||||
|
MPPTextClassifierOptions *options =
|
||||||
|
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
|
||||||
|
options.categoryAllowlist = @[ @"positive" ];
|
||||||
|
options.categoryDenylist = @[ @"negative" ];
|
||||||
|
|
||||||
|
[self assertCreateTextClassifierWithOptions:options
|
||||||
|
failsWithExpectedError:
|
||||||
|
[NSError
|
||||||
|
errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey :
|
||||||
|
@"INVALID_ARGUMENT: `category_allowlist` and "
|
||||||
|
@"`category_denylist` are mutually exclusive options."
|
||||||
|
}]];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testCreateTextClassifierFailsWithInvalidMaxResults {
|
||||||
|
MPPTextClassifierOptions *options =
|
||||||
|
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
|
||||||
|
options.maxResults = 0;
|
||||||
|
|
||||||
|
[self assertCreateTextClassifierWithOptions:options
|
||||||
|
failsWithExpectedError:
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey :
|
||||||
|
@"INVALID_ARGUMENT: Invalid `max_results` option: "
|
||||||
|
@"value must be != 0."
|
||||||
|
}]];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testClassifyWithBertSucceeds {
|
||||||
|
MPPTextClassifier *textClassifier =
|
||||||
|
[self textClassifierFromModelFileWithName:kBertTextClassifierModelName];
|
||||||
|
|
||||||
|
[self assertResultsOfClassifyText:kNegativeText
|
||||||
|
usingTextClassifier:textClassifier
|
||||||
|
equalsCategories:[MPPTextClassifierTests
|
||||||
|
expectedBertResultCategoriesForNegativeText]];
|
||||||
|
|
||||||
|
[self assertResultsOfClassifyText:kPositiveText
|
||||||
|
usingTextClassifier:textClassifier
|
||||||
|
equalsCategories:[MPPTextClassifierTests
|
||||||
|
expectedBertResultCategoriesForPositiveText]];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testClassifyWithRegexSucceeds {
|
||||||
|
MPPTextClassifier *textClassifier =
|
||||||
|
[self textClassifierFromModelFileWithName:kRegexTextClassifierModelName];
|
||||||
|
|
||||||
|
[self assertResultsOfClassifyText:kNegativeText
|
||||||
|
usingTextClassifier:textClassifier
|
||||||
|
equalsCategories:[MPPTextClassifierTests
|
||||||
|
expectedRegexResultCategoriesForNegativeText]];
|
||||||
|
[self assertResultsOfClassifyText:kPositiveText
|
||||||
|
usingTextClassifier:textClassifier
|
||||||
|
equalsCategories:[MPPTextClassifierTests
|
||||||
|
expectedRegexResultCategoriesForPositiveText]];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testClassifyWithMaxResultsSucceeds {
|
||||||
|
MPPTextClassifierOptions *options =
|
||||||
|
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
|
||||||
|
options.maxResults = 1;
|
||||||
|
|
||||||
|
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
|
||||||
|
XCTAssertNotNil(textClassifier);
|
||||||
|
|
||||||
|
[self assertResultsOfClassifyText:kNegativeText
|
||||||
|
usingTextClassifier:textClassifier
|
||||||
|
equalsCategories:[MPPTextClassifierTests
|
||||||
|
expectedBertResultCategoriesForEdgeCaseTests]];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testClassifyWithCategoryAllowlistSucceeds {
|
||||||
|
MPPTextClassifierOptions *options =
|
||||||
|
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
|
||||||
|
options.categoryAllowlist = @[ @"negative" ];
|
||||||
|
|
||||||
|
NSError *error = nil;
|
||||||
|
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options
|
||||||
|
error:&error];
|
||||||
|
XCTAssertNotNil(textClassifier);
|
||||||
|
XCTAssertNil(error);
|
||||||
|
|
||||||
|
[self assertResultsOfClassifyText:kNegativeText
|
||||||
|
usingTextClassifier:textClassifier
|
||||||
|
equalsCategories:[MPPTextClassifierTests
|
||||||
|
expectedBertResultCategoriesForEdgeCaseTests]];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testClassifyWithCategoryDenylistSucceeds {
|
||||||
|
MPPTextClassifierOptions *options =
|
||||||
|
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
|
||||||
|
options.categoryDenylist = @[ @"positive" ];
|
||||||
|
|
||||||
|
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
|
||||||
|
XCTAssertNotNil(textClassifier);
|
||||||
|
|
||||||
|
[self assertResultsOfClassifyText:kNegativeText
|
||||||
|
usingTextClassifier:textClassifier
|
||||||
|
equalsCategories:[MPPTextClassifierTests
|
||||||
|
expectedBertResultCategoriesForEdgeCaseTests]];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testClassifyWithScoreThresholdSucceeds {
|
||||||
|
MPPTextClassifierOptions *options =
|
||||||
|
[self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
|
||||||
|
options.scoreThreshold = 0.5f;
|
||||||
|
|
||||||
|
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
|
||||||
|
XCTAssertNotNil(textClassifier);
|
||||||
|
|
||||||
|
[self assertResultsOfClassifyText:kNegativeText
|
||||||
|
usingTextClassifier:textClassifier
|
||||||
|
equalsCategories:[MPPTextClassifierTests
|
||||||
|
expectedBertResultCategoriesForEdgeCaseTests]];
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
|
@ -0,0 +1,264 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
import MPPCommon
|
||||||
|
import XCTest
|
||||||
|
|
||||||
|
@testable import MPPTextClassifier
|
||||||
|
|
||||||
|
class TextClassifierTests: XCTestCase {
|
||||||
|
|
||||||
|
static let bundle = Bundle(for: TextClassifierTests.self)
|
||||||
|
|
||||||
|
static let bertModelPath = bundle.path(
|
||||||
|
forResource: "bert_text_classifier",
|
||||||
|
ofType: "tflite")
|
||||||
|
|
||||||
|
static let positiveText = "it's a charming and often affecting journey"
|
||||||
|
|
||||||
|
static let negativeText = "unflinchingly bleak and desperate"
|
||||||
|
|
||||||
|
static let bertNegativeTextResults = [
|
||||||
|
ResultCategory(
|
||||||
|
index: 0,
|
||||||
|
score: 0.956187,
|
||||||
|
categoryName: "negative",
|
||||||
|
displayName: nil),
|
||||||
|
ResultCategory(
|
||||||
|
index: 1,
|
||||||
|
score: 0.043812,
|
||||||
|
categoryName: "positive",
|
||||||
|
displayName: nil),
|
||||||
|
]
|
||||||
|
|
||||||
|
static let bertNegativeTextResultsForEdgeTestCases = [
|
||||||
|
ResultCategory(
|
||||||
|
index: 0,
|
||||||
|
score: 0.956187,
|
||||||
|
categoryName: "negative",
|
||||||
|
displayName: nil)
|
||||||
|
]
|
||||||
|
|
||||||
|
func assertEqualErrorDescriptions(
|
||||||
|
_ error: Error, expectedLocalizedDescription: String
|
||||||
|
) {
|
||||||
|
XCTAssertEqual(
|
||||||
|
error.localizedDescription,
|
||||||
|
expectedLocalizedDescription)
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertCategoriesAreEqual(
|
||||||
|
category: ResultCategory,
|
||||||
|
expectedCategory: ResultCategory,
|
||||||
|
indexInCategoryList: Int
|
||||||
|
) {
|
||||||
|
XCTAssertEqual(
|
||||||
|
category.index,
|
||||||
|
expectedCategory.index,
|
||||||
|
String(
|
||||||
|
format: """
|
||||||
|
category[%d].index and expectedCategory[%d].index are not equal.
|
||||||
|
""", indexInCategoryList))
|
||||||
|
XCTAssertEqual(
|
||||||
|
category.score,
|
||||||
|
expectedCategory.score,
|
||||||
|
accuracy: 1e-3,
|
||||||
|
String(
|
||||||
|
format: """
|
||||||
|
category[%d].score and expectedCategory[%d].score are not equal.
|
||||||
|
""", indexInCategoryList))
|
||||||
|
XCTAssertEqual(
|
||||||
|
category.categoryName,
|
||||||
|
expectedCategory.categoryName,
|
||||||
|
String(
|
||||||
|
format: """
|
||||||
|
category[%d].categoryName and expectedCategory[%d].categoryName are \
|
||||||
|
not equal.
|
||||||
|
""", indexInCategoryList))
|
||||||
|
XCTAssertEqual(
|
||||||
|
category.displayName,
|
||||||
|
expectedCategory.displayName,
|
||||||
|
String(
|
||||||
|
format: """
|
||||||
|
category[%d].displayName and expectedCategory[%d].displayName are \
|
||||||
|
not equal.
|
||||||
|
""", indexInCategoryList))
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertEqualCategoryArrays(
|
||||||
|
categoryArray: [ResultCategory],
|
||||||
|
expectedCategoryArray: [ResultCategory]
|
||||||
|
) {
|
||||||
|
XCTAssertEqual(
|
||||||
|
categoryArray.count,
|
||||||
|
expectedCategoryArray.count)
|
||||||
|
|
||||||
|
for (index, (category, expectedCategory)) in zip(categoryArray, expectedCategoryArray)
|
||||||
|
.enumerated()
|
||||||
|
{
|
||||||
|
assertCategoriesAreEqual(
|
||||||
|
category: category,
|
||||||
|
expectedCategory: expectedCategory,
|
||||||
|
indexInCategoryList: index)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertTextClassifierResultHasOneHead(
|
||||||
|
_ textClassifierResult: TextClassifierResult
|
||||||
|
) {
|
||||||
|
XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1)
|
||||||
|
XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func textClassifierOptionsWithModelPath(
|
||||||
|
_ modelPath: String?
|
||||||
|
) throws -> TextClassifierOptions {
|
||||||
|
let modelPath = try XCTUnwrap(modelPath)
|
||||||
|
|
||||||
|
let textClassifierOptions = TextClassifierOptions()
|
||||||
|
textClassifierOptions.baseOptions.modelAssetPath = modelPath
|
||||||
|
|
||||||
|
return textClassifierOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertCreateTextClassifierThrowsError(
|
||||||
|
textClassifierOptions: TextClassifierOptions,
|
||||||
|
expectedErrorDescription: String
|
||||||
|
) {
|
||||||
|
do {
|
||||||
|
let textClassifier = try TextClassifier(options: textClassifierOptions)
|
||||||
|
XCTAssertNil(textClassifier)
|
||||||
|
} catch {
|
||||||
|
assertEqualErrorDescriptions(
|
||||||
|
error,
|
||||||
|
expectedLocalizedDescription: expectedErrorDescription)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertResultsForClassify(
|
||||||
|
text: String,
|
||||||
|
using textClassifier: TextClassifier,
|
||||||
|
equals expectedCategories: [ResultCategory]
|
||||||
|
) throws {
|
||||||
|
let textClassifierResult =
|
||||||
|
try XCTUnwrap(
|
||||||
|
textClassifier.classify(text: text))
|
||||||
|
assertTextClassifierResultHasOneHead(textClassifierResult)
|
||||||
|
assertEqualCategoryArrays(
|
||||||
|
categoryArray:
|
||||||
|
textClassifierResult.classificationResult.classifications[0].categories,
|
||||||
|
expectedCategoryArray: expectedCategories)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCreateTextClassifierWithInvalidMaxResultsFails() throws {
|
||||||
|
let textClassifierOptions =
|
||||||
|
try XCTUnwrap(
|
||||||
|
textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath))
|
||||||
|
textClassifierOptions.maxResults = 0
|
||||||
|
|
||||||
|
assertCreateTextClassifierThrowsError(
|
||||||
|
textClassifierOptions: textClassifierOptions,
|
||||||
|
expectedErrorDescription: """
|
||||||
|
INVALID_ARGUMENT: Invalid `max_results` option: value must be != 0.
|
||||||
|
""")
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCreateTextClassifierWithCategoryAllowlistAndDenylistFails() throws {
|
||||||
|
|
||||||
|
let textClassifierOptions =
|
||||||
|
try XCTUnwrap(
|
||||||
|
textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath))
|
||||||
|
textClassifierOptions.categoryAllowlist = ["positive"]
|
||||||
|
textClassifierOptions.categoryDenylist = ["positive"]
|
||||||
|
|
||||||
|
assertCreateTextClassifierThrowsError(
|
||||||
|
textClassifierOptions: textClassifierOptions,
|
||||||
|
expectedErrorDescription: """
|
||||||
|
INVALID_ARGUMENT: `category_allowlist` and `category_denylist` are \
|
||||||
|
mutually exclusive options.
|
||||||
|
""")
|
||||||
|
}
|
||||||
|
|
||||||
|
func testClassifyWithBertSucceeds() throws {
|
||||||
|
|
||||||
|
let modelPath = try XCTUnwrap(TextClassifierTests.bertModelPath)
|
||||||
|
let textClassifier = try XCTUnwrap(TextClassifier(modelPath: modelPath))
|
||||||
|
|
||||||
|
try assertResultsForClassify(
|
||||||
|
text: TextClassifierTests.negativeText,
|
||||||
|
using: textClassifier,
|
||||||
|
equals: TextClassifierTests.bertNegativeTextResults)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testClassifyWithMaxResultsSucceeds() throws {
|
||||||
|
let textClassifierOptions =
|
||||||
|
try XCTUnwrap(
|
||||||
|
textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath))
|
||||||
|
textClassifierOptions.maxResults = 1
|
||||||
|
|
||||||
|
let textClassifier =
|
||||||
|
try XCTUnwrap(TextClassifier(options: textClassifierOptions))
|
||||||
|
|
||||||
|
try assertResultsForClassify(
|
||||||
|
text: TextClassifierTests.negativeText,
|
||||||
|
using: textClassifier,
|
||||||
|
equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testClassifyWithCategoryAllowlistSucceeds() throws {
|
||||||
|
let textClassifierOptions =
|
||||||
|
try XCTUnwrap(
|
||||||
|
textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath))
|
||||||
|
textClassifierOptions.categoryAllowlist = ["negative"]
|
||||||
|
|
||||||
|
let textClassifier =
|
||||||
|
try XCTUnwrap(TextClassifier(options: textClassifierOptions))
|
||||||
|
|
||||||
|
try assertResultsForClassify(
|
||||||
|
text: TextClassifierTests.negativeText,
|
||||||
|
using: textClassifier,
|
||||||
|
equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testClassifyWithCategoryDenylistSucceeds() throws {
|
||||||
|
let textClassifierOptions =
|
||||||
|
try XCTUnwrap(
|
||||||
|
textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath))
|
||||||
|
textClassifierOptions.categoryDenylist = ["positive"]
|
||||||
|
|
||||||
|
let textClassifier =
|
||||||
|
try XCTUnwrap(TextClassifier(options: textClassifierOptions))
|
||||||
|
|
||||||
|
try assertResultsForClassify(
|
||||||
|
text: TextClassifierTests.negativeText,
|
||||||
|
using: textClassifier,
|
||||||
|
equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testClassifyWithScoreThresholdSucceeds() throws {
|
||||||
|
let textClassifierOptions =
|
||||||
|
try XCTUnwrap(
|
||||||
|
textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath))
|
||||||
|
textClassifierOptions.scoreThreshold = 0.5
|
||||||
|
|
||||||
|
let textClassifier =
|
||||||
|
try XCTUnwrap(TextClassifier(options: textClassifierOptions))
|
||||||
|
|
||||||
|
try assertResultsForClassify(
|
||||||
|
text: TextClassifierTests.negativeText,
|
||||||
|
using: textClassifier,
|
||||||
|
equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -24,8 +24,5 @@ objc_library(
|
||||||
"-ObjC++",
|
"-ObjC++",
|
||||||
"-std=c++17",
|
"-std=c++17",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = ["//mediapipe/tasks/ios/core:MPPTaskRunner"],
|
||||||
"//mediapipe/tasks/ios/core:MPPTaskRunner",
|
|
||||||
"//third_party/apple_frameworks:Foundation",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,10 +20,7 @@ objc_library(
|
||||||
name = "MPPTextClassifierOptions",
|
name = "MPPTextClassifierOptions",
|
||||||
srcs = ["sources/MPPTextClassifierOptions.m"],
|
srcs = ["sources/MPPTextClassifierOptions.m"],
|
||||||
hdrs = ["sources/MPPTextClassifierOptions.h"],
|
hdrs = ["sources/MPPTextClassifierOptions.h"],
|
||||||
deps = [
|
deps = ["//mediapipe/tasks/ios/core:MPPTaskOptions"],
|
||||||
"//mediapipe/tasks/ios/core:MPPTaskOptions",
|
|
||||||
"//third_party/apple_frameworks:Foundation",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
objc_library(
|
objc_library(
|
||||||
|
@ -33,7 +30,6 @@ objc_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/ios/components/containers:MPPClassificationResult",
|
"//mediapipe/tasks/ios/components/containers:MPPClassificationResult",
|
||||||
"//mediapipe/tasks/ios/core:MPPTaskResult",
|
"//mediapipe/tasks/ios/core:MPPTaskResult",
|
||||||
"//third_party/apple_frameworks:Foundation",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -46,6 +42,7 @@ objc_library(
|
||||||
"-std=c++17",
|
"-std=c++17",
|
||||||
"-x objective-c++",
|
"-x objective-c++",
|
||||||
],
|
],
|
||||||
|
module_name = "MPPTextClassifier",
|
||||||
deps = [
|
deps = [
|
||||||
":MPPTextClassifierOptions",
|
":MPPTextClassifierOptions",
|
||||||
":MPPTextClassifierResult",
|
":MPPTextClassifierResult",
|
||||||
|
@ -59,7 +56,6 @@ objc_library(
|
||||||
"//mediapipe/tasks/ios/text/core:MPPTextTaskRunner",
|
"//mediapipe/tasks/ios/text/core:MPPTextTaskRunner",
|
||||||
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers",
|
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers",
|
||||||
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers",
|
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers",
|
||||||
"//third_party/apple_frameworks:Foundation",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -65,7 +65,7 @@ NS_SWIFT_NAME(TextClassifier)
|
||||||
* @return A new instance of `MPPTextClassifier` with the given model path. `nil` if there is an
|
* @return A new instance of `MPPTextClassifier` with the given model path. `nil` if there is an
|
||||||
* error in initializing the text classifier.
|
* error in initializing the text classifier.
|
||||||
*/
|
*/
|
||||||
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error;
|
- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new instance of `MPPTextClassifier` from the given `MPPTextClassifierOptions`.
|
* Creates a new instance of `MPPTextClassifier` from the given `MPPTextClassifierOptions`.
|
||||||
|
@ -78,8 +78,8 @@ NS_SWIFT_NAME(TextClassifier)
|
||||||
* @return A new instance of `MPPTextClassifier` with the given options. `nil` if there is an
|
* @return A new instance of `MPPTextClassifier` with the given options. `nil` if there is an
|
||||||
* error in initializing the text classifier.
|
* error in initializing the text classifier.
|
||||||
*/
|
*/
|
||||||
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options
|
- (nullable instancetype)initWithOptions:(MPPTextClassifierOptions *)options
|
||||||
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
|
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Performs classification on the input text.
|
* Performs classification on the input text.
|
||||||
|
@ -90,7 +90,8 @@ NS_SWIFT_NAME(TextClassifier)
|
||||||
*
|
*
|
||||||
* @return A `MPPTextClassifierResult` object that contains a list of text classifications.
|
* @return A `MPPTextClassifierResult` object that contains a list of text classifications.
|
||||||
*/
|
*/
|
||||||
- (nullable MPPTextClassifierResult *)classifyText:(NSString *)text error:(NSError **)error;
|
- (nullable MPPTextClassifierResult *)classifyText:(NSString *)text
|
||||||
|
error:(NSError **)error NS_SWIFT_NAME(classify(text:));
|
||||||
|
|
||||||
- (instancetype)init NS_UNAVAILABLE;
|
- (instancetype)init NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user