From d7b0f660e66cb1500e06431ba2bc48ed391e68ad Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 1 Dec 2022 02:44:52 +0530 Subject: [PATCH 01/18] Added Files for Text Classifier iOS Task --- mediapipe/tasks/ios/common/BUILD | 26 +++ .../tasks/ios/common/sources/MPPCommon.h | 179 ++++++++++++++++++ mediapipe/tasks/ios/common/utils/BUILD | 38 ++++ .../ios/common/utils/sources/MPPCommonUtils.h | 78 ++++++++ .../common/utils/sources/MPPCommonUtils.mm | 129 +++++++++++++ .../common/utils/sources/NSString+Helpers.h | 27 +++ .../common/utils/sources/NSString+Helpers.mm | 23 +++ .../tasks/ios/components/containers/BUILD | 32 ++++ .../containers/sources/MPPCategory.h | 62 ++++++ .../containers/sources/MPPCategory.m | 33 ++++ .../sources/MPPClassificationResult.h | 94 +++++++++ .../sources/MPPClassificationResult.m | 50 +++++ .../tasks/ios/components/processors/BUILD | 24 +++ .../processors/sources/MPPClassifierOptions.h | 42 ++++ .../processors/sources/MPPClassifierOptions.m | 40 ++++ .../ios/components/processors/utils/BUILD | 29 +++ .../sources/MPPClassifierOptions+Helpers.h | 25 +++ .../sources/MPPClassifierOptions+Helpers.mm | 38 ++++ mediapipe/tasks/ios/core/BUILD | 67 +++++++ .../tasks/ios/core/sources/MPPBaseOptions.h | 51 +++++ .../tasks/ios/core/sources/MPPBaseOptions.m | 36 ++++ .../tasks/ios/core/sources/MPPExternalFile.h | 28 +++ .../tasks/ios/core/sources/MPPExternalFile.m | 27 +++ .../tasks/ios/core/sources/MPPTaskInfo.h | 64 +++++++ .../tasks/ios/core/sources/MPPTaskInfo.mm | 135 +++++++++++++ .../tasks/ios/core/sources/MPPTaskOptions.h | 58 ++++++ .../tasks/ios/core/sources/MPPTaskOptions.m | 36 ++++ mediapipe/tasks/ios/core/utils/BUILD | 27 +++ .../utils/sources/MPPBaseOptions+Helpers.h | 26 +++ .../utils/sources/MPPBaseOptions+Helpers.mm | 40 ++++ mediapipe/tasks/ios/text/core/BUILD | 33 ++++ .../text/core/sources/MPPBaseTextTaskApi.h | 44 +++++ .../text/core/sources/MPPBaseTextTaskApi.mm | 52 +++++ .../tasks/ios/text/text_classifier/BUILD | 50 +++++ .../sources/MPPTextClassifier.h | 61 ++++++ .../sources/MPPTextClassifier.mm | 65 +++++++ .../sources/MPPTextClassifierOptions.h | 51 +++++ .../sources/MPPTextClassifierOptions.mm | 27 +++ .../ios/text/text_classifier/utils/BUILD | 30 +++ .../MPPTextClassifierOptions+Helpers.h | 26 +++ .../MPPTextClassifierOptions+Helpers.mm | 36 ++++ 41 files changed, 2039 insertions(+) create mode 100644 mediapipe/tasks/ios/common/BUILD create mode 100644 mediapipe/tasks/ios/common/sources/MPPCommon.h create mode 100644 mediapipe/tasks/ios/common/utils/BUILD create mode 100644 mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h create mode 100644 mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm create mode 100644 mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h create mode 100644 mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm create mode 100644 mediapipe/tasks/ios/components/containers/BUILD create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPCategory.h create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPCategory.m create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h create mode 100644 mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m create mode 100644 mediapipe/tasks/ios/components/processors/BUILD create mode 100644 mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h create mode 100644 mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m create mode 100644 mediapipe/tasks/ios/components/processors/utils/BUILD create mode 100644 mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h create mode 100644 mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm create mode 100644 mediapipe/tasks/ios/core/BUILD create mode 100644 mediapipe/tasks/ios/core/sources/MPPBaseOptions.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPBaseOptions.m create mode 100644 mediapipe/tasks/ios/core/sources/MPPExternalFile.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPExternalFile.m create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskInfo.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskOptions.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskOptions.m create mode 100644 mediapipe/tasks/ios/core/utils/BUILD create mode 100644 mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h create mode 100644 mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm create mode 100644 mediapipe/tasks/ios/text/core/BUILD create mode 100644 mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h create mode 100644 mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm create mode 100644 mediapipe/tasks/ios/text/text_classifier/BUILD create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.mm create mode 100644 mediapipe/tasks/ios/text/text_classifier/utils/BUILD create mode 100644 mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h create mode 100644 mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm diff --git a/mediapipe/tasks/ios/common/BUILD b/mediapipe/tasks/ios/common/BUILD new file mode 100644 index 000000000..0d00c423f --- /dev/null +++ b/mediapipe/tasks/ios/common/BUILD @@ -0,0 +1,26 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCommon", + hdrs = [ + "sources/MPPCommon.h", + ], + module_name = "MPPCommon", +) + diff --git a/mediapipe/tasks/ios/common/sources/MPPCommon.h b/mediapipe/tasks/ios/common/sources/MPPCommon.h new file mode 100644 index 000000000..1f450370e --- /dev/null +++ b/mediapipe/tasks/ios/common/sources/MPPCommon.h @@ -0,0 +1,179 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * @enum TFLSupportErrorCode + * This enum specifies error codes for TensorFlow Lite Task Library. + * It maintains a 1:1 mapping to TfLiteSupportErrorCode of C libray. + */ +typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { + + // Generic error codes. + + // Unspecified error. + MPPTasksErrorCodeError = 1, + // Invalid argument specified. + MPPTasksErrorCodeInvalidArgumentError = 2, + // Invalid FlatBuffer file or buffer specified. + MPPTasksErrorCodeInvalidFlatBufferError = 3, + // Model contains a builtin op that isn't supported by the OpResolver or + // delegates. + MPPTasksErrorCodeUnsupportedBuiltinOp = 4, + // Model contains a custom op that isn't supported by the OpResolver or + // delegates. + MPPTasksErrorCodeUnsupportedCustomOp = 5, + + // File I/O error codes. + + // No such file. + MPPTasksErrorCodeFileNotFoundError = 100, + // Permission issue. + MPPTasksErrorCodeFilePermissionDeniedError, + // I/O error when reading file. + MPPTasksErrorCodeFileReadError, + // I/O error when mmap-ing file. + MPPTasksErrorCodeFileMmapError, + // ZIP I/O error when unpacMPPTasksErrorCodeing the zip file. + MPPTasksErrorCodeFileZipError, + + // TensorFlow Lite metadata error codes. + + // Unexpected schema version (aMPPTasksErrorCodea file_identifier) in the Metadata FlatBuffer. + MPPTasksErrorCodeMetadataInvalidSchemaVersionError = 200, + // No such associated file within metadata, or file has not been pacMPPTasksErrorCodeed. + MPPTasksErrorCodeMetadataAssociatedFileNotFoundError, + // ZIP I/O error when unpacMPPTasksErrorCodeing an associated file. + MPPTasksErrorCodeMetadataAssociatedFileZipError, + // Inconsistency error between the metadata and actual TF Lite model. + // E.g.: number of labels and output tensor values differ. + MPPTasksErrorCodeMetadataInconsistencyError, + // Invalid process units specified. + // E.g.: multiple ProcessUnits with the same type for a given tensor. + MPPTasksErrorCodeMetadataInvalidProcessUnitsError, + // Inconsistency error with the number of labels. + // E.g.: label files for different locales have a different number of labels. + MPPTasksErrorCodeMetadataNumLabelsMismatchError, + // Score calibration parameters parsing error. + // E.g.: too many parameters provided in the corresponding associated file. + MPPTasksErrorCodeMetadataMalformedScoreCalibrationError, + // Unexpected number of subgraphs for the current task. + // E.g.: image classification expects a single subgraph. + MPPTasksErrorCodeMetadataInvalidNumSubgraphsError, + // A given tensor requires NormalizationOptions but none were found. + // E.g.: float input tensor requires normalization to preprocess input images. + MPPTasksErrorCodeMetadataMissingNormalizationOptionsError, + // Invalid ContentProperties specified. + // E.g. expected ImageProperties, got BoundingBoxProperties. + MPPTasksErrorCodeMetadataInvalidContentPropertiesError, + // Metadata is mandatory but was not found. + // E.g. current task requires TFLite Model Metadata but none was found. + MPPTasksErrorCodeMetadataNotFoundError, + // Associated TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS file is mandatory but + // none was found or it was empty. + // E.g. current task requires labels but none were found. + MPPTasksErrorCodeMetadataMissingLabelsError, + // The ProcessingUnit for tokenizer is not correctly configured. + // E.g BertTokenizer doesn't have a valid vocab file associated. + MPPTasksErrorCodeMetadataInvalidTokenizerError, + + // Input tensor(s) error codes. + + // Unexpected number of input tensors for the current task. + // E.g. current task expects a single input tensor. + MPPTasksErrorCodeInvalidNumInputTensorsError = 300, + // Unexpected input tensor dimensions for the current task. + // E.g.: only 4D input tensors supported. + MPPTasksErrorCodeInvalidInputTensorDimensionsError, + // Unexpected input tensor type for the current task. + // E.g.: current task expects a uint8 pixel image as input. + MPPTasksErrorCodeInvalidInputTensorTypeError, + // Unexpected input tensor bytes size. + // E.g.: size in bytes does not correspond to the expected number of pixels. + MPPTasksErrorCodeInvalidInputTensorSizeError, + // No correct input tensor found for the model. + // E.g.: input tensor name is not part of the text model's input tensors. + MPPTasksErrorCodeInputTensorNotFoundError, + + // Output tensor(s) error codes. + + // Unexpected output tensor dimensions for the current task. + // E.g.: only a batch size of 1 is supported. + MPPTasksErrorCodeInvalidOutputTensorDimensionsError = 400, + // Unexpected input tensor type for the current task. + // E.g.: multi-head model with different output tensor types. + MPPTasksErrorCodeInvalidOutputTensorTypeError, + // No correct output tensor found for the model. + // E.g.: output tensor name is not part of the text model's output tensors. + MPPTasksErrorCodeOutputTensorNotFoundError, + // Unexpected number of output tensors for the current task. + // E.g.: current task expects a single output tensor. + MPPTasksErrorCodeInvalidNumOutputTensorsError, + + // Image processing error codes. + + // Unspecified image processing failures. + MPPTasksErrorCodeImageProcessingError = 500, + // Unexpected input or output buffer metadata. + // E.g.: rotate RGBA buffer to Grayscale buffer by 90 degrees. + MPPTasksErrorCodeImageProcessingInvalidArgumentError, + // Image processing operation failures. + // E.g. libyuv rotation failed for an unknown reason. + MPPTasksErrorCodeImageProcessingBackendError, + + // Task runner error codes. + MPPTasksErrorCodeRunnerError = 600, + // Task runner is not initialized. + MPPTasksErrorCodeRunnerInitializationError, + // Task runner is not started successfully. + MPPTasksErrorCodeRunnerFailsToStartError, + // Task runner is not started. + MPPTasksErrorCodeRunnerNotStartedError, + // Task runner API is called in the wrong processing mode. + MPPTasksErrorCodeRunnerApiCalledInWrongModeError, + // Task runner receives/produces invalid MediaPipe packet timestamp. + MPPTasksErrorCodeRunnerInvalidTimestampError, + // Task runner receives unexpected MediaPipe graph input packet. + // E.g. The packet type doesn't match the graph input stream's data type. + MPPTasksErrorCodeRunnerUnexpectedInputError, + // Task runner produces unexpected MediaPipe graph output packet. + // E.g. The number of output packets is not equal to the number of graph + // output streams. + MPPTasksErrorCodeRunnerUnexpectedOutputError, + // Task runner is not closed successfully. + MPPTasksErrorCodeRunnerFailsToCloseError, + // Task runner's model resources cache service is unavailable or the + // targeting model resources bundle is not found. + MPPTasksErrorCodeRunnerModelResourcesCacheServiceError, + + // Task graph error codes. + MPPTasksErrorCodeGraphError = 700, + // Task graph is not implemented. + MPPTasksErrorCodeTaskGraphNotImplementedError, + // Task graph config is invalid. + MPPTasksErrorCodeInvalidTaskGraphConfigError, + + MPPTasksErrorCodeFirst = MPPTasksErrorCodeError, + + /** + * The last error code in TFLSupportErrorCode (for internal use only). + */ + MPPTasksErrorCodeLast = MPPTasksErrorCodeInvalidTaskGraphConfigError, + +} NS_SWIFT_NAME(TasksErrorCode); + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/common/utils/BUILD b/mediapipe/tasks/ios/common/utils/BUILD new file mode 100644 index 000000000..9c8c75586 --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/BUILD @@ -0,0 +1,38 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCommonUtils", + srcs = ["sources/MPPCommonUtils.mm"], + hdrs = ["sources/MPPCommonUtils.h"], + deps = [ + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/ios/common:MPPCommon", + ], +) + +objc_library( + name = "NSStringHelpers", + srcs = ["sources/NSString+Helpers.mm"], + hdrs = ["sources/NSString+Helpers.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], +) + diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h new file mode 100644 index 000000000..6b4f40bc6 --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h @@ -0,0 +1,78 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import +#include "mediapipe/tasks/cc/common.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Error domain of TensorFlow Lite Support related errors. */ +extern NSString *const MPPTasksErrorDomain; + +/** Helper utility for the all tasks which encapsulates common functionality. */ +@interface MPPCommonUtils : NSObject + +/** + * Creates and saves an NSError in the Tensorflow Lite Task Library domain, with the given code and + * description. + * + * @param code Error code. + * @param description Error description. + * @param error Pointer to the memory location where the created error should be saved. If `nil`, + * no error will be saved. + */ ++ (void)createCustomError:(NSError **)error + withCode:(NSUInteger)code + description:(NSString *)description; + +/** + * Creates and saves an NSError with the given domain, code and description. + * + * @param error Pointer to the memory location where the created error should be saved. If `nil`, + * no error will be saved. + * @param domain Error domain. + * @param code Error code. + * @param description Error description. + */ ++ (void)createCustomError:(NSError **)error + withDomain:(NSString *)domain + code:(NSUInteger)code + description:(NSString *)description; + +/** + * Converts an absl status to an NSError. + * + * @param status absl status. + * @param error Pointer to the memory location where the created error should be saved. If `nil`, + * no error will be saved. + */ ++ (BOOL)checkCppError:(const absl::Status &)status toError:(NSError **)error; + +/** + * Allocates a block of memory with the specified size and returns a pointer to it. If memory + * cannot be allocated because of an invalid memSize, it saves an error. In other cases, it + * terminates program execution. + * + * @param memSize size of memory to be allocated + * @param error Pointer to the memory location where errors if any should be saved. If `nil`, no + * error will be saved. + * + * @return Pointer to the allocated block of memory on successfull allocation. nil in case as + * error is encountered because of invalid memSize. If failure is due to any other reason, method + * terminates program execution. + */ ++ (void *)mallocWithSize:(size_t)memSize error:(NSError **)error; +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm new file mode 100644 index 000000000..5d8cc6887 --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -0,0 +1,129 @@ +// Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/sources/MPPCommon.h" + +/** Error domain of TensorFlow Lite Support related errors. */ +NSString *const TFLSupportTaskErrorDomain = @"org.tensorflow.lite.tasks"; + +@implementation TFLCommonUtils + ++ (void)createCustomError:(NSError **)error + withCode:(NSUInteger)code + description:(NSString *)description { + [TFLCommonUtils createCustomError:error + withDomain:TFLSupportTaskErrorDomain + code:code + description:description]; +} + ++ (void)createCustomError:(NSError **)error + withDomain:(NSString *)domain + code:(NSUInteger)code + description:(NSString *)description { + if (error) { + *error = [NSError errorWithDomain:domain + code:code + userInfo:@{NSLocalizedDescriptionKey : description}]; + } +} + ++ (void *)mallocWithSize:(size_t)memSize error:(NSError **)error { + if (!memSize) { + [TFLCommonUtils createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description:@"memSize cannot be zero."]; + return NULL; + } + + void *allocedMemory = malloc(memSize); + if (!allocedMemory) { + exit(-1); + } + + return allocedMemory; +} + ++ (BOOL)checkCppError:(const absl::Status &)status toError:(NSError *_Nullable *)error { + if (status.ok()) { + return YES; + } + // Payload of absl::Status created by the tflite task library stores an appropriate value of the + // enum TfLiteSupportStatus. The integer value corresponding to the TfLiteSupportStatus enum + // stored in the payload is extracted here to later map to the appropriate error code to be + // returned. In cases where the enum is not stored in (payload is NULL or the payload string + // cannot be converted to an integer), we set the error code value to be 1 + // (TFLSupportErrorCodeUnspecifiedError of TFLSupportErrorCode used in the iOS library to signify + // any errors not falling into other categories.) Since payload is of type absl::Cord that can be + // type cast into an absl::optional, we use the std::stoi function to convert it into + // an integer code if possible. + NSUInteger genericErrorCode = MPPTasksErrorCodeError; + NSUInteger errorCode; + try { + // Try converting payload to integer if payload is not empty. Otherwise convert a string + // signifying generic error code TFLSupportErrorCodeUnspecifiedError to integer. + errorCode = + (NSUInteger)std::stoi(static_cast>( + status.GetPayload(mediapipe::tasks::kMediaPipeTasksPayload)) + .value_or(std::to_string(genericErrorCode))); + } catch (std::invalid_argument &e) { + // If non empty payload string cannot be converted to an integer. Set error code to 1(kError). + errorCode = MPPTasksErrorCodeError; + } + + // If errorCode is outside the range of enum values possible or is + // TFLSupportErrorCodeUnspecifiedError, we try to map the absl::Status::code() to assign + // appropriate TFLSupportErrorCode or TFLSupportErrorCodeUnspecifiedError in default cases. Note: + // The mapping to absl::Status::code() is done to generate a more specific error code than + // TFLSupportErrorCodeUnspecifiedError in cases when the payload can't be mapped to + // TfLiteSupportStatus. This can happen when absl::Status returned by TfLite are in turn returned + // without moodification by TfLite Support Methods. + if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) { + switch (status.code()) { + case absl::StatusCode::kInternal: + errorCode = MPPTasksErrorCodeError; + break; + case absl::StatusCode::kInvalidArgument: + errorCode = MPPTasksErrorCodeInvalidArgumentError; + break; + case absl::StatusCode::kNotFound: + errorCode = MPPTasksErrorCodeError; + break; + default: + errorCode = MPPTasksErrorCodeError; + break; + } + } + + // Creates the NSEror with the appropriate error + // TFLSupportErrorCode and message. TFLSupportErrorCode has a one to one + // mapping with TfLiteSupportStatus starting from the value 1(TFLSupportErrorCodeUnspecifiedError) + // and hence will be correctly initialized if directly cast from the integer code derived from + // TfLiteSupportStatus stored in its payload. TFLSupportErrorCode omits kOk = 0 of + // TfLiteSupportStatus. + // + // Stores a string including absl status code and message(if non empty) as the + // error message See + // https://github.com/abseil/abseil-cpp/blob/master/absl/status/status.h#L514 + // for explanation. absl::Status::message() can also be used but not always + // guaranteed to be non empty. + NSString *description = [NSString + stringWithCString:status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str() + encoding:NSUTF8StringEncoding]; + [TFLCommonUtils createCustomError:error withCode:errorCode description:description]; + return NO; +} + +@end diff --git a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h new file mode 100644 index 000000000..31828a367 --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h @@ -0,0 +1,27 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#import +#include + +NS_ASSUME_NONNULL_BEGIN + +@interface NSString (Helpers) + +@property(readonly) std::string cppString; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm new file mode 100644 index 000000000..540903c6c --- /dev/null +++ b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm @@ -0,0 +1,23 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" + +@implementation NSString (Helpers) + +- (std::string)cppString { + return std::string(self.UTF8String, [self lengthOfBytesUsingEncoding:NSUTF8StringEncoding]); +} + +@end diff --git a/mediapipe/tasks/ios/components/containers/BUILD b/mediapipe/tasks/ios/components/containers/BUILD new file mode 100644 index 000000000..ce80571e9 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/BUILD @@ -0,0 +1,32 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCategory", + srcs = ["sources/MPPCategory.m"], + hdrs = ["sources/MPPCategory.h"], +) + +objc_library( + name = "MPPClassificationResult", + srcs = ["sources/MPPClassificationResult.m"], + hdrs = ["sources/MPPClassificationResult.h"], + deps = [ + ":MPPCategory", + ], +) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h new file mode 100644 index 000000000..e2f8a0729 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h @@ -0,0 +1,62 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import + +NS_ASSUME_NONNULL_BEGIN + +/** Encapsulates information about a class in the classification results. */ +NS_SWIFT_NAME(ClassificationCategory) +@interface TFLCategory : NSObject + +/** Index of the class in the corresponding label map, usually packed in the TFLite Model + * Metadata. */ +@property(nonatomic, readonly) NSInteger index; + +/** Confidence score for this class . */ +@property(nonatomic, readonly) float score; + +/** Class name of the class. */ +@property(nonatomic, readonly, nullable) NSString *label; + +/** Display name of the class. */ +@property(nonatomic, readonly, nullable) NSString *displayName; + +/** + * Initializes a new `TFLCategory` with the given index, score, label and display name. + * + * @param index Index of the class in the corresponding label map, usually packed in the TFLite + * Model Metadata. + * + * @param score Confidence score for this class. + * + * @param label Class name of the class. + * + * @param displayName Display name of the class. + * + * @return An instance of `TFLCategory` initialized with the given index, score, label and display + * name. + */ +- (instancetype)initWithIndex:(NSInteger)index + score:(float)score + label:(nullable NSString *)label + displayName:(nullable NSString *)displayName; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m new file mode 100644 index 000000000..30efa239a --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m @@ -0,0 +1,33 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/components/containers/sources/TFLCategory.h" + +@implementation TFLCategory + +- (instancetype)initWithIndex:(NSInteger)index + score:(float)score + label:(nullable NSString *)label + displayName:(nullable NSString *)displayName { + self = [super init]; + if (self) { + _index = index; + _score = score; + _label = label; + _displayName = displayName; + } + return self; +} + +@end diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h new file mode 100644 index 000000000..120780a7b --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h @@ -0,0 +1,94 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import +#import "mediapipe/tasks/ios/task/components/containers/sources/MPPCategory.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Encapsulates list of predicted classes (aka labels) for a given image classifier head. */ +NS_SWIFT_NAME(Classifications) +@interface MPPClassifications : NSObject + +/** + * The index of the classifier head these classes refer to. This is useful for multi-head + * models. + */ +@property(nonatomic, readonly) NSInteger headIndex; + +/** The name of the classifier head, which is the corresponding tensor metadata + * name. + */ +@property(nonatomic, readonly) NSString *headName; + +/** The array of predicted classes, usually sorted by descending scores (e.g.from high to low + * probability). */ +@property(nonatomic, readonly) NSArray *categories; + +/** + * Initializes a new `MPPClassifications` with the given head index and array of categories. + * head name is initialized to `nil`. + * + * @param headIndex The index of the image classifier head these classes refer to. + * @param categories An array of `MPPCategory` objects encapsulating a list of + * predictions usually sorted by descending scores (e.g. from high to low probability). + * + * @return An instance of `MPPClassifications` initialized with the given head index and + * array of categories. + */ +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + categories:(NSArray *)categories; + +/** + * Initializes a new `MPPClassifications` with the given head index, head name and array of + * categories. + * + * @param headIndex The index of the classifier head these classes refer to. + * @param headName The name of the classifier head, which is the corresponding tensor metadata + * name. + * @param categories An array of `MPPCategory` objects encapsulating a list of + * predictions usually sorted by descending scores (e.g. from high to low probability). + * + * @return An object of `MPPClassifications` initialized with the given head index, head name and + * array of categories. + */ +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + headName:(nullable NSString *)headName + categories:(NSArray *)categories; + +@end + +/** Encapsulates results of any classification task. */ +NS_SWIFT_NAME(ClassificationResult) +@interface MPPClassificationResult : NSObject + +/** Array of MPPClassifications objects containing classifier predictions per image classifier + * head. + */ +@property(nonatomic, readonly) NSArray *classifications; + +/** + * Initializes a new `MPPClassificationResult` with the given array of classifications. + * + * @param classifications An Aaray of `MPPClassifications` objects containing classifier + * predictions per classifier head. + * + * @return An instance of MPPClassificationResult initialized with the given array of + * classifications. + */ +- (instancetype)initWithClassifications:(NSArray *)classifications; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m new file mode 100644 index 000000000..266b401e0 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m @@ -0,0 +1,50 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" + +@implementation MPPClassifications + +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + headName:(nullable NSString *)headName + categories:(NSArray *)categories { + self = [super init]; + if (self) { + _headIndex = headIndex; + _headName = headName; + _categories = categories; + } + return self; +} + +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + categories:(NSArray *)categories { + return [self initWithHeadIndex:headIndex headName:nil categories:categories]; +} + +@end + +@implementation MPPClassificationResult { + NSArray *_classifications; +} + +- (instancetype)initWithClassifications:(NSArray *)classifications { + self = [super init]; + if (self) { + _classifications = classifications; + } + return self; +} + +@end diff --git a/mediapipe/tasks/ios/components/processors/BUILD b/mediapipe/tasks/ios/components/processors/BUILD new file mode 100644 index 000000000..6d1cfdf59 --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/BUILD @@ -0,0 +1,24 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPClassifierOptions", + srcs = ["sources/MPPClassifierOptions.m"], + hdrs = ["sources/MPPClassifierOptions.h"], +) + diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h new file mode 100644 index 000000000..7538a625d --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -0,0 +1,42 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * Holds settings for any single iOS Mediapipe classification task. + */ +NS_SWIFT_NAME(ClassifierOptions) +@interface MPPClassifierOptions : NSObject + +/** If set, all classes in this list will be filtered out from the results . */ +@property(nonatomic, copy) NSArray *labelDenyList; + +/** If set, all classes not in this list will be filtered out from the results . */ +@property(nonatomic, copy) NSArray *labelAllowList; + +/** Display names local for display names*/ +@property(nonatomic, copy) NSString *displayNamesLocale; + +/** Results with score threshold greater than this value are returned . */ +@property(nonatomic) float scoreThreshold; + +/** Limit to the number of classes that can be returned in results. */ +@property(nonatomic) NSInteger maxResults; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m new file mode 100644 index 000000000..5a5baf79f --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m @@ -0,0 +1,40 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" + +@implementation MPPClassifierOptions + +- (instancetype)init { + self = [super init]; + if (self) { + self.maxResults = -1; + self.scoreThreshold = 0; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + MPPClassifierOptions *classifierOptions = [[MPPClassifierOptions alloc] init]; + + classifierOptions.scoreThreshold = self.scoreThreshold; + classifierOptions.maxResults = self.maxResults; + classifierOptions.labelDenyList = self.labelDenyList; + classifierOptions.labelAllowList = self.labelAllowList; + classifierOptions.displayNamesLocale = self.displayNamesLocale; + + return classifierOptions; +} + +@end diff --git a/mediapipe/tasks/ios/components/processors/utils/BUILD b/mediapipe/tasks/ios/components/processors/utils/BUILD new file mode 100644 index 000000000..820c6bb56 --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/utils/BUILD @@ -0,0 +1,29 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPClassifierOptionsHelpers", + srcs = ["sources/MPPClassifierOptions+Helpers.mm"], + hdrs = ["sources/MPPClassifierOptions+Helpers.h"], + deps = [ + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/ios/components/processors:MPPClassifierOptions", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + ] +) + diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h new file mode 100644 index 000000000..4defb3ee7 --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h @@ -0,0 +1,25 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" +#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPClassifierOptions (Helpers) +- (void)copyToProto: + (mediapipe::tasks::components::processors::proto::ClassifierOptions *)classifierOptionsProto; +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm new file mode 100644 index 000000000..efb220147 --- /dev/null +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm @@ -0,0 +1,38 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h" + +namespace { +using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto::ClassifierOptions; +} + +@implementation MPPClassifierOptions (Helpers) +- (void)copyToProto:(ClassifierOptionsProto *)classifierOptionsProto { + if (self.displayNamesLocale) { + classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); + } + classifierOptionsProto->set_max_results((int)self.maxResults); + classifierOptionsProto->set_score_threshold(self.scoreThreshold); + for (NSString *category in self.labelAllowList) { + classifierOptionsProto->add_category_allowlist(category.cppString); + } + + for (NSString *category in self.labelDenyList) { + classifierOptionsProto->add_category_denylist(category.cppString); + } +} + +@end diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD new file mode 100644 index 000000000..91257b4de --- /dev/null +++ b/mediapipe/tasks/ios/core/BUILD @@ -0,0 +1,67 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPExternalFile", + srcs = ["sources/MPPExternalFile.m"], + hdrs = ["sources/MPPExternalFile.h"], +) + +objc_library( + name = "MPPBaseOptions", + srcs = ["sources/MPPBaseOptions.m"], + hdrs = ["sources/MPPBaseOptions.h"], + deps = [ + ":MPPExternalFile", + + ], +) + +objc_library( + name = "MPPTaskOptions", + srcs = ["sources/MPPTaskOptions.m"], + hdrs = ["sources/MPPTaskOptions.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/framework:calculator_options_cc_proto", + ":MPPBaseOptions", + ], +) + +objc_library( + name = "MPPTaskInfo", + srcs = ["sources/MPPTaskInfo.mm"], + hdrs = ["sources/MPPTaskInfo.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", + ":MPPTaskOptions", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/common:MPPCommon", + ], +) + diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h new file mode 100644 index 000000000..686e50add --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h @@ -0,0 +1,51 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import +#import "mediapipe/tasks/ios/core/sources/MPPExternalFile.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * MediaPipe Tasks delegate. + */ +typedef NS_ENUM(NSUInteger, MPPDelegate) { + + /** CPU. */ + MPPDelegateCPU, + + /** GPU. */ + MPPDelegateGPU +} NS_SWIFT_NAME(Delegate); + +/** + * Holds the base options that is used for creation of any type of task. It has fields with + * important information acceleration configuration, tflite model source etc. + */ +NS_SWIFT_NAME(BaseOptions) +@interface MPPBaseOptions : NSObject + +/** + * The external model file, as a single standalone TFLite file. It could be packed with TFLite Model + * Metadata[1] and associated files if exist. Fail to provide the necessary metadata and associated + * files might result in errors. + */ +@property(nonatomic, copy) MPPExternalFile *modelAssetFile; + +/** + * device delegate to run the MediaPipe pipeline. If the delegate is not set, the default + * delegate CPU is used. + */ +@property(nonatomic) MPPDelegate delegate; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m new file mode 100644 index 000000000..4c25b80e8 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m @@ -0,0 +1,36 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" + +@implementation MPPBaseOptions + +- (instancetype)init { + self = [super init]; + if (self) { + self.modelAssetFile = [[MPPExternalFile alloc] init]; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init]; + + baseOptions.modelAssetFile = self.modelAssetFile; + baseOptions.delegate = self.delegate; + + return baseOptions; +} + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPExternalFile.h b/mediapipe/tasks/ios/core/sources/MPPExternalFile.h new file mode 100644 index 000000000..a97802002 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPExternalFile.h @@ -0,0 +1,28 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * Holds information about an external file. + */ +NS_SWIFT_NAME(ExternalFile) +@interface MPPExternalFile : NSObject + +/** Path to the file in bundle. */ +@property(nonatomic, copy) NSString *filePath; +/// Add provision for other sources in future. + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPExternalFile.m b/mediapipe/tasks/ios/core/sources/MPPExternalFile.m new file mode 100644 index 000000000..70d85657c --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPExternalFile.m @@ -0,0 +1,27 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/core/sources/MPPExternalFile.h" + +@implementation MPPExternalFile + +- (id)copyWithZone:(NSZone *)zone { + MPPExternalFile *externalFile = [[MPPExternalFile alloc] init]; + + externalFile.filePath = self.filePath; + + return externalFile; +} + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h new file mode 100644 index 000000000..620184518 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h @@ -0,0 +1,64 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import +#include "mediapipe/framework/calculator.pb.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Holds all needed informaton to initialize a MediaPipe Task. + */ +@interface MPPTaskInfo : NSObject + +@property(nonatomic, copy, nonnull) NSString *taskGraphName; + +/** + * A task-specific options that is derived from MPPTaskOptions and confirms to + * MPPTaskOptionsProtocol. + */ +@property(nonatomic, copy) id taskOptions; + +/** + * List of task graph input stream info strings in the form TAG:name. + */ +@property(nonatomic, copy) NSArray *inputStreams; + +/** + * List of task graph output stream info in the form TAG:name. + */ +@property(nonatomic, copy) NSArray *outputStreams; + +/** + * If the task requires a flow limiter. + */ +@property(nonatomic) BOOL enableFlowLimiting; + ++ (instancetype)new NS_UNAVAILABLE; + +- (instancetype)initWithTaskGraphName:(NSString *)taskGraphName + inputStreams:(NSArray *)inputStreams + outputStreams:(NSArray *)outputStreams + taskOptions:(id)taskOptions + enableFlowLimiting:(BOOL)enableFlowLimiting + error:(NSError **)error; + +/** + * Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance. + */ +- (mediapipe::CalculatorGraphConfig)generateGraphConfig; + +- (instancetype)init NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm new file mode 100644 index 000000000..7e42d6eae --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm @@ -0,0 +1,135 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/common/sources/MPPCommon.h" +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" + +#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_options.pb.h" + +namespace { +using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig; +using Node = ::mediapipe::CalculatorGraphConfig::Node; +using InputStreamInfo = ::mediapipe::InputStreamInfo; +using CalculatorOptions = ::mediapipe::CalculatorOptions; +using FlowLimiterCalculatorOptions = ::mediapipe::FlowLimiterCalculatorOptions; +} // namespace + +@implementation MPPTaskInfo + +- (instancetype)initWithTaskGraphName:(NSString *)taskGraphName + inputStreams:(NSArray *)inputStreams + outputStreams:(NSArray *)outputStreams + taskOptions:(id)taskOptions + enableFlowLimiting:(BOOL)enableFlowLimiting + error:(NSError **)error { + self = [super init]; + if (!taskGraphName || !inputStreams.count || !outputStreams.count) { + [MPPCommonUtils + createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description: + @"Task graph's name, input streams, and output streams should be non-empty."]; + } + + if (self) { + _taskGraphName = taskGraphName; + _inputStreams = inputStreams; + _outputStreams = outputStreams; + _taskOptions = taskOptions; + _enableFlowLimiting = enableFlowLimiting; + } + return self; +} + +- (id)copyWithZone:(NSZone *)zone { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] init]; + + taskInfo.taskGraphName = self.taskGraphName; + taskInfo.inputStreams = self.inputStreams; + taskInfo.outputStreams = self.outputStreams; + taskInfo.taskOptions = self.taskOptions; + taskInfo.enableFlowLimiting = self.enableFlowLimiting; + + return taskInfo; +} + +- (CalculatorGraphConfig)generateGraphConfig { + CalculatorGraphConfig graph_config; + + Node *task_subgraph_node = graph_config.add_node(); + task_subgraph_node->set_calculator(self.taskGraphName.cppString); + [self.taskOptions copyToProto:task_subgraph_node->mutable_options()]; + + for (NSString *outputStream in self.outputStreams) { + auto cpp_output_stream = std::string(outputStream.cppString); + task_subgraph_node->add_output_stream(cpp_output_stream); + graph_config.add_output_stream(cpp_output_stream); + } + + if (self.enableFlowLimiting) { + Node *flow_limit_calculator_node = graph_config.add_node(); + + flow_limit_calculator_node->set_calculator("FlowLimiterCalculator"); + + InputStreamInfo *input_stream_info = flow_limit_calculator_node->add_input_stream_info(); + input_stream_info->set_tag_index("FINISHED"); + input_stream_info->set_back_edge(true); + + FlowLimiterCalculatorOptions *flow_limit_calculator_options = + flow_limit_calculator_node->mutable_options()->MutableExtension( + FlowLimiterCalculatorOptions::ext); + flow_limit_calculator_options->set_max_in_flight(1); + flow_limit_calculator_options->set_max_in_queue(1); + + for (NSString *inputStream in self.inputStreams) { + graph_config.add_input_stream(inputStream.cppString); + + NSString *strippedInputStream = [MPPTaskInfo stripTagIndex:inputStream]; + flow_limit_calculator_node->add_input_stream(strippedInputStream.cppString); + + NSString *taskInputStream = [MPPTaskInfo addStreamNamePrefix:inputStream]; + task_subgraph_node->add_input_stream(taskInputStream.cppString); + + NSString *strippedTaskInputStream = [MPPTaskInfo stripTagIndex:taskInputStream]; + flow_limit_calculator_node->add_output_stream(strippedTaskInputStream.cppString); + } + + NSString *firstOutputStream = self.outputStreams[0]; + auto finished_output_stream = "FINISHED:" + firstOutputStream.cppString; + flow_limit_calculator_node->add_input_stream(finished_output_stream); + } else { + for (NSString *inputStream in self.inputStreams) { + auto cpp_input_stream = inputStream.cppString; + task_subgraph_node->add_input_stream(cpp_input_stream); + graph_config.add_input_stream(cpp_input_stream); + } + } + + return graph_config; +} + ++ (NSString *)stripTagIndex:(NSString *)tagIndexName { + return [tagIndexName componentsSeparatedByString:@":"][1]; +} + ++ (NSString *)addStreamNamePrefix:(NSString *)tagIndexName { + NSArray *splits = [tagIndexName componentsSeparatedByString:@":"]; + return [NSString stringWithFormat:@"%@:throttled_%@", splits[0], splits[1]]; +} + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h new file mode 100644 index 000000000..e40e92657 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h @@ -0,0 +1,58 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import +#include "mediapipe/framework/calculator_options.pb.h" +#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend + * this class. + */ +NS_SWIFT_NAME(TaskOptions) +@interface MPPTaskOptions : NSObject +/** + * Base options for configuring the Mediapipe task. + */ +@property(nonatomic, copy) MPPBaseOptions *baseOptions; + +/** + * Initializes a new `MPPTaskOptions` with the absolute path to the model file + * stored locally on the device, set to the given the model path. + * + * @discussion The external model file must be a single standalone TFLite file. It could be packed + * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the + * necessary metadata and associated files might result in errors. Check the [documentation] + * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * + * @return An instance of `MPPTaskOptions` initialized to the given model path. + */ +- (instancetype)initWithModelPath:(NSString *)modelPath; + +@end + +/** + * Any mediapipe task options should confirm to this protocol. + */ +@protocol MPPTaskOptionsProtocol + +/** + * Copies the iOS Mediapipe task options to an object of mediapipe::CalculatorOptions proto. + */ +- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m new file mode 100644 index 000000000..ec1adbaf1 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m @@ -0,0 +1,36 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" + +@implementation MPPTaskOptions + +- (instancetype)init { + self = [super init]; + if (self) { + _baseOptions = [[MPPBaseOptions alloc] init]; + } + return self; +} + +- (instancetype)initWithModelPath:(NSString *)modelPath { + self = [self init]; + if (self) { + _baseOptions.modelAssetFile.filePath = modelPath; + } + return self; +} + +@end diff --git a/mediapipe/tasks/ios/core/utils/BUILD b/mediapipe/tasks/ios/core/utils/BUILD new file mode 100644 index 000000000..d9fbaf375 --- /dev/null +++ b/mediapipe/tasks/ios/core/utils/BUILD @@ -0,0 +1,27 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPBaseOptionsHelpers", + srcs = ["sources/MPPBaseOptions+Helpers.mm"], + hdrs = ["sources/MPPBaseOptions+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/core:MPPBaseOptions", + "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", + ], +) diff --git a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h new file mode 100644 index 000000000..f57844a8f --- /dev/null +++ b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h @@ -0,0 +1,26 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPBaseOptions (Helpers) + +- (void)copyToProto:(mediapipe::tasks::core::proto::BaseOptions *)baseOptionsProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm new file mode 100644 index 000000000..f20f8602a --- /dev/null +++ b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm @@ -0,0 +1,40 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" + +namespace { +using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions; +} + +@implementation MPPBaseOptions (Helpers) + +- (void)copyToProto:(BaseOptionsProto *)baseOptionsProto { + if (self.modelAssetFile.filePath) { + baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetFile.filePath.UTF8String); + } + + switch (self.delegate) { + case MPPDelegateCPU: { + baseOptionsProto->mutable_acceleration()->mutable_tflite(); + break; + } + case MPPDelegateGPU: + break; + default: + break; + } +} + +@end diff --git a/mediapipe/tasks/ios/text/core/BUILD b/mediapipe/tasks/ios/text/core/BUILD new file mode 100644 index 000000000..abb8edc71 --- /dev/null +++ b/mediapipe/tasks/ios/text/core/BUILD @@ -0,0 +1,33 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPBaseTextTaskApi", + srcs = ["sources/MPPBaseTextTaskApi.mm"], + hdrs = ["sources/MPPBaseTextTaskApi.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + ], +) + diff --git a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h b/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h new file mode 100644 index 000000000..3d2fd4f43 --- /dev/null +++ b/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h @@ -0,0 +1,44 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import + +#include "mediapipe/framework/calculator.pb.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * The base class of the user-facing iOS mediapipe text task api classes. + */ +NS_SWIFT_NAME(BaseTextTaskApi) +@interface MPPBaseTextTaskApi : NSObject + +/** + * Initializes a new `MPPBaseTextTaskApi` with the mediapipe text task graph config proto. + * + * @param graphConfig A mediapipe text task graph config proto. + * + * @return An instance of `MPPBaseTextTaskApi` initialized to the given graph config proto. + */ +- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig + error:(NSError **)error; +- (void)close; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm b/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm new file mode 100644 index 000000000..9d7142fe5 --- /dev/null +++ b/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm @@ -0,0 +1,52 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h" + +#include "mediapipe/tasks/cc/core/task_runner.h" +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" + +namespace { +using ::mediapipe::CalculatorGraphConfig; +using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; +} // namespace + +@interface MPPBaseTextTaskApi () { + /** TextSearcher backed by C++ API */ + std::unique_ptr _taskRunner; +} +@end + +@implementation MPPBaseTextTaskApi + +- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig + error:(NSError **)error { + self = [super init]; + if (self) { + auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig)); + + if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) { + return nil; + } + + _taskRunner = std::move(taskRunnerResult.value()); + } + return self; +} + +- (void)close { + _taskRunner->Close(); +} + +@end diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD new file mode 100644 index 000000000..9ed25a852 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -0,0 +1,50 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPTextClassifier", + srcs = ["sources/MPPTextClassifier.mm"], + hdrs = ["sources/MPPTextClassifier.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/tasks/ios/core:MPPTaskOptions", + "//mediapipe/tasks/ios/core:MPPTaskInfo", + "//mediapipe/tasks/ios/text/core:MPPBaseTextTaskApi", + "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + ":MPPTextClassifierOptions", + ], +) + +objc_library( + name = "MPPTextClassifierOptions", + srcs = ["sources/MPPTextClassifierOptions.mm"], + hdrs = ["sources/MPPTextClassifierOptions.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/tasks/ios/core:MPPTaskOptions", + "//mediapipe/tasks/ios/components/processors:MPPClassifierOptions", + ], +) + diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h new file mode 100644 index 000000000..d6b2f770f --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -0,0 +1,61 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * A Mediapipe iOS Text Classifier. + */ +NS_SWIFT_NAME(TextClassifier) +@interface MPPTextClassifier : MPPBaseTextTaskApi + +/** + * Creates a new instance of `MPPTextClassifier` from an absolute path to a TensorFlow Lite model + * file stored locally on the device. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * + * @param error An optional error parameter populated when there is an error in initializing + * the text classifier. + * + * @return A new instance of `MPPTextClassifier` with the given model path. `nil` if there is an + * error in initializing the text classifier. + */ +- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; + +/** + * Creates a new instance of `MPPTextClassifier` from the given text classifier options. + * + * @param options The options to use for configuring the `MPPTextClassifier`. + * @param error An optional error parameter populated when there is an error in initializing + * the text classifier. + * + * @return A new instance of `MPPTextClassifier` with the given options. `nil` if there is an error + * in initializing the text classifier. + */ +- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm new file mode 100644 index 000000000..56bc4930f --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -0,0 +1,65 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" + +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" + +NSString *kClassificationsStreamName = @"classifications_out"; +NSString *kClassificationsTag = @"classifications"; +NSString *kTextInStreamName = @"text_in"; +NSString *kTextTag = @"TEXT"; +NSString *kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph"; + +@implementation MPPTextClassifierOptions + +- (instancetype)initWithModelPath:(NSString *)modelPath { + self = [super initWithModelPath:modelPath]; + if (self) { + _classifierOptions = [[MPPClassifierOptions alloc] init]; + } + return self; +} + +@end + +@implementation MPPTextClassifier + +- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] + initWithTaskGraphName:kTaskGraphName + inputStreams:@[ [NSString stringWithFormat:@"@:@", kTextTag, kTextInStreamName] ] + outputStreams:@[ [NSString stringWithFormat:@"@:@", kClassificationsTag, + kClassificationsStreamName] ] + taskOptions:options + enableFlowLimiting:NO + error:error]; + + if (!taskInfo) { + return nil; + } + + return [super initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error]; +} + +- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { + MPPTextClassifierOptions *options = + [[MPPTextClassifierOptions alloc] initWithModelPath:modelPath]; + + return [self initWithOptions:options error:error]; +} + +@end diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h new file mode 100644 index 000000000..47c44dd0d --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h @@ -0,0 +1,51 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import + +#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Options to configure MPPTextClassifierOptions. + */ +NS_SWIFT_NAME(TextClassifierOptions) +@interface MPPTextClassifierOptions : MPPTaskOptions + +/** + * Options controlling the behavior of the embedding model specified in the + * base options. + */ +@property(nonatomic, copy) MPPClassifierOptions *classifierOptions; + +/** + * Initializes a new `MPPTextClassifierOptions` with the absolute path to the model file + * stored locally on the device, set to the given the model path. + * + * @discussion The external model file must be a single standalone TFLite file. It could be packed + * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the + * necessary metadata and associated files might result in errors. Check the [documentation] + * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * + * @return An instance of `MPPTextClassifierOptions` initialized to the given model path. + */ +- (instancetype)initWithModelPath:(NSString *)modelPath; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.mm new file mode 100644 index 000000000..8cab693cd --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.mm @@ -0,0 +1,27 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" + +@implementation MPPTextClassifierOptions + +- (instancetype)initWithModelPath:(NSString *)modelPath { + self = [super initWithModelPath:modelPath]; + if (self) { + _classifierOptions = [[MPPClassifierOptions alloc] init]; + } + return self; +} + +@end \ No newline at end of file diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD new file mode 100644 index 000000000..453a30f54 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD @@ -0,0 +1,30 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPTextClassifierOptionsHelpers", + srcs = ["sources/MPPTextClassifierOptions+Helpers.mm"], + hdrs = ["sources/MPPTextClassifierOptions+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierOptions", + "//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers", + "//mediapipe/tasks/ios/components/processors/utils:MPPClassifierOptionsHelpers", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", + ], +) diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h new file mode 100644 index 000000000..8df471d05 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h @@ -0,0 +1,26 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#include "mediapipe/framework/calculator_options.pb.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPTextClassifierOptions (Helpers) + +- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm new file mode 100644 index 000000000..3576cb8d2 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm @@ -0,0 +1,36 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h" +#import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h" +#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" + +namespace { +using CalculatorOptionsProto = ::mediapipe::CalculatorOptions; +using TextClassifierGraphOptionsProto = + ::mediapipe::tasks::text::text_classifier::proto::TextClassifierGraphOptions; + +} // namespace + +@implementation MPPTextClassifierOptions (Helpers) + +- (void)copyToProto:(CalculatorOptionsProto *)optionsProto { + TextClassifierGraphOptionsProto *graph_options = + optionsProto->MutableExtension(TextClassifierGraphOptionsProto::ext); + [self.baseOptions copyToProto:graph_options->mutable_base_options()]; + [self.classifierOptions copyToProto:graph_options->mutable_classifier_options()]; +} + +@end From 1112e558a518b44f9df8a734cd5bef565dc90cbc Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 1 Dec 2022 03:23:58 +0530 Subject: [PATCH 02/18] Added constants for text classifier --- .../text/text_classifier/sources/MPPTextClassifier.mm | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm index 56bc4930f..07277874e 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -18,11 +18,11 @@ #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" #import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" -NSString *kClassificationsStreamName = @"classifications_out"; -NSString *kClassificationsTag = @"classifications"; -NSString *kTextInStreamName = @"text_in"; -NSString *kTextTag = @"TEXT"; -NSString *kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph"; +static NSString *const kClassificationsStreamName = @"classifications_out"; +static NSString *const kClassificationsTag = @"classifications"; +static NSString *const kTextInStreamName = @"text_in"; +static NSString *const kTextTag = @"TEXT"; +static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph"; @implementation MPPTextClassifierOptions From bd1fb717d38f5cc8385bf5772058776f601c9123 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 1 Dec 2022 04:12:14 +0530 Subject: [PATCH 03/18] Removed duplicate implementation --- .../text_classifier/sources/MPPTextClassifier.mm | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm index 07277874e..08557d94e 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -24,18 +24,6 @@ static NSString *const kTextInStreamName = @"text_in"; static NSString *const kTextTag = @"TEXT"; static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph"; -@implementation MPPTextClassifierOptions - -- (instancetype)initWithModelPath:(NSString *)modelPath { - self = [super initWithModelPath:modelPath]; - if (self) { - _classifierOptions = [[MPPClassifierOptions alloc] init]; - } - return self; -} - -@end - @implementation MPPTextClassifier - (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error { From 066ffd36d5d2c635fbabd7e2affe68f1049def09 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 1 Dec 2022 04:12:26 +0530 Subject: [PATCH 04/18] Fixed errors in common utils --- .../common/utils/sources/MPPCommonUtils.mm | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 5d8cc6887..6b277ae0c 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -16,15 +16,15 @@ #import "mediapipe/tasks/ios/common/sources/MPPCommon.h" /** Error domain of TensorFlow Lite Support related errors. */ -NSString *const TFLSupportTaskErrorDomain = @"org.tensorflow.lite.tasks"; +NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; -@implementation TFLCommonUtils +@implementation MPPCommonUtils + (void)createCustomError:(NSError **)error withCode:(NSUInteger)code description:(NSString *)description { - [TFLCommonUtils createCustomError:error - withDomain:TFLSupportTaskErrorDomain + [MPPCommonUtils createCustomError:error + withDomain:MPPTasksErrorDomain code:code description:description]; } @@ -42,7 +42,7 @@ NSString *const TFLSupportTaskErrorDomain = @"org.tensorflow.lite.tasks"; + (void *)mallocWithSize:(size_t)memSize error:(NSError **)error { if (!memSize) { - [TFLCommonUtils createCustomError:error + [MPPCommonUtils createCustomError:error withCode:MPPTasksErrorCodeInvalidArgumentError description:@"memSize cannot be zero."]; return NULL; @@ -60,12 +60,12 @@ NSString *const TFLSupportTaskErrorDomain = @"org.tensorflow.lite.tasks"; if (status.ok()) { return YES; } - // Payload of absl::Status created by the tflite task library stores an appropriate value of the - // enum TfLiteSupportStatus. The integer value corresponding to the TfLiteSupportStatus enum + // Payload of absl::Status created by the Media Pipe task library stores an appropriate value of the + // enum MPPiteSupportStatus. The integer value corresponding to the MPPiteSupportStatus enum // stored in the payload is extracted here to later map to the appropriate error code to be // returned. In cases where the enum is not stored in (payload is NULL or the payload string // cannot be converted to an integer), we set the error code value to be 1 - // (TFLSupportErrorCodeUnspecifiedError of TFLSupportErrorCode used in the iOS library to signify + // (MPPSupportErrorCodeUnspecifiedError of MPPSupportErrorCode used in the iOS library to signify // any errors not falling into other categories.) Since payload is of type absl::Cord that can be // type cast into an absl::optional, we use the std::stoi function to convert it into // an integer code if possible. @@ -73,7 +73,7 @@ NSString *const TFLSupportTaskErrorDomain = @"org.tensorflow.lite.tasks"; NSUInteger errorCode; try { // Try converting payload to integer if payload is not empty. Otherwise convert a string - // signifying generic error code TFLSupportErrorCodeUnspecifiedError to integer. + // signifying generic error code MPPSupportErrorCodeUnspecifiedError to integer. errorCode = (NSUInteger)std::stoi(static_cast>( status.GetPayload(mediapipe::tasks::kMediaPipeTasksPayload)) @@ -122,7 +122,7 @@ NSString *const TFLSupportTaskErrorDomain = @"org.tensorflow.lite.tasks"; NSString *description = [NSString stringWithCString:status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str() encoding:NSUTF8StringEncoding]; - [TFLCommonUtils createCustomError:error withCode:errorCode description:description]; + [MPPCommonUtils createCustomError:error withCode:errorCode description:description]; return NO; } From 683c2b1f0910543bc6f4d1f83266759487d9cef8 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 1 Dec 2022 04:13:19 +0530 Subject: [PATCH 05/18] Renamed .mm file --- mediapipe/tasks/ios/text/text_classifier/BUILD | 2 +- .../{MPPTextClassifierOptions.mm => MPPTextClassifierOptions.m} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename mediapipe/tasks/ios/text/text_classifier/sources/{MPPTextClassifierOptions.mm => MPPTextClassifierOptions.m} (100%) diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index 9ed25a852..e1f6eaab8 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -36,7 +36,7 @@ objc_library( objc_library( name = "MPPTextClassifierOptions", - srcs = ["sources/MPPTextClassifierOptions.mm"], + srcs = ["sources/MPPTextClassifierOptions.m"], hdrs = ["sources/MPPTextClassifierOptions.h"], copts = [ "-ObjC++", diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m similarity index 100% rename from mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.mm rename to mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m From 8d9c1b8a0f12a29c04b3d0001a3999a340f28327 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 1 Dec 2022 09:13:05 +0530 Subject: [PATCH 06/18] Added APIs for text classification --- mediapipe/tasks/ios/common/utils/BUILD | 3 + .../common/utils/sources/MPPCommonUtils.mm | 12 +++- .../common/utils/sources/NSString+Helpers.h | 2 + .../common/utils/sources/NSString+Helpers.mm | 4 ++ .../tasks/ios/components/containers/BUILD | 1 + .../containers/sources/MPPCategory.h | 2 +- .../containers/sources/MPPCategory.m | 4 +- .../sources/MPPClassificationResult.h | 8 ++- .../sources/MPPClassificationResult.m | 5 +- .../ios/components/containers/utils/BUILD | 40 +++++++++++ .../utils/sources/MPPCategory+Helpers.h | 26 ++++++++ .../utils/sources/MPPCategory+Helpers.mm | 42 ++++++++++++ .../sources/MPPClassificationResult+Helpers.h | 35 ++++++++++ .../MPPClassificationResult+Helpers.mm | 66 +++++++++++++++++++ mediapipe/tasks/ios/core/BUILD | 20 ++++++ .../tasks/ios/core/sources/MPPPacketCreator.h | 29 ++++++++ .../ios/core/sources/MPPPacketCreator.mm | 29 ++++++++ .../tasks/ios/core/sources/MPPTaskResult.h | 31 +++++++++ .../tasks/ios/core/sources/MPPTaskResult.m | 27 ++++++++ .../text/core/sources/MPPBaseTextTaskApi.h | 6 +- .../text/core/sources/MPPBaseTextTaskApi.mm | 10 +-- mediapipe/tasks/ios/text/core/utils/BUILD | 33 ++++++++++ .../tasks/ios/text/text_classifier/BUILD | 3 + .../sources/MPPTextClassifier.h | 3 + .../sources/MPPTextClassifier.mm | 25 +++++++ 25 files changed, 450 insertions(+), 16 deletions(-) create mode 100644 mediapipe/tasks/ios/components/containers/utils/BUILD create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm create mode 100644 mediapipe/tasks/ios/core/sources/MPPPacketCreator.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPPacketCreator.mm create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskResult.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskResult.m create mode 100644 mediapipe/tasks/ios/text/core/utils/BUILD diff --git a/mediapipe/tasks/ios/common/utils/BUILD b/mediapipe/tasks/ios/common/utils/BUILD index 9c8c75586..f2ffda39e 100644 --- a/mediapipe/tasks/ios/common/utils/BUILD +++ b/mediapipe/tasks/ios/common/utils/BUILD @@ -23,6 +23,9 @@ objc_library( deps = [ "//mediapipe/tasks/cc:common", "//mediapipe/tasks/ios/common:MPPCommon", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", ], ) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 6b277ae0c..141270b6d 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -13,8 +13,16 @@ // limitations under the License. #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" + #import "mediapipe/tasks/ios/common/sources/MPPCommon.h" +#include + +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/cord.h" // from @com_google_absl + +#include "mediapipe/tasks/cc/common.h" + /** Error domain of TensorFlow Lite Support related errors. */ NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; @@ -60,8 +68,8 @@ NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; if (status.ok()) { return YES; } - // Payload of absl::Status created by the Media Pipe task library stores an appropriate value of the - // enum MPPiteSupportStatus. The integer value corresponding to the MPPiteSupportStatus enum + // Payload of absl::Status created by the Media Pipe task library stores an appropriate value of + // the enum MPPiteSupportStatus. The integer value corresponding to the MPPiteSupportStatus enum // stored in the payload is extracted here to later map to the appropriate error code to be // returned. In cases where the enum is not stored in (payload is NULL or the payload string // cannot be converted to an integer), we set the error code value to be 1 diff --git a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h index 31828a367..4c1e3af01 100644 --- a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h +++ b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h @@ -22,6 +22,8 @@ NS_ASSUME_NONNULL_BEGIN @property(readonly) std::string cppString; ++ (NSString *)stringWithCppString:(std::string)text; + @end NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm index 540903c6c..0720223dc 100644 --- a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm +++ b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm @@ -20,4 +20,8 @@ return std::string(self.UTF8String, [self lengthOfBytesUsingEncoding:NSUTF8StringEncoding]); } ++ (NSString *)stringWithCppString:(std::string)text { + return [NSString stringWithCString:text.c_str() encoding:[NSString defaultCStringEncoding]]; +} + @end diff --git a/mediapipe/tasks/ios/components/containers/BUILD b/mediapipe/tasks/ios/components/containers/BUILD index ce80571e9..5d6bae220 100644 --- a/mediapipe/tasks/ios/components/containers/BUILD +++ b/mediapipe/tasks/ios/components/containers/BUILD @@ -28,5 +28,6 @@ objc_library( hdrs = ["sources/MPPClassificationResult.h"], deps = [ ":MPPCategory", + "//mediapipe/tasks/ios/core:MPPTaskResult", ], ) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h index e2f8a0729..b74a5edb5 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h @@ -18,7 +18,7 @@ NS_ASSUME_NONNULL_BEGIN /** Encapsulates information about a class in the classification results. */ NS_SWIFT_NAME(ClassificationCategory) -@interface TFLCategory : NSObject +@interface MPPCategory : NSObject /** Index of the class in the corresponding label map, usually packed in the TFLite Model * Metadata. */ diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m index 30efa239a..a14c38b2c 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m @@ -12,9 +12,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#import "mediapipe/tasks/ios/components/containers/sources/TFLCategory.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" -@implementation TFLCategory +@implementation MPPCategory - (instancetype)initWithIndex:(NSInteger)index score:(float)score diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h index 120780a7b..986a8cc1f 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #import -#import "mediapipe/tasks/ios/task/components/containers/sources/MPPCategory.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" NS_ASSUME_NONNULL_BEGIN @@ -71,7 +72,7 @@ NS_SWIFT_NAME(Classifications) /** Encapsulates results of any classification task. */ NS_SWIFT_NAME(ClassificationResult) -@interface MPPClassificationResult : NSObject +@interface MPPClassificationResult : MPPTaskResult /** Array of MPPClassifications objects containing classifier predictions per image classifier * head. @@ -87,7 +88,8 @@ NS_SWIFT_NAME(ClassificationResult) * @return An instance of MPPClassificationResult initialized with the given array of * classifications. */ -- (instancetype)initWithClassifications:(NSArray *)classifications; +- (instancetype)initWithClassifications:(NSArray *)classifications + timeStamp:(long)timeStamp; @end diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m index 266b401e0..df3e6a52d 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m @@ -39,8 +39,9 @@ limitations under the License. NSArray *_classifications; } -- (instancetype)initWithClassifications:(NSArray *)classifications { - self = [super init]; +- (instancetype)initWithClassifications:(NSArray *)classifications + timeStamp:(long)timeStamp { + self = [super initWithTimeStamp:timeStamp]; if (self) { _classifications = classifications; } diff --git a/mediapipe/tasks/ios/components/containers/utils/BUILD b/mediapipe/tasks/ios/components/containers/utils/BUILD new file mode 100644 index 000000000..a61dd6ca0 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/BUILD @@ -0,0 +1,40 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCategoryHelpers", + srcs = ["sources/MPPCategory+Helpers.mm"], + hdrs = ["sources/MPPCategory+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/components/containers:MPPCategory", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + ], +) + +objc_library( + name = "MPPClassificationResultHelpers", + srcs = ["sources/MPPClassificationResult+Helpers.mm"], + hdrs = ["sources/MPPClassificationResult+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/components/containers:MPPClassificationResult", + ":MPPCategoryHelpers", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + ], +) diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h new file mode 100644 index 000000000..874c751ac --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h @@ -0,0 +1,26 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#include "mediapipe/framework/formats/classification.pb.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPCategory (Helpers) + ++ (MPPCategory *)categoryWithProto:(const mediapipe::Classification &)classificationProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm new file mode 100644 index 000000000..24d250795 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm @@ -0,0 +1,42 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h" + +namespace { +using ClassificationProto = ::mediapipe::Classification; +} + +@implementation MPPCategory (Helpers) + ++ (MPPCategory *)categoryWithProto:(const ClassificationProto &)clasificationProto { + NSString *label; + NSString *displayName; + + if (clasificationProto.has_label()) { + label = [NSString stringWithCppString:clasificationProto.label()]; + } + + if (clasificationProto.has_display_name()) { + displayName = [NSString stringWithCppString:clasificationProto.display_name()]; + } + + return [[MPPCategory alloc] initWithIndex:clasificationProto.index() + score:clasificationProto.score() + label:label + displayName:displayName]; +} + +@end diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h new file mode 100644 index 000000000..5b19447ac --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h @@ -0,0 +1,35 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPClassifications (Helpers) + ++ (MPPClassifications *)classificationsWithProto: + (const mediapipe::tasks::components::containers::proto::Classifications &)classificationsProto; + +@end + +@interface MPPClassificationResult (Helpers) + ++ (MPPClassificationResult *)classificationResultWithProto: + (const mediapipe::tasks::components::containers::proto::ClassificationResult &) + classificationResultProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm new file mode 100644 index 000000000..0e9e599d7 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm @@ -0,0 +1,66 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" + +namespace { +using ClassificationsProto = ::mediapipe::tasks::components::containers::proto::Classifications; +using ClassificationResultProto = + ::mediapipe::tasks::components::containers::proto::ClassificationResult; +} // namespace + +@implementation MPPClassifications (Helpers) + ++ (MPPClassifications *)classificationsWithProto: + (const ClassificationsProto &)classificationsProto { + NSMutableArray *categories = [[NSMutableArray alloc] init]; + for (const auto &classification : classificationsProto.classification_list().classification()) { + [categories addObject:[MPPCategory categoryWithProto:classification]]; + } + + NSString *headName; + + if (classificationsProto.has_head_name()) { + headName = [NSString stringWithCppString:classificationsProto.head_name()]; + } + + return [[MPPClassifications alloc] initWithHeadIndex:(NSInteger)classificationsProto.head_index() + headName:headName + categories:categories]; +} + +@end + +@implementation MPPClassificationResult (Helpers) + ++ (MPPClassificationResult *)classificationResultWithProto: + (const ClassificationResultProto &)classificationResultProto { + NSMutableArray *classifications = [[NSMutableArray alloc] init]; + for (const auto &classifications_proto : classificationResultProto.classifications()) { + [classifications addObject:[MPPClassifications classificationsWithProto:classifications_proto]]; + } + + long timeStamp; + + if (classificationResultProto.has_timestamp_ms()) { + timeStamp = classificationResultProto.timestamp_ms(); + } + + return [[MPPClassificationResult alloc] initWithClassifications:classifications + timeStamp:timeStamp]; +} + +@end diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index 91257b4de..c6def1685 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -65,3 +65,23 @@ objc_library( ], ) +objc_library( + name = "MPPPacketCreator", + srcs = ["sources/MPPPacketCreator.mm"], + hdrs = ["sources/MPPPacketCreator.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + ], +) + +objc_library( + name = "MPPTaskResult", + srcs = ["sources/MPPTaskResult.m"], + hdrs = ["sources/MPPTaskResult.h"], +) + diff --git a/mediapipe/tasks/ios/core/sources/MPPPacketCreator.h b/mediapipe/tasks/ios/core/sources/MPPPacketCreator.h new file mode 100644 index 000000000..ecd0c5bfd --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPPacketCreator.h @@ -0,0 +1,29 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#ifndef __cplusplus +#error This header can only be included by an Objective-C++ file. +#endif + +#include "mediapipe/framework/packet.h" + +/// This class is an Objective-C wrapper around a MediaPipe graph object, and +/// helps interface it with iOS technologies such as AVFoundation. +@interface MPPPacketCreator : NSObject + ++ (mediapipe::Packet)createWithText:(NSString *)text; + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPPacketCreator.mm b/mediapipe/tasks/ios/core/sources/MPPPacketCreator.mm new file mode 100644 index 000000000..6ce5a5139 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPPacketCreator.mm @@ -0,0 +1,29 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" + +namespace { +using ::mediapipe::MakePacket; +using ::mediapipe::Packet; +} // namespace + +@implementation MPPPacketCreator + ++ (Packet)createWithText:(NSString *)text { + return MakePacket(text.cppString); +} + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h new file mode 100644 index 000000000..e4845c26d --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h @@ -0,0 +1,31 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend + * this class. + */ +NS_SWIFT_NAME(TaskResult) +@interface MPPTaskResult : NSObject +/** + * Base options for configuring the Mediapipe task. + */ +@property(nonatomic, assign, readonly) long timeStamp; + +- (instancetype)initWithTimeStamp:(long)timeStamp; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m new file mode 100644 index 000000000..6a79ea7a9 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m @@ -0,0 +1,27 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" + +@implementation MPPTaskResult + +- (instancetype)initWithTimeStamp:(long)timeStamp { + self = [self init]; + if (self) { + _timeStamp = timeStamp; + } + return self; +} + +@end diff --git a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h b/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h index 3d2fd4f43..405d25a81 100644 --- a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h +++ b/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h @@ -15,6 +15,7 @@ #import #include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" NS_ASSUME_NONNULL_BEGIN @@ -22,7 +23,10 @@ NS_ASSUME_NONNULL_BEGIN * The base class of the user-facing iOS mediapipe text task api classes. */ NS_SWIFT_NAME(BaseTextTaskApi) -@interface MPPBaseTextTaskApi : NSObject +@interface MPPBaseTextTaskApi : NSObject { + @protected + std::unique_ptr cppTaskRunner; +} /** * Initializes a new `MPPBaseTextTaskApi` with the mediapipe text task graph config proto. diff --git a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm b/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm index 9d7142fe5..5c05797da 100644 --- a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm +++ b/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm @@ -13,18 +13,18 @@ limitations under the License. ==============================================================================*/ #import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h" - -#include "mediapipe/tasks/cc/core/task_runner.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" namespace { using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::Packet; +using ::mediapipe::tasks::core::PacketMap; using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; } // namespace @interface MPPBaseTextTaskApi () { /** TextSearcher backed by C++ API */ - std::unique_ptr _taskRunner; + std::unique_ptr _cppTaskRunner; } @end @@ -40,13 +40,13 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; return nil; } - _taskRunner = std::move(taskRunnerResult.value()); + _cppTaskRunner = std::move(taskRunnerResult.value()); } return self; } - (void)close { - _taskRunner->Close(); + _cppTaskRunner->Close(); } @end diff --git a/mediapipe/tasks/ios/text/core/utils/BUILD b/mediapipe/tasks/ios/text/core/utils/BUILD new file mode 100644 index 000000000..abb8edc71 --- /dev/null +++ b/mediapipe/tasks/ios/text/core/utils/BUILD @@ -0,0 +1,33 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPBaseTextTaskApi", + srcs = ["sources/MPPBaseTextTaskApi.mm"], + hdrs = ["sources/MPPBaseTextTaskApi.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + ], +) + diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index e1f6eaab8..eb0800fcd 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -28,8 +28,11 @@ objc_library( "//mediapipe/tasks/ios/core:MPPTaskOptions", "//mediapipe/tasks/ios/core:MPPTaskInfo", "//mediapipe/tasks/ios/text/core:MPPBaseTextTaskApi", + "//mediapipe/tasks/ios/core:MPPPacketCreator", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", + "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", ":MPPTextClassifierOptions", ], ) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h index d6b2f770f..96d5887ff 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -14,6 +14,7 @@ ==============================================================================*/ #import +#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h" #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" @@ -52,6 +53,8 @@ NS_SWIFT_NAME(TextClassifier) */ - (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error; +- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error; + - (instancetype)init NS_UNAVAILABLE; + (instancetype)new NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm index 08557d94e..e61e6998d 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -15,9 +15,20 @@ #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" +#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" #import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" +#include "absl/status/statusor.h" + +namespace { +using ::mediapipe::Packet; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::core::PacketMap; +} // namespace + static NSString *const kClassificationsStreamName = @"classifications_out"; static NSString *const kClassificationsTag = @"classifications"; static NSString *const kTextInStreamName = @"text_in"; @@ -50,4 +61,18 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T return [self initWithOptions:options error:error]; } +- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error { + Packet packet = [MPPPacketCreator createWithText:text]; + absl::StatusOr output_packet_map = + cppTaskRunner->Process({{kTextInStreamName.cppString, packet}}); + + if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) { + return nil; + } + + return [MPPClassificationResult + classificationResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString] + .Get()]; +} + @end From eaf6edc3a6c95fe270790bb59e308d53887aa1a8 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 1 Dec 2022 09:13:45 +0530 Subject: [PATCH 07/18] Added Objc tests for iOS Text Classifier --- .../tasks/ios/test/text/text_classifier/BUILD | 36 ++++++++++++ .../text_classifier/MPPTextClassifierTests.m | 57 +++++++++++++++++++ 2 files changed, 93 insertions(+) create mode 100644 mediapipe/tasks/ios/test/text/text_classifier/BUILD create mode 100644 mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m diff --git a/mediapipe/tasks/ios/test/text/text_classifier/BUILD b/mediapipe/tasks/ios/test/text/text_classifier/BUILD new file mode 100644 index 000000000..cbcc3d106 --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_classifier/BUILD @@ -0,0 +1,36 @@ +load("@org_tensorflow//tensorflow/lite/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION") +load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") +load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPTextClassifierObjcTestLibrary", + testonly = 1, + srcs = ["MPPTextClassifierTests.m"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + tags = TFL_DEFAULT_TAGS, + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier", + ], + +) + +ios_unit_test( + name = "MPPTextClassifierObjcTest", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":MPPTextClassifierObjcTestLibrary", + ], +) diff --git a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m new file mode 100644 index 000000000..3808009f3 --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m @@ -0,0 +1,57 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import + +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" + +NS_ASSUME_NONNULL_BEGIN + +static NSString *const kBertTextClassifierModelName = @"bert_text_classifier"; + +@interface MPPTextClassifierTests : XCTestCase +@end + +@implementation MPPTextClassifierTests + +- (void)setUp { + [super setUp]; + +} + +- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension { + NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName + ofType:extension]; + XCTAssertNotNil(filePath); + + return filePath; +} + +- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName { + NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; + MPPTextClassifierOptions *textClassifierOptions = + [[MPPTextClassifierOptions alloc] initWithModelPath:modelPath]; + + return textClassifierOptions; +} + +- (void)testCreateTextClassifierOptionsSucceeds { + MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; + XCTAssertNotNil(textClassifier); +} + +@end + +NS_ASSUME_NONNULL_END From 03105c2a62c8ce7717eaaf8da16a8668e5a34f00 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 5 Dec 2022 17:02:20 +0530 Subject: [PATCH 08/18] Added tensorflow constants --- .../tasks/ios/test/text/text_classifier/BUILD | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/test/text/text_classifier/BUILD b/mediapipe/tasks/ios/test/text/text_classifier/BUILD index cbcc3d106..81479e68b 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/test/text/text_classifier/BUILD @@ -1,4 +1,3 @@ -load("@org_tensorflow//tensorflow/lite/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION") load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner") @@ -6,6 +5,27 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +TFL_MINIMUM_OS_VERSION = "11.0" +# LINT.ThenChange( +# TensorFlowLiteC.podspec.template, +# TensorFlowLiteSelectTfOps.podspec.template, +# ../objc/TensorFlowLiteObjC.podspec.template, +# ../swift/TensorFlowLiteSwift.podspec.template +# ) + +# Default tags for filtering iOS targets. Targets are restricted to Apple platforms. +TFL_DEFAULT_TAGS = [ + "apple", +] + +# Following sanitizer tests are not supported by iOS test targets. +TFL_DISABLED_SANITIZER_TAGS = [ + "noasan", + "nomsan", + "notsan", +] + + objc_library( name = "MPPTextClassifierObjcTestLibrary", testonly = 1, From 37d81082fe28e356e3c8cd1a7f2fa62756ac7976 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Mon, 5 Dec 2022 18:24:05 +0530 Subject: [PATCH 09/18] Created new ios.bzl with TFL constants --- mediapipe/tasks/ios/ios.bzl | 15 +++++++ .../tasks/ios/test/text/text_classifier/BUILD | 43 ++++++++----------- 2 files changed, 32 insertions(+), 26 deletions(-) create mode 100644 mediapipe/tasks/ios/ios.bzl diff --git a/mediapipe/tasks/ios/ios.bzl b/mediapipe/tasks/ios/ios.bzl new file mode 100644 index 000000000..ad0f865ad --- /dev/null +++ b/mediapipe/tasks/ios/ios.bzl @@ -0,0 +1,15 @@ +"""Mediapipe Task Library Helper Rules for iOS""" + +MPP_TASK_MINIMUM_OS_VERSION = "11.0" + +# Default tags for filtering iOS targets. Targets are restricted to Apple platforms. +MPP_TASK_DEFAULT_TAGS = [ + "apple", +] + +# Following sanitizer tests are not supported by iOS test targets. +MPP_TASK_DISABLED_SANITIZER_TAGS = [ + "noasan", + "nomsan", + "notsan", +] diff --git a/mediapipe/tasks/ios/test/text/text_classifier/BUILD b/mediapipe/tasks/ios/test/text/text_classifier/BUILD index 81479e68b..2202ff1a6 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/test/text/text_classifier/BUILD @@ -1,31 +1,22 @@ -load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") -load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner") +load( + "//mediapipe/tasks:ios/ios.bzl", + "MPP_TASK_MINIMUM_OS_VERSION", + "MPP_TASK_DEFAULT_TAGS", + "MPP_TASK_DISABLED_SANITIZER_TAGS", +) +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_unit_test", +) +load( + "@org_tensorflow//tensorflow/lite:special_rules.bzl", + "tflite_ios_lab_runner" +) package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -TFL_MINIMUM_OS_VERSION = "11.0" -# LINT.ThenChange( -# TensorFlowLiteC.podspec.template, -# TensorFlowLiteSelectTfOps.podspec.template, -# ../objc/TensorFlowLiteObjC.podspec.template, -# ../swift/TensorFlowLiteSwift.podspec.template -# ) - -# Default tags for filtering iOS targets. Targets are restricted to Apple platforms. -TFL_DEFAULT_TAGS = [ - "apple", -] - -# Following sanitizer tests are not supported by iOS test targets. -TFL_DISABLED_SANITIZER_TAGS = [ - "noasan", - "nomsan", - "notsan", -] - - objc_library( name = "MPPTextClassifierObjcTestLibrary", testonly = 1, @@ -34,7 +25,7 @@ objc_library( "//mediapipe/tasks/testdata/text:bert_text_classifier_models", "//mediapipe/tasks/testdata/text:text_classifier_models", ], - tags = TFL_DEFAULT_TAGS, + tags = MPP_TASK_DEFAULT_TAGS, copts = [ "-ObjC++", "-std=c++17", @@ -47,9 +38,9 @@ objc_library( ios_unit_test( name = "MPPTextClassifierObjcTest", - minimum_os_version = TFL_MINIMUM_OS_VERSION, + minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION, runner = tflite_ios_lab_runner("IOS_LATEST"), - tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + tags = MPP_TASK_DEFAULT_TAGS + MPP_TASK_DISABLED_SANITIZER_TAGS, deps = [ ":MPPTextClassifierObjcTestLibrary", ], From dea0e21aeca8095cf8915f71b188976c48b84126 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 14 Dec 2022 11:28:21 +0530 Subject: [PATCH 10/18] Separated task options protocol --- mediapipe/tasks/ios/core/BUILD | 9 +++++- .../tasks/ios/core/sources/MPPTaskOptions.h | 13 --------- .../ios/core/sources/MPPTaskOptionsProtocol.h | 29 +++++++++++++++++++ .../ios/text/text_classifier/utils/BUILD | 2 +- .../MPPTextClassifierOptions+Helpers.h | 2 +- 5 files changed, 39 insertions(+), 16 deletions(-) create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index e8ce47818..73fcacc37 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -25,7 +25,6 @@ objc_library( "-std=c++17", ], deps = [ - "//mediapipe/framework:calculator_options_cc_proto", ":MPPBaseOptions", ], ) @@ -74,3 +73,11 @@ objc_library( srcs = ["sources/MPPBaseOptions.m"], hdrs = ["sources/MPPBaseOptions.h"], ) + +objc_library( + name = "MPPTaskOptionsProtocol", + hdrs = ["sources/MPPTaskOptionsProtocol.h"], + deps = [ + "//mediapipe/framework:calculator_options_cc_proto", + ], +) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h index e40e92657..fa11cd38e 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h @@ -10,7 +10,6 @@ limitations under the License. ==============================================================================*/ #import -#include "mediapipe/framework/calculator_options.pb.h" #import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" NS_ASSUME_NONNULL_BEGIN @@ -43,16 +42,4 @@ NS_SWIFT_NAME(TaskOptions) @end -/** - * Any mediapipe task options should confirm to this protocol. - */ -@protocol MPPTaskOptionsProtocol - -/** - * Copies the iOS Mediapipe task options to an object of mediapipe::CalculatorOptions proto. - */ -- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto; - -@end - NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h new file mode 100644 index 000000000..18543e9ef --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h @@ -0,0 +1,29 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import +#include "mediapipe/framework/calculator_options.pb.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Any mediapipe task options should confirm to this protocol. + */ +@protocol MPPTaskOptionsProtocol + +/** + * Copies the iOS Mediapipe task options to an object of mediapipe::CalculatorOptions proto. + */ +- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD index 453a30f54..662e76c2a 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD @@ -23,8 +23,8 @@ objc_library( deps = [ "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierOptions", "//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers", + "//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol", "//mediapipe/tasks/ios/components/processors/utils:MPPClassifierOptionsHelpers", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", ], ) diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h index 8df471d05..71076da26 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h @@ -12,8 +12,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/framework/calculator_options.pb.h" #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" NS_ASSUME_NONNULL_BEGIN From 96247ccce484afad99d01a0434f818313ac7102d Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 14 Dec 2022 18:59:26 +0530 Subject: [PATCH 11/18] Added iOS task manager --- mediapipe/tasks/ios/core/BUILD | 12 ++++ .../tasks/ios/core/sources/MPPTaskInfo.h | 2 + .../tasks/ios/core/sources/MPPTaskInfo.mm | 3 +- .../tasks/ios/core/sources/MPPTaskManager.h | 47 ++++++++++++++++ .../tasks/ios/core/sources/MPPTaskManager.mm | 56 +++++++++++++++++++ .../tasks/ios/core/sources/MPPTaskOptions.m | 2 +- .../utils/sources/MPPBaseOptions+Helpers.mm | 4 +- .../tasks/ios/text/text_classifier/BUILD | 2 +- .../sources/MPPTextClassifier.h | 5 +- 9 files changed, 125 insertions(+), 8 deletions(-) create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskManager.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskManager.mm diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index 73fcacc37..666b0e6e1 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -42,6 +42,7 @@ objc_library( "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", ":MPPTaskOptions", + ":MPPTaskOptionsProtocol", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:NSStringHelpers", "//mediapipe/tasks/ios/common:MPPCommon", @@ -81,3 +82,14 @@ objc_library( "//mediapipe/framework:calculator_options_cc_proto", ], ) + +objc_library( + name = "MPPTaskManager", + srcs = ["sources/MPPTaskManager.mm"], + hdrs = ["sources/MPPTaskManager.h"], + deps = [ + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + ], +) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h index 620184518..a6ba4c4bd 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h @@ -12,6 +12,8 @@ #import #include "mediapipe/framework/calculator.pb.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" + NS_ASSUME_NONNULL_BEGIN diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm index 7e42d6eae..ed8e814d2 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm @@ -37,7 +37,6 @@ using FlowLimiterCalculatorOptions = ::mediapipe::FlowLimiterCalculatorOptions; taskOptions:(id)taskOptions enableFlowLimiting:(BOOL)enableFlowLimiting error:(NSError **)error { - self = [super init]; if (!taskGraphName || !inputStreams.count || !outputStreams.count) { [MPPCommonUtils createCustomError:error @@ -46,6 +45,8 @@ using FlowLimiterCalculatorOptions = ::mediapipe::FlowLimiterCalculatorOptions; @"Task graph's name, input streams, and output streams should be non-empty."]; } + self = [super init]; + if (self) { _taskGraphName = taskGraphName; _inputStreams = inputStreams; diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskManager.h b/mediapipe/tasks/ios/core/sources/MPPTaskManager.h new file mode 100644 index 000000000..b4ba02edd --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskManager.h @@ -0,0 +1,47 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import + +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" + + +NS_ASSUME_NONNULL_BEGIN + +/** + * The base class of the user-facing iOS mediapipe text task api classes. + */ +@interface MPPTaskManager : NSObject +/** + * Initializes a new `MPPBaseTextTaskApi` with the mediapipe text task graph config proto. + * + * @param graphConfig A mediapipe text task graph config proto. + * + * @return An instance of `MPPBaseTextTaskApi` initialized to the given graph config proto. + */ +- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig + error:(NSError **)error; + +- (absl::StatusOr)process:(const mediapipe::tasks::core::PacketMap&)packetMap error:(NSError **)error; + +- (void)close; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskManager.mm b/mediapipe/tasks/ios/core/sources/MPPTaskManager.mm new file mode 100644 index 000000000..2bf23d428 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskManager.mm @@ -0,0 +1,56 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h" +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" + +namespace { +using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::Packet; +using ::mediapipe::tasks::core::PacketMap; +using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; +} // namespace + +@interface MPPTaskManager () { + /** TextSearcher backed by C++ API */ + std::unique_ptr _cppTaskRunner; +} +@end + +@implementation MPPTaskManager + +- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig + error:(NSError **)error { + self = [super init]; + if (self) { + auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig)); + + if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) { + return nil; + } + + _cppTaskRunner = std::move(taskRunnerResult.value()); + } + return self; +} + +- (absl::StatusOr)process:(const PacketMap&)packetMap { + return _cppTaskRunner->Process(packetMap); +} + +- (void)close { + _cppTaskRunner->Close(); +} + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m index ec1adbaf1..f71d275be 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m @@ -28,7 +28,7 @@ - (instancetype)initWithModelPath:(NSString *)modelPath { self = [self init]; if (self) { - _baseOptions.modelAssetFile.filePath = modelPath; + _baseOptions.modelAssetPath = modelPath; } return self; } diff --git a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm index f20f8602a..9fce15dfa 100644 --- a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm +++ b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm @@ -21,8 +21,8 @@ using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions; @implementation MPPBaseOptions (Helpers) - (void)copyToProto:(BaseOptionsProto *)baseOptionsProto { - if (self.modelAssetFile.filePath) { - baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetFile.filePath.UTF8String); + if (self.modelAssetPath) { + baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String); } switch (self.delegate) { diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index eb0800fcd..3427e3a6f 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -27,7 +27,7 @@ objc_library( deps = [ "//mediapipe/tasks/ios/core:MPPTaskOptions", "//mediapipe/tasks/ios/core:MPPTaskInfo", - "//mediapipe/tasks/ios/text/core:MPPBaseTextTaskApi", + "//mediapipe/tasks/ios/core:MPPTaskManager", "//mediapipe/tasks/ios/core:MPPPacketCreator", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h index 96d5887ff..0c33a5288 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -16,7 +16,6 @@ #import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" -#import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h" #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" NS_ASSUME_NONNULL_BEGIN @@ -25,7 +24,7 @@ NS_ASSUME_NONNULL_BEGIN * A Mediapipe iOS Text Classifier. */ NS_SWIFT_NAME(TextClassifier) -@interface MPPTextClassifier : MPPBaseTextTaskApi +@interface MPPTextClassifier : NSObject /** * Creates a new instance of `MPPTextClassifier` from an absolute path to a TensorFlow Lite model @@ -53,7 +52,7 @@ NS_SWIFT_NAME(TextClassifier) */ - (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error; -- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error; +- (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error; - (instancetype)init NS_UNAVAILABLE; From 781f7adf26d21df1b40beb6bb7518f886e7c9645 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 14 Dec 2022 18:59:56 +0530 Subject: [PATCH 12/18] Updated text classifier to use task manager --- .../sources/MPPTextClassifier.mm | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm index e61e6998d..b4cd66f70 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -13,11 +13,11 @@ limitations under the License. ==============================================================================*/ #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" - #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" #import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" #import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" #import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" @@ -35,9 +35,16 @@ static NSString *const kTextInStreamName = @"text_in"; static NSString *const kTextTag = @"TEXT"; static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph"; +@interface MPPTextClassifier () { + /** TextSearcher backed by C++ API */ + MPPTaskManager *_taskManager; +} +@end + @implementation MPPTextClassifier - (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] initWithTaskGraphName:kTaskGraphName inputStreams:@[ [NSString stringWithFormat:@"@:@", kTextTag, kTextInStreamName] ] @@ -51,7 +58,11 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T return nil; } - return [super initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error]; + _taskManager = [[MPPTaskManager alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error]; + + self = [super init]; + + return self; } - (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { @@ -61,11 +72,10 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T return [self initWithOptions:options error:error]; } -- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error { +- (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error { Packet packet = [MPPPacketCreator createWithText:text]; - absl::StatusOr output_packet_map = - cppTaskRunner->Process({{kTextInStreamName.cppString, packet}}); + absl::StatusOr output_packet_map = [_taskManager process:{{kTextInStreamName.cppString, packet}} error:error]; if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) { return nil; } From 4a7a3b342b3a606b73ede421176d26f25b7045a2 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:07:16 +0530 Subject: [PATCH 13/18] Updated copyright notice --- .../ios/common/utils/sources/MPPCommonUtils.h | 30 ++++++++-------- .../common/utils/sources/MPPCommonUtils.mm | 26 +++++++------- .../common/utils/sources/NSString+Helpers.h | 27 +++++++-------- .../common/utils/sources/NSString+Helpers.mm | 26 +++++++------- .../containers/sources/MPPCategory.h | 26 +++++++------- .../containers/sources/MPPCategory.m | 26 +++++++------- .../sources/MPPClassificationResult.h | 26 +++++++------- .../sources/MPPClassificationResult.m | 26 +++++++------- .../processors/sources/MPPClassifierOptions.h | 26 +++++++------- .../processors/sources/MPPClassifierOptions.m | 26 +++++++------- .../sources/MPPClassifierOptions+Helpers.h | 26 +++++++------- .../sources/MPPClassifierOptions+Helpers.mm | 26 +++++++------- .../tasks/ios/core/sources/MPPTaskInfo.h | 25 ++++++++------ .../tasks/ios/core/sources/MPPTaskInfo.mm | 32 ++++++++--------- .../tasks/ios/core/sources/MPPTaskManager.h | 34 +++++++++---------- .../tasks/ios/core/sources/MPPTaskManager.mm | 26 +++++++------- .../utils/sources/MPPBaseOptions+Helpers.h | 26 +++++++------- .../utils/sources/MPPBaseOptions+Helpers.mm | 26 +++++++------- 18 files changed, 244 insertions(+), 242 deletions(-) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h index 6b4f40bc6..8a90856c7 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h @@ -1,30 +1,30 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ #import #include "mediapipe/tasks/cc/common.h" NS_ASSUME_NONNULL_BEGIN -/** Error domain of TensorFlow Lite Support related errors. */ +/** Error domain of Mediapipe Task related errors. */ extern NSString *const MPPTasksErrorDomain; /** Helper utility for the all tasks which encapsulates common functionality. */ @interface MPPCommonUtils : NSObject /** - * Creates and saves an NSError in the Tensorflow Lite Task Library domain, with the given code and + * Creates and saves an NSError in the Mediapipe task library domain, with the given code and * description. * * @param code Error code. diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 141270b6d..574f2ef9a 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -23,7 +23,7 @@ #include "mediapipe/tasks/cc/common.h" -/** Error domain of TensorFlow Lite Support related errors. */ +/** Error domain of MediaPipe task library errors. */ NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; @implementation MPPCommonUtils @@ -69,11 +69,11 @@ NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; return YES; } // Payload of absl::Status created by the Media Pipe task library stores an appropriate value of - // the enum MPPiteSupportStatus. The integer value corresponding to the MPPiteSupportStatus enum + // the enum MediaPipeTasksStatus. The integer value corresponding to the MediaPipeTasksStatus enum // stored in the payload is extracted here to later map to the appropriate error code to be // returned. In cases where the enum is not stored in (payload is NULL or the payload string // cannot be converted to an integer), we set the error code value to be 1 - // (MPPSupportErrorCodeUnspecifiedError of MPPSupportErrorCode used in the iOS library to signify + // (MPPTasksErrorCodeError of MPPTasksErrorCode used in the iOS library to signify // any errors not falling into other categories.) Since payload is of type absl::Cord that can be // type cast into an absl::optional, we use the std::stoi function to convert it into // an integer code if possible. @@ -81,7 +81,7 @@ NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; NSUInteger errorCode; try { // Try converting payload to integer if payload is not empty. Otherwise convert a string - // signifying generic error code MPPSupportErrorCodeUnspecifiedError to integer. + // signifying generic error code MPPTasksErrorCodeError to integer. errorCode = (NSUInteger)std::stoi(static_cast>( status.GetPayload(mediapipe::tasks::kMediaPipeTasksPayload)) @@ -92,12 +92,12 @@ NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; } // If errorCode is outside the range of enum values possible or is - // TFLSupportErrorCodeUnspecifiedError, we try to map the absl::Status::code() to assign - // appropriate TFLSupportErrorCode or TFLSupportErrorCodeUnspecifiedError in default cases. Note: + // MPPTasksErrorCodeError, we try to map the absl::Status::code() to assign + // appropriate MPPTasksErrorCode in default cases. Note: // The mapping to absl::Status::code() is done to generate a more specific error code than - // TFLSupportErrorCodeUnspecifiedError in cases when the payload can't be mapped to - // TfLiteSupportStatus. This can happen when absl::Status returned by TfLite are in turn returned - // without moodification by TfLite Support Methods. + // MPPTasksErrorCodeError in cases when the payload can't be mapped to + // MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn returned + // without modification by Mediapipe cc library methods. if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) { switch (status.code()) { case absl::StatusCode::kInternal: @@ -116,11 +116,11 @@ NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; } // Creates the NSEror with the appropriate error - // TFLSupportErrorCode and message. TFLSupportErrorCode has a one to one - // mapping with TfLiteSupportStatus starting from the value 1(TFLSupportErrorCodeUnspecifiedError) + // MPPTasksErrorCode and message. MPPTasksErrorCode has a one to one + // mapping with MediaPipeTasksStatus starting from the value 1(MPPTasksErrorCodeError) // and hence will be correctly initialized if directly cast from the integer code derived from - // TfLiteSupportStatus stored in its payload. TFLSupportErrorCode omits kOk = 0 of - // TfLiteSupportStatus. + // MediaPipeTasksStatus stored in its payload. MPPTasksErrorCode omits kOk = 0 of + // MediaPipeTasksStatusx. // // Stores a string including absl status code and message(if non empty) as the // error message See diff --git a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h index 4c1e3af01..aac7485da 100644 --- a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h +++ b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h @@ -1,17 +1,16 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #import #include diff --git a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm index 0720223dc..183ed4365 100644 --- a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm +++ b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" @implementation NSString (Helpers) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h index b74a5edb5..431b8a705 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import NS_ASSUME_NONNULL_BEGIN diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m index a14c38b2c..20f745582 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" @implementation MPPCategory diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h index 986a8cc1f..b0e0c4073 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ #import #import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m index df3e6a52d..e4e5eaac5 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ #import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" @implementation MPPClassifications diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index 7538a625d..8c4981642 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import NS_ASSUME_NONNULL_BEGIN diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m index 5a5baf79f..52dce23e4 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" @implementation MPPClassifierOptions diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h index 4defb3ee7..6644a6255 100644 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm index efb220147..25e657599 100644 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" #import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h" diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h index a6ba4c4bd..fca660fae 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h @@ -1,14 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #import #include "mediapipe/framework/calculator.pb.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm index ed8e814d2..7d2fd6f28 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" #import "mediapipe/tasks/ios/common/sources/MPPCommon.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" @@ -24,9 +24,9 @@ namespace { using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig; using Node = ::mediapipe::CalculatorGraphConfig::Node; -using InputStreamInfo = ::mediapipe::InputStreamInfo; -using CalculatorOptions = ::mediapipe::CalculatorOptions; -using FlowLimiterCalculatorOptions = ::mediapipe::FlowLimiterCalculatorOptions; +using ::mediapipe::InputStreamInfo; +using ::mediapipe::CalculatorOptions; +using ::mediapipe::FlowLimiterCalculatorOptions; } // namespace @implementation MPPTaskInfo diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskManager.h b/mediapipe/tasks/ios/core/sources/MPPTaskManager.h index b4ba02edd..f6dea201a 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskManager.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskManager.h @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import #include "mediapipe/framework/calculator.pb.h" @@ -21,15 +21,15 @@ NS_ASSUME_NONNULL_BEGIN /** - * The base class of the user-facing iOS mediapipe text task api classes. + * The base class of the user-facing iOS mediapipe task api classes. */ @interface MPPTaskManager : NSObject /** - * Initializes a new `MPPBaseTextTaskApi` with the mediapipe text task graph config proto. + * Initializes a new `MPPTaskManager` with the mediapipe task graph config proto. * - * @param graphConfig A mediapipe text task graph config proto. + * @param graphConfig A mediapipe task graph config proto. * - * @return An instance of `MPPBaseTextTaskApi` initialized to the given graph config proto. + * @return An instance of `MPPTaskManager` initialized to the given graph config proto. */ - (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig error:(NSError **)error; diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskManager.mm b/mediapipe/tasks/ios/core/sources/MPPTaskManager.mm index 2bf23d428..492ed8cf6 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskManager.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskManager.mm @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" diff --git a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h index f57844a8f..89d4a8237 100644 --- a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h +++ b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" diff --git a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm index 9fce15dfa..c3f1b679f 100644 --- a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm +++ b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" namespace { From febca359a61a8e4b8e9e87236d884e1238d00bee Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 21 Dec 2022 01:07:31 +0530 Subject: [PATCH 14/18] Updated packet creator class name --- mediapipe/tasks/ios/core/BUILD | 6 +++--- .../{MPPPacketCreator.h => MPPTextPacketCreator.h} | 11 ++++------- .../{MPPPacketCreator.mm => MPPTextPacketCreator.mm} | 4 ++-- 3 files changed, 9 insertions(+), 12 deletions(-) rename mediapipe/tasks/ios/core/sources/{MPPPacketCreator.h => MPPTextPacketCreator.h} (72%) rename mediapipe/tasks/ios/core/sources/{MPPPacketCreator.mm => MPPTextPacketCreator.mm} (89%) diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index 666b0e6e1..58f7389ac 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -50,9 +50,9 @@ objc_library( ) objc_library( - name = "MPPPacketCreator", - srcs = ["sources/MPPPacketCreator.mm"], - hdrs = ["sources/MPPPacketCreator.h"], + name = "MPPTextPacketCreator", + srcs = ["sources/MPPTextPacketCreator.mm"], + hdrs = ["sources/MPPTextPacketCreator.h"], copts = [ "-ObjC++", "-std=c++17", diff --git a/mediapipe/tasks/ios/core/sources/MPPPacketCreator.h b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h similarity index 72% rename from mediapipe/tasks/ios/core/sources/MPPPacketCreator.h rename to mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h index ecd0c5bfd..03f946dd0 100644 --- a/mediapipe/tasks/ios/core/sources/MPPPacketCreator.h +++ b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h @@ -14,15 +14,12 @@ #import -#ifndef __cplusplus -#error This header can only be included by an Objective-C++ file. -#endif - #include "mediapipe/framework/packet.h" -/// This class is an Objective-C wrapper around a MediaPipe graph object, and -/// helps interface it with iOS technologies such as AVFoundation. -@interface MPPPacketCreator : NSObject +/* This class is an Objective-C wrapper around a MediaPipe graph object, and + * helps interface it with iOS technologies such as AVFoundation. + */ +@interface MPPTextPacketCreator : NSObject + (mediapipe::Packet)createWithText:(NSString *)text; diff --git a/mediapipe/tasks/ios/core/sources/MPPPacketCreator.mm b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm similarity index 89% rename from mediapipe/tasks/ios/core/sources/MPPPacketCreator.mm rename to mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm index 6ce5a5139..ca86e7a0b 100644 --- a/mediapipe/tasks/ios/core/sources/MPPPacketCreator.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.mm @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h" +#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" namespace { @@ -20,7 +20,7 @@ using ::mediapipe::MakePacket; using ::mediapipe::Packet; } // namespace -@implementation MPPPacketCreator +@implementation MPPTextPacketCreator + (Packet)createWithText:(NSString *)text { return MakePacket(text.cppString); From 7e0fec7c28eb25eb69793c5a33194b96ef8d1734 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 23 Dec 2022 17:52:00 +0530 Subject: [PATCH 15/18] Updated implementation of text classifier --- .../tasks/ios/components/containers/BUILD | 1 - .../sources/MPPClassificationResult.h | 6 +- .../sources/MPPClassificationResult.m | 6 +- .../MPPClassificationResult+Helpers.mm | 9 +-- .../tasks/ios/core/sources/MPPTaskManager.h | 47 ---------------- .../tasks/ios/core/sources/MPPTaskManager.mm | 56 ------------------- .../text_classifier/MPPTextClassifierTests.m | 38 ++++++++++++- .../tasks/ios/text/text_classifier/BUILD | 15 ++++- .../sources/MPPTextClassifier.h | 4 +- .../sources/MPPTextClassifier.mm | 31 +++++----- .../sources/MPPTextClassifierOptions.h | 28 +++++----- .../sources/MPPTextClassifierOptions.m | 14 ++--- .../sources/MPPTextClassifierResult.h | 41 ++++++++++++++ .../sources/MPPTextClassifierResult.m | 28 ++++++++++ .../ios/text/text_classifier/utils/BUILD | 10 ++++ .../MPPTextClassifierOptions+Helpers.h | 26 ++++----- .../MPPTextClassifierOptions+Helpers.mm | 26 ++++----- .../sources/MPPTextClassifierResult+Helpers.h | 28 ++++++++++ .../MPPTextClassifierResult+Helpers.mm | 39 +++++++++++++ 19 files changed, 264 insertions(+), 189 deletions(-) delete mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskManager.h delete mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskManager.mm create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m create mode 100644 mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h create mode 100644 mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm diff --git a/mediapipe/tasks/ios/components/containers/BUILD b/mediapipe/tasks/ios/components/containers/BUILD index 5d6bae220..ce80571e9 100644 --- a/mediapipe/tasks/ios/components/containers/BUILD +++ b/mediapipe/tasks/ios/components/containers/BUILD @@ -28,6 +28,5 @@ objc_library( hdrs = ["sources/MPPClassificationResult.h"], deps = [ ":MPPCategory", - "//mediapipe/tasks/ios/core:MPPTaskResult", ], ) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h index b0e0c4073..24f99bfde 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h @@ -14,7 +14,6 @@ #import #import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" -#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" NS_ASSUME_NONNULL_BEGIN @@ -72,7 +71,7 @@ NS_SWIFT_NAME(Classifications) /** Encapsulates results of any classification task. */ NS_SWIFT_NAME(ClassificationResult) -@interface MPPClassificationResult : MPPTaskResult +@interface MPPClassificationResult : NSObject /** Array of MPPClassifications objects containing classifier predictions per image classifier * head. @@ -88,8 +87,7 @@ NS_SWIFT_NAME(ClassificationResult) * @return An instance of MPPClassificationResult initialized with the given array of * classifications. */ -- (instancetype)initWithClassifications:(NSArray *)classifications - timeStamp:(long)timeStamp; +- (instancetype)initWithClassifications:(NSArray *)classifications; @end diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m index e4e5eaac5..dd9c4e024 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m @@ -39,9 +39,9 @@ NSArray *_classifications; } -- (instancetype)initWithClassifications:(NSArray *)classifications - timeStamp:(long)timeStamp { - self = [super initWithTimeStamp:timeStamp]; +- (instancetype)initWithClassifications:(NSArray *)classifications { + + self = [super init]; if (self) { _classifications = classifications; } diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm index 0e9e599d7..84d5872d7 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm @@ -53,14 +53,7 @@ using ClassificationResultProto = [classifications addObject:[MPPClassifications classificationsWithProto:classifications_proto]]; } - long timeStamp; - - if (classificationResultProto.has_timestamp_ms()) { - timeStamp = classificationResultProto.timestamp_ms(); - } - - return [[MPPClassificationResult alloc] initWithClassifications:classifications - timeStamp:timeStamp]; + return [[MPPClassificationResult alloc] initWithClassifications:classifications]; } @end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskManager.h b/mediapipe/tasks/ios/core/sources/MPPTaskManager.h deleted file mode 100644 index f6dea201a..000000000 --- a/mediapipe/tasks/ios/core/sources/MPPTaskManager.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -#include "mediapipe/framework/calculator.pb.h" -#include "mediapipe/tasks/cc/core/task_runner.h" - - -NS_ASSUME_NONNULL_BEGIN - -/** - * The base class of the user-facing iOS mediapipe task api classes. - */ -@interface MPPTaskManager : NSObject -/** - * Initializes a new `MPPTaskManager` with the mediapipe task graph config proto. - * - * @param graphConfig A mediapipe task graph config proto. - * - * @return An instance of `MPPTaskManager` initialized to the given graph config proto. - */ -- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig - error:(NSError **)error; - -- (absl::StatusOr)process:(const mediapipe::tasks::core::PacketMap&)packetMap error:(NSError **)error; - -- (void)close; - -- (instancetype)init NS_UNAVAILABLE; - -+ (instancetype)new NS_UNAVAILABLE; - -@end - -NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskManager.mm b/mediapipe/tasks/ios/core/sources/MPPTaskManager.mm deleted file mode 100644 index 492ed8cf6..000000000 --- a/mediapipe/tasks/ios/core/sources/MPPTaskManager.mm +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h" -#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" - -namespace { -using ::mediapipe::CalculatorGraphConfig; -using ::mediapipe::Packet; -using ::mediapipe::tasks::core::PacketMap; -using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; -} // namespace - -@interface MPPTaskManager () { - /** TextSearcher backed by C++ API */ - std::unique_ptr _cppTaskRunner; -} -@end - -@implementation MPPTaskManager - -- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig - error:(NSError **)error { - self = [super init]; - if (self) { - auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig)); - - if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) { - return nil; - } - - _cppTaskRunner = std::move(taskRunnerResult.value()); - } - return self; -} - -- (absl::StatusOr)process:(const PacketMap&)packetMap { - return _cppTaskRunner->Process(packetMap); -} - -- (void)close { - _cppTaskRunner->Close(); -} - -@end diff --git a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m index 3808009f3..fa04c3e65 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m +++ b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m @@ -19,6 +19,28 @@ NS_ASSUME_NONNULL_BEGIN static NSString *const kBertTextClassifierModelName = @"bert_text_classifier"; +static NSString *const kNegativeText = @"unflinchingly bleak and desperate"; + +#define VerifyCategory(category, expectedIndex, expectedScore, expectedLabel, expectedDisplayName) \ + XCTAssertEqual(category.index, expectedIndex); \ + XCTAssertEqualWithAccuracy(category.score, expectedScore, 1e-6); \ + XCTAssertEqualObjects(category.label, expectedLabel); \ + XCTAssertEqualObjects(category.displayName, expectedDisplayName); + +#define VerifyClassifications(classifications, expectedHeadIndex, expectedCategoryCount) \ + XCTAssertEqual(classifications.categories.count, expectedCategoryCount); + +#define VerifyClassificationResult(classificationResult, expectedClassificationsCount) \ + XCTAssertNotNil(classificationResult); \ + XCTAssertEqual(classificationResult.classifications.count, expectedClassificationsCount) + +#define AssertClassificationResultHasOneHead(classificationResult) \ + XCTAssertNotNil(classificationResult); \ + XCTAssertEqual(classificationResult.classifications.count, 1); + XCTAssertEqual(classificationResult.classifications[0].headIndex, 1); + +#define AssertTextClassifierResultIsNotNil(textClassifierResult) \ + XCTAssertNotNil(textClassifierResult); @interface MPPTextClassifierTests : XCTestCase @end @@ -41,15 +63,25 @@ static NSString *const kBertTextClassifierModelName = @"bert_text_classifier"; - (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName { NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; MPPTextClassifierOptions *textClassifierOptions = - [[MPPTextClassifierOptions alloc] initWithModelPath:modelPath]; + [[MPPTextClassifierOptions alloc] init]; + textClassifierOptions.baseOptions.modelAssetPath = modelPath; return textClassifierOptions; } -- (void)testCreateTextClassifierOptionsSucceeds { - MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; +kBertTextClassifierModelName + +- (MPPTextClassifier *)createTextClassifierFromOptionsWithModelName:(NSString *)modelName { + MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:modelName]; MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; XCTAssertNotNil(textClassifier); + + return textClassifier +} + +- (void)classifyWithBertSucceeds { + MPPTextClassifier *textClassifier = [self createTextClassifierWithModelName:kBertTextClassifierModelName]; + MPPTextClassifierResult *textClassifierResult = [textClassifier classifyWithText:kNegativeText]; } @end diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index 3427e3a6f..61eecb9cd 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -27,10 +27,10 @@ objc_library( deps = [ "//mediapipe/tasks/ios/core:MPPTaskOptions", "//mediapipe/tasks/ios/core:MPPTaskInfo", - "//mediapipe/tasks/ios/core:MPPTaskManager", - "//mediapipe/tasks/ios/core:MPPPacketCreator", + "//mediapipe/tasks/ios/core:MPPTaskRunner", + "//mediapipe/tasks/ios/core:MPPTextPacketCreator", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", - "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", + "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:NSStringHelpers", ":MPPTextClassifierOptions", @@ -51,3 +51,12 @@ objc_library( ], ) +objc_library( + name = "MPPTextClassifierResult", + srcs = ["sources/MPPTextClassifierResult.m"], + hdrs = ["sources/MPPTextClassifierResult.h"], + deps = [ + "//mediapipe/tasks/ios/core:MPPTaskResult", + "//mediapipe/tasks/ios/components/containers:MPPClassificationResult", + ], +) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h index 0c33a5288..19e10e35f 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -14,7 +14,7 @@ ==============================================================================*/ #import -#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" @@ -52,7 +52,7 @@ NS_SWIFT_NAME(TextClassifier) */ - (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error; -- (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error; +- (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm index b4cd66f70..b9e76fc69 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -15,9 +15,9 @@ #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" -#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" -#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h" -#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h" +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" +#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" #import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" @@ -30,14 +30,14 @@ using ::mediapipe::tasks::core::PacketMap; } // namespace static NSString *const kClassificationsStreamName = @"classifications_out"; -static NSString *const kClassificationsTag = @"classifications"; +static NSString *const kClassificationsTag = @"CLASSIFICATIONS"; static NSString *const kTextInStreamName = @"text_in"; static NSString *const kTextTag = @"TEXT"; static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph"; @interface MPPTextClassifier () { /** TextSearcher backed by C++ API */ - MPPTaskManager *_taskManager; + MPPTaskRunner *_taskRunner; } @end @@ -47,8 +47,8 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] initWithTaskGraphName:kTaskGraphName - inputStreams:@[ [NSString stringWithFormat:@"@:@", kTextTag, kTextInStreamName] ] - outputStreams:@[ [NSString stringWithFormat:@"@:@", kClassificationsTag, + inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ] + outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kClassificationsTag, kClassificationsStreamName] ] taskOptions:options enableFlowLimiting:NO @@ -58,7 +58,7 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T return nil; } - _taskManager = [[MPPTaskManager alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error]; + _taskRunner = [[MPPTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error]; self = [super init]; @@ -66,22 +66,23 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T } - (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { - MPPTextClassifierOptions *options = - [[MPPTextClassifierOptions alloc] initWithModelPath:modelPath]; + MPPTextClassifierOptions *options = [[MPPTextClassifierOptions alloc] init]; + + options.baseOptions.modelAssetPath = modelPath; return [self initWithOptions:options error:error]; } -- (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error { - Packet packet = [MPPPacketCreator createWithText:text]; +- (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error { + Packet packet = [MPPTextPacketCreator createWithText:text]; - absl::StatusOr output_packet_map = [_taskManager process:{{kTextInStreamName.cppString, packet}} error:error]; + absl::StatusOr output_packet_map = [_taskRunner process:{{kTextInStreamName.cppString, packet}} error:error]; if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) { return nil; } - return [MPPClassificationResult - classificationResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString] + return [MPPTextClassifierResult + textClassifierResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString] .Get()]; } diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h index 47c44dd0d..374226998 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h @@ -31,20 +31,20 @@ NS_SWIFT_NAME(TextClassifierOptions) */ @property(nonatomic, copy) MPPClassifierOptions *classifierOptions; -/** - * Initializes a new `MPPTextClassifierOptions` with the absolute path to the model file - * stored locally on the device, set to the given the model path. - * - * @discussion The external model file must be a single standalone TFLite file. It could be packed - * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the - * necessary metadata and associated files might result in errors. Check the [documentation] - * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement. - * - * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. - * - * @return An instance of `MPPTextClassifierOptions` initialized to the given model path. - */ -- (instancetype)initWithModelPath:(NSString *)modelPath; +// /** +// * Initializes a new `MPPTextClassifierOptions` with the absolute path to the model file +// * stored locally on the device, set to the given the model path. +// * +// * @discussion The external model file must be a single standalone TFLite file. It could be packed +// * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the +// * necessary metadata and associated files might result in errors. Check the [documentation] +// * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement. +// * +// * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. +// * +// * @return An instance of `MPPTextClassifierOptions` initialized to the given model path. +// */ +// - (instancetype)initWithModelPath:(NSString *)modelPath; @end diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m index 8cab693cd..82e9bed64 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m @@ -16,12 +16,12 @@ @implementation MPPTextClassifierOptions -- (instancetype)initWithModelPath:(NSString *)modelPath { - self = [super initWithModelPath:modelPath]; - if (self) { - _classifierOptions = [[MPPClassifierOptions alloc] init]; - } - return self; -} +// - (instancetype)initWithModelPath:(NSString *)modelPath { +// self = [super initWithModelPath:modelPath]; +// if (self) { +// _classifierOptions = [[MPPClassifierOptions alloc] init]; +// } +// return self; +// } @end \ No newline at end of file diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h new file mode 100644 index 000000000..414e6d9c6 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h @@ -0,0 +1,41 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Encapsulates list of predicted classes (aka labels) for a given image classifier head. */ +NS_SWIFT_NAME(TextClassifierResult) +@interface MPPTextClassifierResult : MPPTaskResult + +@property(nonatomic, readonly) MPPClassificationResult *classificationResult; + +/** + * Initializes a new `MPPClassificationResult` with the given array of classifications. + * + * @param classifications An Aaray of `MPPClassifications` objects containing classifier + * predictions per classifier head. + * + * @return An instance of MPPClassificationResult initialized with the given array of + * classifications. + */ +- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult + timeStamp:(long)timeStamp; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m new file mode 100644 index 000000000..b99ee3b19 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m @@ -0,0 +1,28 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" + +@implementation MPPTextClassifierResult + +- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult + timeStamp:(long)timeStamp { + self = [super initWithTimestamp:timeStamp]; + if (self) { + _classificationResult = classificationResult; + } + return self; +} + +@end diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD index 662e76c2a..d6a371137 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD @@ -28,3 +28,13 @@ objc_library( "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", ], ) + +objc_library( + name = "MPPTextClassifierResultHelpers", + srcs = ["sources/MPPTextClassifierResult+Helpers.mm"], + hdrs = ["sources/MPPTextClassifierResult+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierResult", + "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", + ], +) diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h index 71076da26..0771eafce 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm index 3576cb8d2..aa11384d2 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h" #import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h" #import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h new file mode 100644 index 000000000..d3fb04d69 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h @@ -0,0 +1,28 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPTextClassifierResult (Helpers) + ++ (MPPTextClassifierResult *)textClassifierResultWithProto: + (const mediapipe::tasks::components::containers::proto::ClassificationResult &) + classificationResultProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm new file mode 100644 index 000000000..2fc2d751d --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm @@ -0,0 +1,39 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" + +namespace { +using ClassificationResultProto = + ::mediapipe::tasks::components::containers::proto::ClassificationResult; +} // namespace + +@implementation MPPTextClassifierResult (Helpers) + ++ (MPPTextClassifierResult *)textClassifierResultWithProto: + (const ClassificationResultProto &)classificationResultProto { + long timeStamp; + + if (classificationResultProto.has_timestamp_ms()) { + timeStamp = classificationResultProto.timestamp_ms(); + } + + MPPClassificationResult *classificationResult = [MPPClassificationResult classificationResultWithProto:classificationResultProto]; + + return [[MPPTextClassifierResult alloc] initWithClassificationResult:classificationResult + timeStamp:timeStamp]; +} + +@end From 7ce21038bb124a8bab79b47c716458089292b405 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 4 Jan 2023 13:40:17 +0530 Subject: [PATCH 16/18] Merge branch 'master' into ios-task --- .github/bot_config.yml | 3 +- ...low_session_from_saved_model_calculator.cc | 7 +- ..._session_from_saved_model_calculator.proto | 4 +- ...flow_session_from_saved_model_generator.cc | 7 +- ...w_session_from_saved_model_generator.proto | 4 +- mediapipe/examples/desktop/autoflip/BUILD | 4 + mediapipe/framework/api2/builder.h | 2 +- mediapipe/framework/api2/packet.h | 2 +- mediapipe/framework/formats/BUILD | 2 +- mediapipe/framework/formats/tensor.h | 7 +- mediapipe/framework/formats/tensor_ahwb.cc | 10 +- .../framework/formats/tensor_ahwb_gpu_test.cc | 16 +- .../framework/formats/tensor_ahwb_test.cc | 94 ++++++++--- .../framework/profiler/graph_profiler.cc | 1 + mediapipe/framework/profiler/graph_profiler.h | 9 ++ .../framework/profiler/graph_profiler_test.cc | 26 +++ .../gpu/gpu_buffer_storage_cv_pixel_buffer.cc | 75 +++++---- .../text_classifier/text_classifier_test.py | 7 +- .../image_classifier/image_classifier_test.py | 5 +- mediapipe/objc/MPPGraph.mm | 9 +- .../gesture_recognizer/gesture_recognizer.cc | 4 +- mediapipe/tasks/ios/common/utils/BUILD | 1 - .../ios/common/utils/sources/MPPCommonUtils.h | 2 + .../common/utils/sources/MPPCommonUtils.mm | 5 +- .../common/utils/sources/NSString+Helpers.h | 3 +- .../tasks/ios/components/processors/BUILD | 1 - .../processors/sources/MPPClassifierOptions.h | 21 ++- .../processors/sources/MPPClassifierOptions.m | 4 +- .../ios/components/processors/utils/BUILD | 9 +- .../sources/MPPClassifierOptions+Helpers.h | 1 + .../sources/MPPClassifierOptions+Helpers.mm | 3 +- mediapipe/tasks/ios/core/BUILD | 18 ++- .../tasks/ios/core/sources/MPPTaskInfo.h | 4 +- .../tasks/ios/core/sources/MPPTaskInfo.mm | 1 - .../ios/core/sources/MPPTaskOptionsProtocol.h | 3 +- .../tasks/ios/core/sources/MPPTaskRunner.h | 8 +- .../tasks/ios/core/sources/MPPTaskRunner.mm | 5 +- .../audio_classifier/audio_classifier.ts | 7 +- .../audio_classifier/audio_classifier_test.ts | 3 +- .../audio/audio_embedder/audio_embedder.ts | 7 +- .../audio_embedder/audio_embedder_test.ts | 3 +- .../tasks/web/components/processors/BUILD | 26 --- .../processors/base_options.test.ts | 127 --------------- .../web/components/processors/base_options.ts | 80 ---------- mediapipe/tasks/web/core/BUILD | 5 +- mediapipe/tasks/web/core/task_runner.ts | 75 ++++++++- mediapipe/tasks/web/core/task_runner_test.ts | 148 +++++++++++++++++- .../tasks/web/core/task_runner_test_utils.ts | 4 +- .../text/text_classifier/text_classifier.ts | 7 +- .../text_classifier/text_classifier_test.ts | 3 +- .../web/text/text_embedder/text_embedder.ts | 7 +- .../text/text_embedder/text_embedder_test.ts | 3 +- mediapipe/tasks/web/vision/core/BUILD | 1 + .../vision/core/vision_task_runner.test.ts | 32 ++-- .../web/vision/core/vision_task_runner.ts | 4 +- .../gesture_recognizer/gesture_recognizer.ts | 43 +++-- .../gesture_recognizer_result.d.ts | 8 +- .../gesture_recognizer_test.ts | 26 ++- .../vision/hand_landmarker/hand_landmarker.ts | 8 +- .../hand_landmarker_result.d.ts | 2 + .../hand_landmarker/hand_landmarker_test.ts | 3 +- .../image_classifier/image_classifier.ts | 7 +- .../image_classifier/image_classifier_test.ts | 3 +- .../vision/image_embedder/image_embedder.ts | 7 +- .../image_embedder/image_embedder_test.ts | 3 +- .../vision/object_detector/object_detector.ts | 8 +- .../object_detector/object_detector_test.ts | 3 +- 67 files changed, 593 insertions(+), 457 deletions(-) delete mode 100644 mediapipe/tasks/web/components/processors/base_options.test.ts delete mode 100644 mediapipe/tasks/web/components/processors/base_options.ts diff --git a/.github/bot_config.yml b/.github/bot_config.yml index 8ad724168..74a60e4b9 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -15,4 +15,5 @@ # A list of assignees assignees: - - sureshdagooglecom + - kuaashish + - ayushgdev diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc index 922eb9d50..18bddbbe3 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc @@ -55,7 +55,7 @@ absl::Status GetLatestDirectory(std::string* path) { } // If options.convert_signature_to_tags() is set, will convert letters to -// uppercase and replace /'s and -'s with _'s. This enables the standard +// uppercase and replace /, -, . and :'s with _'s. This enables the standard // SavedModel classification, regression, and prediction signatures to be used // as uppercase INPUTS and OUTPUTS tags for streams and supports other common // patterns. @@ -67,9 +67,8 @@ const std::string MaybeConvertSignatureToTag( output.resize(name.length()); std::transform(name.begin(), name.end(), output.begin(), [](unsigned char c) { return std::toupper(c); }); - output = absl::StrReplaceAll(output, {{"/", "_"}}); - output = absl::StrReplaceAll(output, {{"-", "_"}}); - output = absl::StrReplaceAll(output, {{".", "_"}}); + output = absl::StrReplaceAll( + output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}}); LOG(INFO) << "Renamed TAG from: " << name << " to " << output; return output; } else { diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto index 927d3b51f..515b46fa9 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto @@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelCalculatorOptions { // The name of the generic signature to load into the mapping from tags to // tensor names. optional string signature_name = 2 [default = "serving_default"]; - // Whether to convert the signature keys to uppercase as well as switch /'s - // and -'s to _'s, which enables common signatures to be used as Tags. + // Whether to convert the signature keys to uppercase as well as switch + // /, -, .and :'s to _'s, which enables common signatures to be used as Tags. optional bool convert_signature_to_tags = 3 [default = true]; // If true, saved_model_path can have multiple exported models in // subdirectories saved_model_path/%08d and the alphabetically last (i.e., diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc index d5236f1cc..ee69ec56a 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc @@ -61,7 +61,7 @@ absl::Status GetLatestDirectory(std::string* path) { } // If options.convert_signature_to_tags() is set, will convert letters to -// uppercase and replace /'s and -'s with _'s. This enables the standard +// uppercase and replace /, -, and .'s with _'s. This enables the standard // SavedModel classification, regression, and prediction signatures to be used // as uppercase INPUTS and OUTPUTS tags for streams and supports other common // patterns. @@ -73,9 +73,8 @@ const std::string MaybeConvertSignatureToTag( output.resize(name.length()); std::transform(name.begin(), name.end(), output.begin(), [](unsigned char c) { return std::toupper(c); }); - output = absl::StrReplaceAll(output, {{"/", "_"}}); - output = absl::StrReplaceAll(output, {{"-", "_"}}); - output = absl::StrReplaceAll(output, {{".", "_"}}); + output = absl::StrReplaceAll( + output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}}); LOG(INFO) << "Renamed TAG from: " << name << " to " << output; return output; } else { diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto index d24a1cd73..d45fcb662 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto @@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelGeneratorOptions { // The name of the generic signature to load into the mapping from tags to // tensor names. optional string signature_name = 2 [default = "serving_default"]; - // Whether to convert the signature keys to uppercase as well as switch /'s - // and -'s to _'s, which enables common signatures to be used as Tags. + // Whether to convert the signature keys to uppercase, as well as switch /'s + // -'s, .'s, and :'s to _'s, enabling common signatures to be used as Tags. optional bool convert_signature_to_tags = 3 [default = true]; // If true, saved_model_path can have multiple exported models in // subdirectories saved_model_path/%08d and the alphabetically last (i.e., diff --git a/mediapipe/examples/desktop/autoflip/BUILD b/mediapipe/examples/desktop/autoflip/BUILD index 562f11c49..0e28746dc 100644 --- a/mediapipe/examples/desktop/autoflip/BUILD +++ b/mediapipe/examples/desktop/autoflip/BUILD @@ -30,6 +30,10 @@ proto_library( java_lite_proto_library( name = "autoflip_messages_java_proto_lite", + visibility = [ + "//java/com/google/android/apps/photos:__subpackages__", + "//javatests/com/google/android/apps/photos:__subpackages__", + ], deps = [ ":autoflip_messages_proto", ], diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 19273bf44..2a98c4166 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -398,7 +398,7 @@ template class Node; #if __cplusplus >= 201703L // Deduction guide to silence -Wctad-maybe-unsupported. -explicit Node()->Node; +explicit Node() -> Node; #endif // C++17 template <> diff --git a/mediapipe/framework/api2/packet.h b/mediapipe/framework/api2/packet.h index 7933575d3..b1ebb0410 100644 --- a/mediapipe/framework/api2/packet.h +++ b/mediapipe/framework/api2/packet.h @@ -181,7 +181,7 @@ template class Packet; #if __cplusplus >= 201703L // Deduction guide to silence -Wctad-maybe-unsupported. -explicit Packet()->Packet; +explicit Packet() -> Packet; #endif // C++17 template <> diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index f5a043f10..cce7e5bd0 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -455,7 +455,7 @@ cc_library( ], }), deps = [ - "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 8a6f02e9d..0f19bb5ee 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -24,7 +24,7 @@ #include #include -#include "absl/container/flat_hash_set.h" +#include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/formats/tensor_internal.h" #include "mediapipe/framework/port.h" @@ -434,8 +434,9 @@ class Tensor { mutable bool use_ahwb_ = false; mutable uint64_t ahwb_tracking_key_ = 0; // TODO: Tracks all unique tensors. Can grow to a large number. LRU - // can be more predicted. - static inline absl::flat_hash_set ahwb_usage_track_; + // (Least Recently Used) can be more predicted. + // The value contains the size alignment parameter. + static inline absl::flat_hash_map ahwb_usage_track_; // Expects the target SSBO to be already bound. bool AllocateAhwbMapToSsbo() const; bool InsertAhwbToSsboFence() const; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 466811be7..525f05f31 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -266,7 +266,12 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView( bool Tensor::AllocateAHardwareBuffer(int size_alignment) const { // Mark current tracking key as Ahwb-use. - ahwb_usage_track_.insert(ahwb_tracking_key_); + if (auto it = ahwb_usage_track_.find(ahwb_tracking_key_); + it != ahwb_usage_track_.end()) { + size_alignment = it->second; + } else if (ahwb_tracking_key_ != 0) { + ahwb_usage_track_.insert({ahwb_tracking_key_, size_alignment}); + } use_ahwb_ = true; if (__builtin_available(android 26, *)) { @@ -458,7 +463,8 @@ void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const { ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim); } } - use_ahwb_ = ahwb_usage_track_.contains(ahwb_tracking_key_); + // Keep flag value if it was set previously. + use_ahwb_ = use_ahwb_ || ahwb_usage_track_.contains(ahwb_tracking_key_); } #else // MEDIAPIPE_TENSOR_USE_AHWB diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index a6ca00949..e2ad869f9 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -92,9 +92,14 @@ class TensorAhwbGpuTest : public mediapipe::GpuTestBase { }; TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { - Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + { + // Request Ahwb first to get Ahwb storage allocated internally. + auto view = tensor.GetAHardwareBufferWriteView(); + EXPECT_NE(view.handle(), nullptr); + view.SetWritingFinishedFD(-1, [](bool) { return true; }); + } RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); @@ -114,9 +119,14 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { } TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { - Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})}; + { + // Request Ahwb first to get Ahwb storage allocated internally. + auto view = tensor.GetAHardwareBufferWriteView(); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); + } RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); @@ -139,7 +149,6 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { // Request the CPU view to get the memory to be allocated. // Request Ahwb view then to transform the storage into Ahwb. - Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; { @@ -168,7 +177,6 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) { // Request the GPU view to get the ssbo allocated internally. // Request Ahwb view then to transform the storage into Ahwb. - Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; RunInGlContext([&tensor] { diff --git a/mediapipe/framework/formats/tensor_ahwb_test.cc b/mediapipe/framework/formats/tensor_ahwb_test.cc index 7ab5a4925..3da6ca8d3 100644 --- a/mediapipe/framework/formats/tensor_ahwb_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_test.cc @@ -1,34 +1,28 @@ #include "mediapipe/framework/formats/tensor.h" -#include "mediapipe/gpu/gpu_test_base.h" #include "testing/base/public/gmock.h" #include "testing/base/public/gunit.h" -#ifdef MEDIAPIPE_TENSOR_USE_AHWB -#if !MEDIAPIPE_DISABLE_GPU - namespace mediapipe { -class TensorAhwbTest : public mediapipe::GpuTestBase { - public: -}; - -TEST_F(TensorAhwbTest, TestCpuThenAHWB) { +TEST(TensorAhwbTest, TestCpuThenAHWB) { Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); { auto ptr = tensor.GetCpuWriteView().buffer(); EXPECT_NE(ptr, nullptr); } { - auto ahwb = tensor.GetAHardwareBufferReadView().handle(); - EXPECT_NE(ahwb, nullptr); + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); } } -TEST_F(TensorAhwbTest, TestAHWBThenCpu) { +TEST(TensorAhwbTest, TestAHWBThenCpu) { Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); { - auto ahwb = tensor.GetAHardwareBufferWriteView().handle(); - EXPECT_NE(ahwb, nullptr); + auto view = tensor.GetAHardwareBufferWriteView(); + EXPECT_NE(view.handle(), nullptr); + view.SetWritingFinishedFD(-1, [](bool) { return true; }); } { auto ptr = tensor.GetCpuReadView().buffer(); @@ -36,21 +30,71 @@ TEST_F(TensorAhwbTest, TestAHWBThenCpu) { } } -TEST_F(TensorAhwbTest, TestCpuThenGl) { - RunInGlContext([] { - Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); +TEST(TensorAhwbTest, TestAhwbAlignment) { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{5}); + { + auto view = tensor.GetAHardwareBufferWriteView(16); + EXPECT_NE(view.handle(), nullptr); + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc; + AHardwareBuffer_describe(view.handle(), &desc); + // sizeof(float) * 5 = 20, the closest aligned to 16 size is 32. + EXPECT_EQ(desc.width, 32); + } + view.SetWritingFinishedFD(-1, [](bool) { return true; }); + } +} + +// Tensor::GetCpuView uses source location mechanism that gives source file name +// and line from where the method is called. The function is intended just to +// have two calls providing the same source file name and line. +auto GetCpuView(const Tensor &tensor) { return tensor.GetCpuWriteView(); } + +// The test checks the tracking mechanism: when a tensor's Cpu view is retrieved +// for the first time then the source location is attached to the tensor. If the +// Ahwb view is requested then from the tensor then the previously recorded Cpu +// view request source location is marked for using Ahwb storage. +// When a Cpu view with the same source location (but for the newly allocated +// tensor) is requested and the location is marked to use Ahwb storage then the +// Ahwb storage is allocated for the CpuView. +TEST(TensorAhwbTest, TestTrackingAhwb) { + // Create first tensor and request Cpu and then Ahwb view to mark the source + // location for Ahwb storage. + { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9}); { - auto ptr = tensor.GetCpuWriteView().buffer(); - EXPECT_NE(ptr, nullptr); + auto view = GetCpuView(tensor); + EXPECT_NE(view.buffer(), nullptr); } { - auto ssbo = tensor.GetOpenGlBufferReadView().name(); - EXPECT_GT(ssbo, 0); + // Align size of the Ahwb by multiple of 16. + auto view = tensor.GetAHardwareBufferWriteView(16); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); } - }); + } + { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9}); + { + // The second tensor uses the same Cpu view source location so Ahwb + // storage is allocated internally. + auto view = GetCpuView(tensor); + EXPECT_NE(view.buffer(), nullptr); + } + { + // Check the Ahwb size to be aligned to multiple of 16. The alignment is + // stored by previous requesting of the Ahwb view. + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc; + AHardwareBuffer_describe(view.handle(), &desc); + // sizeof(float) * 9 = 36. The closest aligned size is 48. + EXPECT_EQ(desc.width, 48); + } + view.SetReadingFinishedFunc([](bool) { return true; }); + } + } } } // namespace mediapipe - -#endif // !MEDIAPIPE_DISABLE_GPU -#endif // MEDIAPIPE_TENSOR_USE_AHWB diff --git a/mediapipe/framework/profiler/graph_profiler.cc b/mediapipe/framework/profiler/graph_profiler.cc index f14acfc78..6aead5250 100644 --- a/mediapipe/framework/profiler/graph_profiler.cc +++ b/mediapipe/framework/profiler/graph_profiler.cc @@ -194,6 +194,7 @@ void GraphProfiler::Initialize( "Calculator \"$0\" has already been added.", node_name); } profile_builder_ = std::make_unique(this); + graph_id_ = ++next_instance_id_; is_initialized_ = true; } diff --git a/mediapipe/framework/profiler/graph_profiler.h b/mediapipe/framework/profiler/graph_profiler.h index 23caed4ec..6358cb057 100644 --- a/mediapipe/framework/profiler/graph_profiler.h +++ b/mediapipe/framework/profiler/graph_profiler.h @@ -237,6 +237,9 @@ class GraphProfiler : public std::enable_shared_from_this { return validated_graph_; } + // Gets a numerical identifier for this GraphProfiler object. + uint64_t GetGraphId() { return graph_id_; } + private: // This can be used to add packet info for the input streams to the graph. // It treats the stream defined by |stream_name| as a stream produced by a @@ -357,6 +360,12 @@ class GraphProfiler : public std::enable_shared_from_this { class GraphProfileBuilder; std::unique_ptr profile_builder_; + // The globally incrementing identifier for all graphs in a process. + static inline std::atomic_int next_instance_id_ = 0; + + // A unique identifier for this object. Only unique within a process. + uint64_t graph_id_; + // For testing. friend GraphProfilerTestPeer; }; diff --git a/mediapipe/framework/profiler/graph_profiler_test.cc b/mediapipe/framework/profiler/graph_profiler_test.cc index 81ba90cda..75d1c7ebd 100644 --- a/mediapipe/framework/profiler/graph_profiler_test.cc +++ b/mediapipe/framework/profiler/graph_profiler_test.cc @@ -442,6 +442,32 @@ TEST_F(GraphProfilerTestPeer, InitializeMultipleTimes) { "Cannot initialize .* multiple times."); } +// Tests that graph identifiers are not reused, even after destruction. +TEST_F(GraphProfilerTestPeer, InitializeMultipleProfilers) { + auto raw_graph_config = R"( + profiler_config { + enable_profiler: true + } + input_stream: "input_stream" + node { + calculator: "DummyTestCalculator" + input_stream: "input_stream" + })"; + const int n_iterations = 100; + absl::flat_hash_set seen_ids; + for (int i = 0; i < n_iterations; ++i) { + std::shared_ptr profiler = + std::make_shared(); + auto graph_config = CreateGraphConfig(raw_graph_config); + mediapipe::ValidatedGraphConfig validated_graph; + QCHECK_OK(validated_graph.Initialize(graph_config)); + profiler->Initialize(validated_graph); + + int id = profiler->GetGraphId(); + ASSERT_THAT(seen_ids, testing::Not(testing::Contains(id))); + seen_ids.insert(id); + } +} // Tests that Pause(), Resume(), and Reset() works. TEST_F(GraphProfilerTestPeer, PauseResumeReset) { InitializeProfilerWithGraphConfig(R"( diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc index 014cc1c69..7cac32b7f 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc @@ -74,42 +74,51 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView( static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer, const GlTextureView& view) { CHECK(pixel_buffer); - CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); - CHECK(err == kCVReturnSuccess) - << "CVPixelBufferLockBaseAddress failed: " << err; - OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); - size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer); - uint8_t* pixel_ptr = - static_cast(CVPixelBufferGetBaseAddress(pixel_buffer)); - if (pixel_format == kCVPixelFormatType_32BGRA) { - // TODO: restore previous framebuffer? Move this to helper so we - // can use BindFramebuffer? - glViewport(0, 0, view.width(), view.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), - view.name(), 0); + auto ctx = GlContext::GetCurrent().get(); + if (!ctx) ctx = view.gl_context(); + ctx->Run([pixel_buffer, &view, ctx] { + CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); + CHECK(err == kCVReturnSuccess) + << "CVPixelBufferLockBaseAddress failed: " << err; + OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); + size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer); + uint8_t* pixel_ptr = + static_cast(CVPixelBufferGetBaseAddress(pixel_buffer)); + if (pixel_format == kCVPixelFormatType_32BGRA) { + glBindFramebuffer(GL_FRAMEBUFFER, kUtilityFramebuffer.Get(*ctx)); + glViewport(0, 0, view.width(), view.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + view.target(), view.name(), 0); - size_t contiguous_bytes_per_row = view.width() * 4; - if (bytes_per_row == contiguous_bytes_per_row) { - glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE, - pixel_ptr); - } else { - std::vector contiguous_buffer(contiguous_bytes_per_row * - view.height()); - uint8_t* temp_ptr = contiguous_buffer.data(); - glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE, - temp_ptr); - for (int i = 0; i < view.height(); ++i) { - memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row); - temp_ptr += contiguous_bytes_per_row; - pixel_ptr += bytes_per_row; + size_t contiguous_bytes_per_row = view.width() * 4; + if (bytes_per_row == contiguous_bytes_per_row) { + glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, + GL_UNSIGNED_BYTE, pixel_ptr); + } else { + // TODO: use GL_PACK settings for row length. We can expect + // GLES 3.0 on iOS now. + std::vector contiguous_buffer(contiguous_bytes_per_row * + view.height()); + uint8_t* temp_ptr = contiguous_buffer.data(); + glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, + GL_UNSIGNED_BYTE, temp_ptr); + for (int i = 0; i < view.height(); ++i) { + memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row); + temp_ptr += contiguous_bytes_per_row; + pixel_ptr += bytes_per_row; + } } + // TODO: restore previous framebuffer? + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + view.target(), 0, 0); + glBindFramebuffer(GL_FRAMEBUFFER, 0); + } else { + LOG(ERROR) << "unsupported pixel format: " << pixel_format; } - } else { - LOG(ERROR) << "unsupported pixel format: " << pixel_format; - } - err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); - CHECK(err == kCVReturnSuccess) - << "CVPixelBufferUnlockBaseAddress failed: " << err; + err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); + CHECK(err == kCVReturnSuccess) + << "CVPixelBufferUnlockBaseAddress failed: " << err; + }); } #endif // TARGET_IPHONE_SIMULATOR diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index 7a30d19fd..eb4443b44 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -71,9 +71,12 @@ class TextClassifierTest(tf.test.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) + filecmp.clear_cache() self.assertTrue( - filecmp.cmp(output_metadata_file, - self._AVERAGE_WORD_EMBEDDING_JSON_FILE)) + filecmp.cmp( + output_metadata_file, + self._AVERAGE_WORD_EMBEDDING_JSON_FILE, + shallow=False)) def test_create_and_train_bert(self): train_data, validation_data = self._get_data() diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index 6ca21d334..afda8643b 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -135,7 +135,10 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) - self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file)) + filecmp.clear_cache() + self.assertTrue( + filecmp.cmp( + output_metadata_file, expected_metadata_file, shallow=False)) def test_continual_training_by_loading_checkpoint(self): mock_stdout = io.StringIO() diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.mm index 1bd177e80..3123eb863 100644 --- a/mediapipe/objc/MPPGraph.mm +++ b/mediapipe/objc/MPPGraph.mm @@ -230,16 +230,17 @@ if ([wrapper.delegate } - (absl::Status)performStart { - absl::Status status = _graph->Initialize(_config); - if (!status.ok()) { - return status; - } + absl::Status status; for (const auto& service_packet : _servicePackets) { status = _graph->SetServicePacket(*service_packet.first, service_packet.second); if (!status.ok()) { return status; } } + status = _graph->Initialize(_config); + if (!status.ok()) { + return status; + } status = _graph->StartRun(_inputSidePackets, _streamHeaders); if (!status.ok()) { return status; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index 01f444742..91a5ec213 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -151,11 +151,11 @@ ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) { auto custom_gestures_classifier_options_proto = std::make_unique( components::processors::ConvertClassifierOptionsToProto( - &(options->canned_gestures_classifier_options))); + &(options->custom_gestures_classifier_options))); hand_gesture_recognizer_graph_options ->mutable_custom_gesture_classifier_graph_options() ->mutable_classifier_options() - ->Swap(canned_gestures_classifier_options_proto.get()); + ->Swap(custom_gestures_classifier_options_proto.get()); return options_proto; } diff --git a/mediapipe/tasks/ios/common/utils/BUILD b/mediapipe/tasks/ios/common/utils/BUILD index f2ffda39e..a29c700da 100644 --- a/mediapipe/tasks/ios/common/utils/BUILD +++ b/mediapipe/tasks/ios/common/utils/BUILD @@ -38,4 +38,3 @@ objc_library( "-std=c++17", ], ) - diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h index 407d87aba..5404a074d 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h @@ -13,6 +13,7 @@ // limitations under the License. #import + #include "mediapipe/tasks/cc/common.h" NS_ASSUME_NONNULL_BEGIN @@ -56,6 +57,7 @@ extern NSString *const MPPTasksErrorDomain; * @param status absl::Status. * @param error Pointer to the memory location where the created error should be saved. If `nil`, * no error will be saved. + * @return YES when there is no error, NO otherwise. */ + (BOOL)checkCppError:(const absl::Status &)status toError:(NSError **)error; diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 4d4880a87..8234ac6d3 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -20,7 +20,6 @@ #include "absl/status/status.h" // from @com_google_absl #include "absl/strings/cord.h" // from @com_google_absl - #include "mediapipe/tasks/cc/common.h" /** Error domain of MediaPipe task library errors. */ @@ -96,8 +95,8 @@ NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks"; // appropriate MPPTasksErrorCode in default cases. Note: // The mapping to absl::Status::code() is done to generate a more specific error code than // MPPTasksErrorCodeError in cases when the payload can't be mapped to - // MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn returned - // without modification by Mediapipe cc library methods. + // MPPTasksErrorCode. This can happen when absl::Status returned by TFLite library are in turn + // returned without modification by Mediapipe cc library methods. if (errorCode > MPPTasksErrorCodeLast || errorCode <= MPPTasksErrorCodeFirst) { switch (status.code()) { case absl::StatusCode::kInternal: diff --git a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h index aac7485da..66f9c5ccc 100644 --- a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h +++ b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h @@ -13,13 +13,14 @@ // limitations under the License. #import + #include NS_ASSUME_NONNULL_BEGIN @interface NSString (Helpers) -@property(readonly) std::string cppString; +@property(readonly, nonatomic) std::string cppString; + (NSString *)stringWithCppString:(std::string)text; diff --git a/mediapipe/tasks/ios/components/processors/BUILD b/mediapipe/tasks/ios/components/processors/BUILD index 6d1cfdf59..165145076 100644 --- a/mediapipe/tasks/ios/components/processors/BUILD +++ b/mediapipe/tasks/ios/components/processors/BUILD @@ -21,4 +21,3 @@ objc_library( srcs = ["sources/MPPClassifierOptions.m"], hdrs = ["sources/MPPClassifierOptions.h"], ) - diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index 7bf5744f7..13dca4030 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -22,29 +22,34 @@ NS_ASSUME_NONNULL_BEGIN NS_SWIFT_NAME(ClassifierOptions) @interface MPPClassifierOptions : NSObject -/** The locale to use for display names specified through the TFLite Model - * Metadata, if any. Defaults to English. +/** + * The locale to use for display names specified through the TFLite Model + * Metadata, if any. Defaults to English. */ @property(nonatomic, copy) NSString *displayNamesLocale; -/** The maximum number of top-scored classification results to return. If < 0, +/** + * The maximum number of top-scored classification results to return. If < 0, * all available results will be returned. If 0, an invalid argument error is - * returned. + * returned. */ @property(nonatomic) NSInteger maxResults; -/** Score threshold to override the one provided in the model metadata (if any). - * Results below this value are rejected. +/** + * Score threshold to override the one provided in the model metadata (if any). + * Results below this value are rejected. */ @property(nonatomic) float scoreThreshold; -/** The allowlist of category names. If non-empty, detection results whose +/** + * The allowlist of category names. If non-empty, detection results whose * category name is not in this set will be filtered out. Duplicate or unknown * category names are ignored. Mutually exclusive with categoryDenylist. */ @property(nonatomic, copy) NSArray *categoryAllowlist; -/** The denylist of category names. If non-empty, detection results whose +/** + * The denylist of category names. If non-empty, detection results whose * category name is in this set will be filtered out. Duplicate or unknown * category names are ignored. Mutually exclusive with categoryAllowlist. */ diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m index accb6c7dd..01f498184 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m @@ -19,8 +19,8 @@ - (instancetype)init { self = [super init]; if (self) { - self.maxResults = -1; - self.scoreThreshold = 0; + _maxResults = -1; + _scoreThreshold = 0; } return self; } diff --git a/mediapipe/tasks/ios/components/processors/utils/BUILD b/mediapipe/tasks/ios/components/processors/utils/BUILD index 820c6bb56..5344c5fdf 100644 --- a/mediapipe/tasks/ios/components/processors/utils/BUILD +++ b/mediapipe/tasks/ios/components/processors/utils/BUILD @@ -21,9 +21,8 @@ objc_library( srcs = ["sources/MPPClassifierOptions+Helpers.mm"], hdrs = ["sources/MPPClassifierOptions+Helpers.h"], deps = [ - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/ios/components/processors:MPPClassifierOptions", - "//mediapipe/tasks/ios/common/utils:NSStringHelpers", - ] + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/components/processors:MPPClassifierOptions", + ], ) - diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h index 6644a6255..e156020df 100644 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h @@ -13,6 +13,7 @@ // limitations under the License. #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" + #import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" NS_ASSUME_NONNULL_BEGIN diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm index efe9572e1..24b54fd6a 100644 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm @@ -23,13 +23,12 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto - (void)copyToProto:(ClassifierOptionsProto *)classifierOptionsProto { classifierOptionsProto->Clear(); - + if (self.displayNamesLocale) { classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); } classifierOptionsProto->set_max_results((int)self.maxResults); - classifierOptionsProto->set_score_threshold(self.scoreThreshold); for (NSString *category in self.categoryAllowlist) { diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index adc37d901..434d20085 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -54,14 +54,14 @@ objc_library( "-std=c++17", ], deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", - "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", ":MPPTaskOptions", ":MPPTaskOptionsProtocol", + "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/tasks/ios/common:MPPCommon", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:NSStringHelpers", - "//mediapipe/tasks/ios/common:MPPCommon", ], ) @@ -83,9 +83,13 @@ objc_library( name = "MPPTaskRunner", srcs = ["sources/MPPTaskRunner.mm"], hdrs = ["sources/MPPTaskRunner.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], deps = [ - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", ], ) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h index ae4c9eba1..b94e704d1 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h @@ -13,7 +13,9 @@ // limitations under the License. #import + #include "mediapipe/framework/calculator.pb.h" + #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" @@ -59,7 +61,7 @@ NS_ASSUME_NONNULL_BEGIN /** * Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance. */ -- (mediapipe::CalculatorGraphConfig)generateGraphConfig; +- (::mediapipe::CalculatorGraphConfig)generateGraphConfig; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm index be3c8cbf7..5f2290497 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm @@ -24,7 +24,6 @@ namespace { using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig; using Node = ::mediapipe::CalculatorGraphConfig::Node; -using ::mediapipe::CalculatorOptions; using ::mediapipe::FlowLimiterCalculatorOptions; using ::mediapipe::InputStreamInfo; } // namespace diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h index 44fba4c0b..c03165c1d 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h @@ -13,6 +13,7 @@ // limitations under the License. #import + #include "mediapipe/framework/calculator_options.pb.h" NS_ASSUME_NONNULL_BEGIN @@ -25,7 +26,7 @@ NS_ASSUME_NONNULL_BEGIN /** * Copies the iOS MediaPipe task options to an object of mediapipe::CalculatorOptions proto. */ -- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto; +- (void)copyToProto:(::mediapipe::CalculatorOptions *)optionsProto; @end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h index 6561e136d..2b9f2ecdb 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -22,7 +22,6 @@ NS_ASSUME_NONNULL_BEGIN /** * This class is used to create and call appropriate methods on the C++ Task Runner. */ - @interface MPPTaskRunner : NSObject /** @@ -35,11 +34,10 @@ NS_ASSUME_NONNULL_BEGIN - (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig error:(NSError **)error NS_DESIGNATED_INITIALIZER; -- (absl::StatusOr) - process:(const mediapipe::tasks::core::PacketMap &)packetMap - error:(NSError **)error; +- (absl::StatusOr)process: + (const mediapipe::tasks::core::PacketMap &)packetMap; -- (void)close; +- (absl::Status)close; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm index e08d0bc1b..c5c307fd5 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm @@ -17,7 +17,6 @@ namespace { using ::mediapipe::CalculatorGraphConfig; -using ::mediapipe::Packet; using ::mediapipe::tasks::core::PacketMap; using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; } // namespace @@ -49,8 +48,8 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; return _cppTaskRunner->Process(packetMap); } -- (void)close { - _cppTaskRunner->Close(); +- (absl::Status)close { + return _cppTaskRunner->Close(); } @end diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 7bfca680a..51573f50a 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -119,11 +119,10 @@ export class AudioClassifier extends AudioTaskRunner { * * @param options The options for the audio classifier. */ - override async setOptions(options: AudioClassifierOptions): Promise { - await super.setOptions(options); + override setOptions(options: AudioClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -171,7 +170,7 @@ export class AudioClassifier extends AudioTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM); diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts index d5c0a9429..2089f184f 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts @@ -79,7 +79,8 @@ describe('AudioClassifier', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); audioClassifier = new AudioClassifierFake(); - await audioClassifier.setOptions({}); // Initialize graph + await audioClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 246cba883..6a4b8ce39 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -121,11 +121,10 @@ export class AudioEmbedder extends AudioTaskRunner { * * @param options The options for the audio embedder. */ - override async setOptions(options: AudioEmbedderOptions): Promise { - await super.setOptions(options); + override setOptions(options: AudioEmbedderOptions): Promise { this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -171,7 +170,7 @@ export class AudioEmbedder extends AudioTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM); diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts index 2f605ff98..dde61a6e9 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts @@ -70,7 +70,8 @@ describe('AudioEmbedder', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); audioEmbedder = new AudioEmbedderFake(); - await audioEmbedder.setOptions({}); // Initialize graph + await audioEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', () => { diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index 148a08238..cab24293d 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -103,29 +103,3 @@ jasmine_node_test( name = "embedder_options_test", deps = [":embedder_options_test_lib"], ) - -mediapipe_ts_library( - name = "base_options", - srcs = [ - "base_options.ts", - ], - deps = [ - "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", - "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", - "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", - "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", - "//mediapipe/tasks/web/core", - ], -) - -mediapipe_ts_library( - name = "base_options_test_lib", - testonly = True, - srcs = ["base_options.test.ts"], - deps = [":base_options"], -) - -jasmine_node_test( - name = "base_options_test", - deps = [":base_options_test_lib"], -) diff --git a/mediapipe/tasks/web/components/processors/base_options.test.ts b/mediapipe/tasks/web/components/processors/base_options.test.ts deleted file mode 100644 index 6d58be68f..000000000 --- a/mediapipe/tasks/web/components/processors/base_options.test.ts +++ /dev/null @@ -1,127 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import 'jasmine'; - -// Placeholder for internal dependency on encodeByteArray -// Placeholder for internal dependency on trusted resource URL builder - -import {convertBaseOptionsToProto} from './base_options'; - -describe('convertBaseOptionsToProto()', () => { - const mockBytes = new Uint8Array([0, 1, 2, 3]); - const mockBytesResult = { - modelAsset: { - fileContent: Buffer.from(mockBytes).toString('base64'), - fileName: undefined, - fileDescriptorMeta: undefined, - filePointerMeta: undefined, - }, - useStreamMode: false, - acceleration: { - xnnpack: undefined, - gpu: undefined, - tflite: {}, - }, - }; - - let fetchSpy: jasmine.Spy; - - beforeEach(() => { - fetchSpy = jasmine.createSpy().and.callFake(async url => { - expect(url).toEqual('foo'); - return { - arrayBuffer: () => mockBytes.buffer, - } as unknown as Response; - }); - global.fetch = fetchSpy; - }); - - it('verifies that at least one model asset option is provided', async () => { - await expectAsync(convertBaseOptionsToProto({})) - .toBeRejectedWithError( - /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); - }); - - it('verifies that no more than one model asset option is provided', async () => { - await expectAsync(convertBaseOptionsToProto({ - modelAssetPath: `foo`, - modelAssetBuffer: new Uint8Array([]) - })) - .toBeRejectedWithError( - /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); - }); - - it('downloads model', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetPath: `foo`, - }); - - expect(fetchSpy).toHaveBeenCalled(); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); - - it('does not download model when bytes are provided', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - }); - - expect(fetchSpy).not.toHaveBeenCalled(); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); - - it('can enable CPU delegate', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'CPU', - }); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); - - it('can enable GPU delegate', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'GPU', - }); - expect(baseOptionsProto.toObject()).toEqual({ - ...mockBytesResult, - acceleration: { - xnnpack: undefined, - gpu: { - useAdvancedGpuApi: false, - api: 0, - allowPrecisionLoss: true, - cachedKernelPath: undefined, - serializedModelDir: undefined, - modelToken: undefined, - usage: 2, - }, - tflite: undefined, - }, - }); - }); - - it('can reset delegate', async () => { - let baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'GPU', - }); - // Clear backend - baseOptionsProto = - await convertBaseOptionsToProto({delegate: undefined}, baseOptionsProto); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); -}); diff --git a/mediapipe/tasks/web/components/processors/base_options.ts b/mediapipe/tasks/web/components/processors/base_options.ts deleted file mode 100644 index 97b62b784..000000000 --- a/mediapipe/tasks/web/components/processors/base_options.ts +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb'; -import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb'; -import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; -import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; -import {BaseOptions} from '../../../../tasks/web/core/task_runner_options'; - -// The OSS JS API does not support the builder pattern. -// tslint:disable:jspb-use-builder-pattern - -/** - * Converts a BaseOptions API object to its Protobuf representation. - * @throws If neither a model assset path or buffer is provided - */ -export async function convertBaseOptionsToProto( - updatedOptions: BaseOptions, - currentOptions?: BaseOptionsProto): Promise { - const result = - currentOptions ? currentOptions.clone() : new BaseOptionsProto(); - - await configureExternalFile(updatedOptions, result); - configureAcceleration(updatedOptions, result); - - return result; -} - -/** - * Configues the `externalFile` option and validates that a single model is - * provided. - */ -async function configureExternalFile( - options: BaseOptions, proto: BaseOptionsProto) { - const externalFile = proto.getModelAsset() || new ExternalFile(); - proto.setModelAsset(externalFile); - - if (options.modelAssetPath || options.modelAssetBuffer) { - if (options.modelAssetPath && options.modelAssetBuffer) { - throw new Error( - 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); - } - - let modelAssetBuffer = options.modelAssetBuffer; - if (!modelAssetBuffer) { - const response = await fetch(options.modelAssetPath!.toString()); - modelAssetBuffer = new Uint8Array(await response.arrayBuffer()); - } - externalFile.setFileContent(modelAssetBuffer); - } - - if (!externalFile.hasFileContent()) { - throw new Error( - 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); - } -} - -/** Configues the `acceleration` option. */ -function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) { - const acceleration = proto.getAcceleration() ?? new Acceleration(); - if (options.delegate === 'GPU') { - acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); - } else { - acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); - } - proto.setAcceleration(acceleration); -} diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index 1721661f5..c0d10d28b 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -18,8 +18,10 @@ mediapipe_ts_library( srcs = ["task_runner.ts"], deps = [ ":core", + "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", + "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", - "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", @@ -53,6 +55,7 @@ mediapipe_ts_library( "task_runner_test.ts", ], deps = [ + ":core", ":task_runner", ":task_runner_test_utils", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 2011fadef..ffb538b52 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -14,9 +14,11 @@ * limitations under the License. */ +import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb'; +import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; -import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options'; -import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; +import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb'; +import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; @@ -91,14 +93,52 @@ export abstract class TaskRunner { this.graphRunner.registerModelResourcesGraphService(); } - /** Configures the shared options of a MediaPipe Task. */ - async setOptions(options: TaskRunnerOptions): Promise { - if (options.baseOptions) { - this.baseOptions = await convertBaseOptionsToProto( - options.baseOptions, this.baseOptions); + /** Configures the task with custom options. */ + abstract setOptions(options: TaskRunnerOptions): Promise; + + /** + * Applies the current set of options, including any base options that have + * not been processed by the task implementation. The options are applied + * synchronously unless a `modelAssetPath` is provided. This ensures that + * for most use cases options are applied directly and immediately affect + * the next inference. + */ + protected applyOptions(options: TaskRunnerOptions): Promise { + const baseOptions: BaseOptions = options.baseOptions || {}; + + // Validate that exactly one model is configured + if (options.baseOptions?.modelAssetBuffer && + options.baseOptions?.modelAssetPath) { + throw new Error( + 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); + } else if (!(this.baseOptions.getModelAsset()?.hasFileContent() || + options.baseOptions?.modelAssetBuffer || + options.baseOptions?.modelAssetPath)) { + throw new Error( + 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); + } + + this.setAcceleration(baseOptions); + if (baseOptions.modelAssetPath) { + // We don't use `await` here since we want to apply most settings + // synchronously. + return fetch(baseOptions.modelAssetPath.toString()) + .then(response => response.arrayBuffer()) + .then(buffer => { + this.setExternalFile(new Uint8Array(buffer)); + this.refreshGraph(); + }); + } else { + // Apply the setting synchronously. + this.setExternalFile(baseOptions.modelAssetBuffer); + this.refreshGraph(); + return Promise.resolve(); } } + /** Appliest the current options to the MediaPipe graph. */ + protected abstract refreshGraph(): void; + /** * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run * over the video stream. Will replace the previously running MediaPipe graph, @@ -140,6 +180,27 @@ export abstract class TaskRunner { } this.processingErrors = []; } + + /** Configures the `externalFile` option */ + private setExternalFile(modelAssetBuffer?: Uint8Array): void { + const externalFile = this.baseOptions.getModelAsset() || new ExternalFile(); + if (modelAssetBuffer) { + externalFile.setFileContent(modelAssetBuffer); + } + this.baseOptions.setModelAsset(externalFile); + } + + /** Configures the `acceleration` option. */ + private setAcceleration(options: BaseOptions) { + const acceleration = + this.baseOptions.getAcceleration() ?? new Acceleration(); + if (options.delegate === 'GPU') { + acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); + } else { + acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); + } + this.baseOptions.setAcceleration(acceleration); + } } diff --git a/mediapipe/tasks/web/core/task_runner_test.ts b/mediapipe/tasks/web/core/task_runner_test.ts index c9aad9d25..a55ac04d7 100644 --- a/mediapipe/tasks/web/core/task_runner_test.ts +++ b/mediapipe/tasks/web/core/task_runner_test.ts @@ -15,18 +15,22 @@ */ import 'jasmine'; +// Placeholder for internal dependency on encodeByteArray import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {TaskRunner} from '../../../tasks/web/core/task_runner'; import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils'; import {ErrorListener} from '../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource URL builder import {GraphRunnerImageLib} from './task_runner'; +import {TaskRunnerOptions} from './task_runner_options.d'; class TaskRunnerFake extends TaskRunner { - protected baseOptions = new BaseOptionsProto(); private errorListener: ErrorListener|undefined; private errors: string[] = []; + baseOptions = new BaseOptionsProto(); + static createFake(): TaskRunnerFake { const wasmModule = createSpyWasmModule(); return new TaskRunnerFake(wasmModule); @@ -61,10 +65,16 @@ class TaskRunnerFake extends TaskRunner { super.finishProcessing(); } + override refreshGraph(): void {} + override setGraph(graphData: Uint8Array, isBinary: boolean): void { super.setGraph(graphData, isBinary); } + setOptions(options: TaskRunnerOptions): Promise { + return this.applyOptions(options); + } + private throwErrors(): void { expect(this.errorListener).toBeDefined(); for (const error of this.errors) { @@ -75,8 +85,38 @@ class TaskRunnerFake extends TaskRunner { } describe('TaskRunner', () => { + const mockBytes = new Uint8Array([0, 1, 2, 3]); + const mockBytesResult = { + modelAsset: { + fileContent: Buffer.from(mockBytes).toString('base64'), + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined, + }, + useStreamMode: false, + acceleration: { + xnnpack: undefined, + gpu: undefined, + tflite: {}, + }, + }; + + let fetchSpy: jasmine.Spy; + let taskRunner: TaskRunnerFake; + + beforeEach(() => { + fetchSpy = jasmine.createSpy().and.callFake(async url => { + expect(url).toEqual('foo'); + return { + arrayBuffer: () => mockBytes.buffer, + } as unknown as Response; + }); + global.fetch = fetchSpy; + + taskRunner = TaskRunnerFake.createFake(); + }); + it('handles errors during graph update', () => { - const taskRunner = TaskRunnerFake.createFake(); taskRunner.enqueueError('Test error'); expect(() => { @@ -85,7 +125,6 @@ describe('TaskRunner', () => { }); it('handles errors during graph execution', () => { - const taskRunner = TaskRunnerFake.createFake(); taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); taskRunner.enqueueError('Test error'); @@ -96,7 +135,6 @@ describe('TaskRunner', () => { }); it('can handle multiple errors', () => { - const taskRunner = TaskRunnerFake.createFake(); taskRunner.enqueueError('Test error 1'); taskRunner.enqueueError('Test error 2'); @@ -104,4 +142,106 @@ describe('TaskRunner', () => { taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); }).toThrowError(/Test error 1, Test error 2/); }); + + it('verifies that at least one model asset option is provided', () => { + expect(() => { + taskRunner.setOptions({}); + }) + .toThrowError( + /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); + }); + + it('verifies that no more than one model asset option is provided', () => { + expect(() => { + taskRunner.setOptions({ + baseOptions: { + modelAssetPath: `foo`, + modelAssetBuffer: new Uint8Array([]) + } + }); + }) + .toThrowError( + /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); + }); + + it('doesn\'t require model once it is configured', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + expect(() => { + taskRunner.setOptions({}); + }).not.toThrowError(); + }); + + it('downloads model', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetPath: `foo`}}); + + expect(fetchSpy).toHaveBeenCalled(); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('does not download model when bytes are provided', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('changes model synchronously when bytes are provided', () => { + const resolvedPromise = taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + + // Check that the change has been applied even though we do not await the + // above Promise + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + return resolvedPromise; + }); + + it('can enable CPU delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'CPU', + } + }); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('can enable GPU delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'GPU', + } + }); + expect(taskRunner.baseOptions.toObject()).toEqual({ + ...mockBytesResult, + acceleration: { + xnnpack: undefined, + gpu: { + useAdvancedGpuApi: false, + api: 0, + allowPrecisionLoss: true, + cachedKernelPath: undefined, + serializedModelDir: undefined, + modelToken: undefined, + usage: 2, + }, + tflite: undefined, + }, + }); + }); + + it('can reset delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'GPU', + } + }); + // Clear backend + await taskRunner.setOptions({baseOptions: {delegate: undefined}}); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); }); diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts index 2a1161a55..838b3f585 100644 --- a/mediapipe/tasks/web/core/task_runner_test_utils.ts +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -44,10 +44,10 @@ export function createSpyWasmModule(): SpyWasmModule { * Sets up our equality testing to use a custom float equality checking function * to avoid incorrect test results due to minor floating point inaccuracies. */ -export function addJasmineCustomFloatEqualityTester() { +export function addJasmineCustomFloatEqualityTester(tolerance = 5e-8) { jasmine.addCustomEqualityTester((a, b) => { // Custom float equality if (a === +a && b === +b && (a !== (a | 0) || b !== (b | 0))) { - return Math.abs(a - b) < 5e-8; + return Math.abs(a - b) < tolerance; } return; }); diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 62708700a..981438625 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -109,11 +109,10 @@ export class TextClassifier extends TaskRunner { * * @param options The options for the text classifier. */ - override async setOptions(options: TextClassifierOptions): Promise { - await super.setOptions(options); + override setOptions(options: TextClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } protected override get baseOptions(): BaseOptionsProto { @@ -141,7 +140,7 @@ export class TextClassifier extends TaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts index 841bf8c48..5578362cb 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts @@ -56,7 +56,8 @@ describe('TextClassifier', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); textClassifier = new TextClassifierFake(); - await textClassifier.setOptions({}); // Initialize graph + await textClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 611233e02..7aa0aa6b9 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -113,11 +113,10 @@ export class TextEmbedder extends TaskRunner { * * @param options The options for the text embedder. */ - override async setOptions(options: TextEmbedderOptions): Promise { - await super.setOptions(options); + override setOptions(options: TextEmbedderOptions): Promise { this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); + return this.applyOptions(options); } protected override get baseOptions(): BaseOptionsProto { @@ -157,7 +156,7 @@ export class TextEmbedder extends TaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(EMBEDDINGS_STREAM); diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts index 04a9b371a..2804e4deb 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts @@ -56,7 +56,8 @@ describe('TextEmbedder', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); textEmbedder = new TextEmbedderFake(); - await textEmbedder.setOptions({}); // Initialize graph + await textEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index e4ea3036f..03958a819 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -29,6 +29,7 @@ mediapipe_ts_library( testonly = True, srcs = ["vision_task_runner.test.ts"], deps = [ + ":vision_task_options", ":vision_task_runner", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/web/core:task_runner_test_utils", diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts index 6cc9ea328..d77cc4fed 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -20,6 +20,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils'; import {ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {VisionTaskOptions} from './vision_task_options'; import {VisionTaskRunner} from './vision_task_runner'; class VisionTaskRunnerFake extends VisionTaskRunner { @@ -31,6 +32,12 @@ class VisionTaskRunnerFake extends VisionTaskRunner { protected override process(): void {} + protected override refreshGraph(): void {} + + override setOptions(options: VisionTaskOptions): Promise { + return this.applyOptions(options); + } + override processImageData(image: ImageSource): void { super.processImageData(image); } @@ -41,32 +48,24 @@ class VisionTaskRunnerFake extends VisionTaskRunner { } describe('VisionTaskRunner', () => { - const streamMode = { - modelAsset: undefined, - useStreamMode: true, - acceleration: undefined, - }; - - const imageMode = { - modelAsset: undefined, - useStreamMode: false, - acceleration: undefined, - }; - let visionTaskRunner: VisionTaskRunnerFake; - beforeEach(() => { + beforeEach(async () => { visionTaskRunner = new VisionTaskRunnerFake(); + await visionTaskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('can enable image mode', async () => { await visionTaskRunner.setOptions({runningMode: 'image'}); - expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: false})); }); it('can enable video mode', async () => { await visionTaskRunner.setOptions({runningMode: 'video'}); - expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: true})); }); it('can clear running mode', async () => { @@ -74,7 +73,8 @@ describe('VisionTaskRunner', () => { // Clear running mode await visionTaskRunner.setOptions({runningMode: undefined}); - expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: false})); }); it('cannot process images with video mode', async () => { diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 3432b521b..952990326 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -22,13 +22,13 @@ import {VisionTaskOptions} from './vision_task_options'; /** Base class for all MediaPipe Vision Tasks. */ export abstract class VisionTaskRunner extends TaskRunner { /** Configures the shared options of a vision task. */ - override async setOptions(options: VisionTaskOptions): Promise { - await super.setOptions(options); + override applyOptions(options: VisionTaskOptions): Promise { if ('runningMode' in options) { const useStreamMode = !!options.runningMode && options.runningMode !== 'image'; this.baseOptions.setUseStreamMode(useStreamMode); } + return super.applyOptions(options); } /** Sends an image packet to the graph and awaits results. */ diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index b6b795076..c77f2c67a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -169,9 +169,7 @@ export class GestureRecognizer extends * * @param options The options for the gesture recognizer. */ - override async setOptions(options: GestureRecognizerOptions): Promise { - await super.setOptions(options); - + override setOptions(options: GestureRecognizerOptions): Promise { if ('numHands' in options) { this.handDetectorGraphOptions.setNumHands( options.numHands ?? DEFAULT_NUM_HANDS); @@ -221,7 +219,7 @@ export class GestureRecognizer extends ?.clearClassifierOptions(); } - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -265,12 +263,22 @@ export class GestureRecognizer extends NORM_RECT_STREAM, timestamp); this.finishProcessing(); - return { - gestures: this.gestures, - landmarks: this.landmarks, - worldLandmarks: this.worldLandmarks, - handednesses: this.handednesses - }; + if (this.gestures.length === 0) { + // If no gestures are detected in the image, just return an empty list + return { + gestures: [], + landmarks: [], + worldLandmarks: [], + handednesses: [], + }; + } else { + return { + gestures: this.gestures, + landmarks: this.landmarks, + worldLandmarks: this.worldLandmarks, + handednesses: this.handednesses + }; + } } /** Sets the default values for the graph. */ @@ -285,15 +293,19 @@ export class GestureRecognizer extends } /** Converts the proto data to a Category[][] structure. */ - private toJsCategories(data: Uint8Array[]): Category[][] { + private toJsCategories(data: Uint8Array[], populateIndex = true): + Category[][] { const result: Category[][] = []; for (const binaryProto of data) { const inputList = ClassificationList.deserializeBinary(binaryProto); const outputList: Category[] = []; for (const classification of inputList.getClassificationList()) { + const index = populateIndex && classification.hasIndex() ? + classification.getIndex()! : + DEFAULT_CATEGORY_INDEX; outputList.push({ score: classification.getScore() ?? 0, - index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX, + index, categoryName: classification.getLabel() ?? '', displayName: classification.getDisplayName() ?? '', }); @@ -342,7 +354,7 @@ export class GestureRecognizer extends } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM); @@ -377,7 +389,10 @@ export class GestureRecognizer extends }); this.graphRunner.attachProtoVectorListener( HAND_GESTURES_STREAM, binaryProto => { - this.gestures.push(...this.toJsCategories(binaryProto)); + // Gesture index is not used, because the final gesture result comes + // from multiple classifiers. + this.gestures.push( + ...this.toJsCategories(binaryProto, /* populateIndex= */ false)); }); this.graphRunner.attachProtoVectorListener( HANDEDNESS_STREAM, binaryProto => { diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts index e570270b2..323290008 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts @@ -17,6 +17,8 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; +export {Category, Landmark, NormalizedLandmark}; + /** * Represents the gesture recognition results generated by `GestureRecognizer`. */ @@ -30,6 +32,10 @@ export declare interface GestureRecognizerResult { /** Handedness of detected hands. */ handednesses: Category[][]; - /** Recognized hand gestures of detected hands */ + /** + * Recognized hand gestures of detected hands. Note that the index of the + * gesture is always -1, because the raw indices from multiple gesture + * classifiers cannot consolidate to a meaningful index. + */ gestures: Category[][]; } diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts index c0f0d1554..ee51fd32a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -109,7 +109,8 @@ describe('GestureRecognizer', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); gestureRecognizer = new GestureRecognizerFake(); - await gestureRecognizer.setOptions({}); // Initialize graph + await gestureRecognizer.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { @@ -271,7 +272,7 @@ describe('GestureRecognizer', () => { expect(gestures).toEqual({ 'gestures': [[{ 'score': 0.2, - 'index': 2, + 'index': -1, 'categoryName': 'gesture_label', 'displayName': 'gesture_display_name' }]], @@ -304,4 +305,25 @@ describe('GestureRecognizer', () => { // gestures. expect(gestures2).toEqual(gestures1); }); + + it('returns empty results when no gestures are detected', async () => { + // Pass the test data to our listener + gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(gestureRecognizer); + gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + gestureRecognizer.listeners.get('handedness')!(createHandednesses()); + gestureRecognizer.listeners.get('hand_gestures')!([]); + }); + + // Invoke the gesture recognizer + const gestures = gestureRecognizer.recognize({} as HTMLImageElement); + expect(gestures).toEqual({ + 'gestures': [], + 'landmarks': [], + 'worldLandmarks': [], + 'handednesses': [] + }); + }); }); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 2a0e8286c..24cf9a402 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -150,9 +150,7 @@ export class HandLandmarker extends VisionTaskRunner { * * @param options The options for the hand landmarker. */ - override async setOptions(options: HandLandmarkerOptions): Promise { - await super.setOptions(options); - + override setOptions(options: HandLandmarkerOptions): Promise { // Configure hand detector options. if ('numHands' in options) { this.handDetectorGraphOptions.setNumHands( @@ -173,7 +171,7 @@ export class HandLandmarker extends VisionTaskRunner { options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD); } - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -291,7 +289,7 @@ export class HandLandmarker extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts index 89f867d69..8a6d9bfa6 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts @@ -17,6 +17,8 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; +export {Landmark, NormalizedLandmark, Category}; + /** * Represents the hand landmarks deection results generated by `HandLandmarker`. */ diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts index fc26680e0..76e77b4bf 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -98,7 +98,8 @@ describe('HandLandmarker', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); handLandmarker = new HandLandmarkerFake(); - await handLandmarker.setOptions({}); // Initialize graph + await handLandmarker.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 36e7311fb..9298a860c 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -118,11 +118,10 @@ export class ImageClassifier extends VisionTaskRunner { * * @param options The options for the image classifier. */ - override async setOptions(options: ImageClassifierOptions): Promise { - await super.setOptions(options); + override setOptions(options: ImageClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -163,7 +162,7 @@ export class ImageClassifier extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts index 2041a0cef..da4a01d02 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts @@ -61,7 +61,8 @@ describe('ImageClassifier', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); imageClassifier = new ImageClassifierFake(); - await imageClassifier.setOptions({}); // Initialize graph + await imageClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 0c45ba5e7..cf0bd8c5d 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -120,11 +120,10 @@ export class ImageEmbedder extends VisionTaskRunner { * * @param options The options for the image embedder. */ - override async setOptions(options: ImageEmbedderOptions): Promise { - await super.setOptions(options); + override setOptions(options: ImageEmbedderOptions): Promise { this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -186,7 +185,7 @@ export class ImageEmbedder extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(EMBEDDINGS_STREAM); diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts index cafe0f3d8..b63bb374c 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts @@ -57,7 +57,8 @@ describe('ImageEmbedder', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); imageEmbedder = new ImageEmbedderFake(); - await imageEmbedder.setOptions({}); // Initialize graph + await imageEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index fbfaced12..e4c51de08 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -117,9 +117,7 @@ export class ObjectDetector extends VisionTaskRunner { * * @param options The options for the object detector. */ - override async setOptions(options: ObjectDetectorOptions): Promise { - await super.setOptions(options); - + override setOptions(options: ObjectDetectorOptions): Promise { // Note that we have to support both JSPB and ProtobufJS, hence we // have to expliclity clear the values instead of setting them to // `undefined`. @@ -153,7 +151,7 @@ export class ObjectDetector extends VisionTaskRunner { this.options.clearCategoryDenylistList(); } - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -226,7 +224,7 @@ export class ObjectDetector extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(DETECTIONS_STREAM); diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts index fff1a1c48..43b7035d5 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -61,7 +61,8 @@ describe('ObjectDetector', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); objectDetector = new ObjectDetectorFake(); - await objectDetector.setOptions({}); // Initialize graph + await objectDetector.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { From c8ebd21bd5698a9384d79084c2e90bf655aee9b1 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 5 Jan 2023 18:09:29 +0530 Subject: [PATCH 17/18] Updated implementation of iOS Text Classifier --- .../tasks/ios/components/containers/BUILD | 2 +- .../containers/sources/MPPCategory.h | 35 +-- .../containers/sources/MPPCategory.m | 6 +- .../sources/MPPClassificationResult.h | 68 +++-- .../sources/MPPClassificationResult.m | 18 +- .../ios/components/containers/utils/BUILD | 2 +- .../utils/sources/MPPCategory+Helpers.h | 26 +- .../utils/sources/MPPCategory+Helpers.mm | 32 +-- .../sources/MPPClassificationResult+Helpers.h | 26 +- .../MPPClassificationResult+Helpers.mm | 37 ++- mediapipe/tasks/ios/core/BUILD | 7 + .../tasks/ios/core/sources/MPPBaseOptions.h | 3 +- .../tasks/ios/core/sources/MPPBaseOptions.m | 4 +- .../ios/core/sources/MPPResultCallback.h | 21 ++ .../tasks/ios/core/sources/MPPTaskResult.h | 4 +- .../tasks/ios/core/sources/MPPTaskResult.m | 6 +- .../tasks/ios/core/sources/MPPTaskRunner.h | 46 ++- .../tasks/ios/core/sources/MPPTaskRunner.mm | 10 +- .../ios/test/{text/text_classifier => }/BUILD | 36 ++- .../tasks/ios/test/MPPTextClassifierTests.m | 110 +++++++ .../tasks/ios/test/TextClassifierTests.swift | 272 ++++++++++++++++++ .../text_classifier/MPPTextClassifierTests.m | 89 ------ mediapipe/tasks/ios/text/core/BUILD | 10 +- .../text/core/sources/MPPBaseTextTaskApi.h | 48 ---- .../text/core/sources/MPPBaseTextTaskApi.mm | 52 ---- .../ios/text/core/sources/MPPTextTaskRunner.h | 37 +++ .../text/core/sources/MPPTextTaskRunner.mm | 29 ++ mediapipe/tasks/ios/text/core/utils/BUILD | 33 --- .../tasks/ios/text/text_classifier/BUILD | 9 +- .../sources/MPPTextClassifier.h | 75 +++-- .../sources/MPPTextClassifier.mm | 61 ++-- .../sources/MPPTextClassifierOptions.h | 46 +-- .../sources/MPPTextClassifierOptions.m | 40 +-- .../sources/MPPTextClassifierResult.h | 20 +- .../sources/MPPTextClassifierResult.m | 6 +- .../ios/text/text_classifier/utils/BUILD | 3 +- .../MPPTextClassifierOptions+Helpers.h | 4 +- .../MPPTextClassifierOptions+Helpers.mm | 2 +- .../sources/MPPTextClassifierResult+Helpers.h | 10 +- .../MPPTextClassifierResult+Helpers.mm | 29 +- 40 files changed, 885 insertions(+), 489 deletions(-) create mode 100644 mediapipe/tasks/ios/core/sources/MPPResultCallback.h rename mediapipe/tasks/ios/test/{text/text_classifier => }/BUILD (56%) create mode 100644 mediapipe/tasks/ios/test/MPPTextClassifierTests.m create mode 100644 mediapipe/tasks/ios/test/TextClassifierTests.swift delete mode 100644 mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m delete mode 100644 mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h delete mode 100644 mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm create mode 100644 mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h create mode 100644 mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm delete mode 100644 mediapipe/tasks/ios/text/core/utils/BUILD diff --git a/mediapipe/tasks/ios/components/containers/BUILD b/mediapipe/tasks/ios/components/containers/BUILD index ce80571e9..9d82fc55a 100644 --- a/mediapipe/tasks/ios/components/containers/BUILD +++ b/mediapipe/tasks/ios/components/containers/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h index 431b8a705..035cde09d 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,41 +16,44 @@ NS_ASSUME_NONNULL_BEGIN -/** Encapsulates information about a class in the classification results. */ +/** Category is a util class, contains a label, its display name, a float value as score, and the + * index of the label in the corresponding label file. Typically it's used as the result of + * classification tasks. */ NS_SWIFT_NAME(ClassificationCategory) @interface MPPCategory : NSObject -/** Index of the class in the corresponding label map, usually packed in the TFLite Model - * Metadata. */ +/** The index of the label in the corresponding label file. It takes the value -1 if the index is + * not set. */ @property(nonatomic, readonly) NSInteger index; /** Confidence score for this class . */ @property(nonatomic, readonly) float score; -/** Class name of the class. */ -@property(nonatomic, readonly, nullable) NSString *label; +/** The label of this category object. */ +@property(nonatomic, readonly, nullable) NSString *categoryName; -/** Display name of the class. */ +/** The display name of the label, which may be translated for different locales. For example, a + * label, "apple", may be translated into Spanish for display purpose, so that the display name is + * "manzana". */ @property(nonatomic, readonly, nullable) NSString *displayName; /** - * Initializes a new `TFLCategory` with the given index, score, label and display name. + * Initializes a new `MPPCategory` with the given index, score, category name and display name. * - * @param index Index of the class in the corresponding label map, usually packed in the TFLite - * Model Metadata. + * @param index The index of the label in the corresponding label file. * - * @param score Confidence score for this class. + * @param score The probability score of this label category. * - * @param label Class name of the class. + * @param categoryName The label of this category object.. * - * @param displayName Display name of the class. + * @param displayName The display name of the label. * - * @return An instance of `TFLCategory` initialized with the given index, score, label and display - * name. + * @return An instance of `MPPCategory` initialized with the given index, score, category name and + * display name. */ - (instancetype)initWithIndex:(NSInteger)index score:(float)score - label:(nullable NSString *)label + categoryName:(nullable NSString *)categoryName displayName:(nullable NSString *)displayName; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m index 20f745582..824fae65e 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,13 +18,13 @@ - (instancetype)initWithIndex:(NSInteger)index score:(float)score - label:(nullable NSString *)label + categoryName:(nullable NSString *)categoryName displayName:(nullable NSString *)displayName { self = [super init]; if (self) { _index = index; _score = score; - _label = label; + _categoryName = categoryName; _displayName = displayName; } return self; diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h index 24f99bfde..732d1f899 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,32 +17,27 @@ NS_ASSUME_NONNULL_BEGIN -/** Encapsulates list of predicted classes (aka labels) for a given image classifier head. */ +/** Represents the list of classification for a given classifier head. Typically used as a result + * for classification tasks. */ NS_SWIFT_NAME(Classifications) @interface MPPClassifications : NSObject -/** - * The index of the classifier head these classes refer to. This is useful for multi-head - * models. +/** The index of the classifier head these entries refer to. This is useful for multi-head models. */ @property(nonatomic, readonly) NSInteger headIndex; -/** The name of the classifier head, which is the corresponding tensor metadata - * name. - */ -@property(nonatomic, readonly) NSString *headName; +/** The optional name of the classifier head, which is the corresponding tensor metadata name. */ +@property(nonatomic, readonly, nullable) NSString *headName; -/** The array of predicted classes, usually sorted by descending scores (e.g.from high to low - * probability). */ +/** An array of `MPPCategory` objects containing the predicted categories. */ @property(nonatomic, readonly) NSArray *categories; /** - * Initializes a new `MPPClassifications` with the given head index and array of categories. - * head name is initialized to `nil`. + * Initializes a new `MPPClassifications` object with the given head index and array of categories. + * Head name is initialized to `nil`. * - * @param headIndex The index of the image classifier head these classes refer to. - * @param categories An array of `MPPCategory` objects encapsulating a list of - * predictions usually sorted by descending scores (e.g. from high to low probability). + * @param headIndex The index of the classifier head. + * @param categories An array of `MPPCategory` objects containing the predicted categories. * * @return An instance of `MPPClassifications` initialized with the given head index and * array of categories. @@ -54,11 +49,10 @@ NS_SWIFT_NAME(Classifications) * Initializes a new `MPPClassifications` with the given head index, head name and array of * categories. * - * @param headIndex The index of the classifier head these classes refer to. + * @param headIndex The index of the classifier head. * @param headName The name of the classifier head, which is the corresponding tensor metadata * name. - * @param categories An array of `MPPCategory` objects encapsulating a list of - * predictions usually sorted by descending scores (e.g. from high to low probability). + * @param categories An array of `MPPCategory` objects containing the predicted categories. * * @return An object of `MPPClassifications` initialized with the given head index, head name and * array of categories. @@ -69,17 +63,27 @@ NS_SWIFT_NAME(Classifications) @end -/** Encapsulates results of any classification task. */ +/** + * Represents the classification results of a model. Typically used as a result for classification + * tasks. + */ NS_SWIFT_NAME(ClassificationResult) @interface MPPClassificationResult : NSObject -/** Array of MPPClassifications objects containing classifier predictions per image classifier - * head. - */ +/** An Array of `MPPClassifications` objects containing the predicted categories for each head of + * the model. */ @property(nonatomic, readonly) NSArray *classifications; +/** The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to + * these results. If it is set to the value -1, it signifies the absence of a time stamp. This is + * only used for classification on time series (e.g. audio classification). In these use cases, the + * amount of data to process might exceed the maximum size that the model can process: to solve + * this, the input data is split into multiple chunks starting at different timestamps. */ +@property(nonatomic, readonly) NSInteger timestampMs; + /** - * Initializes a new `MPPClassificationResult` with the given array of classifications. + * Initializes a new `MPPClassificationResult` with the given array of classifications. This method + * must be used when no time stamp needs to be specified. It sets the property `timestampMs` to -1. * * @param classifications An Aaray of `MPPClassifications` objects containing classifier * predictions per classifier head. @@ -89,6 +93,22 @@ NS_SWIFT_NAME(ClassificationResult) */ - (instancetype)initWithClassifications:(NSArray *)classifications; +/** + * Initializes a new `MPPClassificationResult` with the given array of classifications and time + * stamp (in milliseconds). + * + * @param classifications An Array of `MPPClassifications` objects containing the predicted + * categories for each head of the model. + * + * @param timeStampMs The timestamp (in milliseconds) of the start of the chunk of data + * corresponding to these results. + * + * @return An instance of `MPPClassificationResult` initialized with the given array of + * classifications and timestampMs. + */ +- (instancetype)initWithClassifications:(NSArray *)classifications + timestampMs:(NSInteger)timestampMs; + @end NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m index dd9c4e024..6cf75234e 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -35,16 +35,22 @@ @end -@implementation MPPClassificationResult { - NSArray *_classifications; -} +@implementation MPPClassificationResult -- (instancetype)initWithClassifications:(NSArray *)classifications { - +- (instancetype)initWithClassifications:(NSArray *)classifications + timestampMs:(NSInteger)timestampMs { self = [super init]; if (self) { _classifications = classifications; + _timestampMs = timestampMs; } + + return self; +} + +- (instancetype)initWithClassifications:(NSArray *)classifications { + return [self initWithClassifications:classifications timestampMs:-1]; + return self; } diff --git a/mediapipe/tasks/ios/components/containers/utils/BUILD b/mediapipe/tasks/ios/components/containers/utils/BUILD index a61dd6ca0..e4c76ac4b 100644 --- a/mediapipe/tasks/ios/components/containers/utils/BUILD +++ b/mediapipe/tasks/ios/components/containers/utils/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h index 874c751ac..7580cfeeb 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #include "mediapipe/framework/formats/classification.pb.h" #import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm index 24d250795..f729d9720 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" #import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h" @@ -22,11 +22,11 @@ using ClassificationProto = ::mediapipe::Classification; @implementation MPPCategory (Helpers) + (MPPCategory *)categoryWithProto:(const ClassificationProto &)clasificationProto { - NSString *label; + NSString *categoryName; NSString *displayName; if (clasificationProto.has_label()) { - label = [NSString stringWithCppString:clasificationProto.label()]; + categoryName = [NSString stringWithCppString:clasificationProto.label()]; } if (clasificationProto.has_display_name()) { @@ -35,7 +35,7 @@ using ClassificationProto = ::mediapipe::Classification; return [[MPPCategory alloc] initWithIndex:clasificationProto.index() score:clasificationProto.score() - label:label + categoryName:categoryName displayName:displayName]; } diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h index 5b19447ac..fde436feb 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm index 84d5872d7..9ad284790 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" #import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h" #import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" @@ -53,7 +53,16 @@ using ClassificationResultProto = [classifications addObject:[MPPClassifications classificationsWithProto:classifications_proto]]; } - return [[MPPClassificationResult alloc] initWithClassifications:classifications]; + MPPClassificationResult *classificationResult; + + if (classificationResultProto.has_timestamp_ms()) { + classificationResult = [[MPPClassificationResult alloc] initWithClassifications:classifications timestampMs:(NSInteger)classificationResultProto.timestamp_ms()]; + } + else { + classificationResult = [[MPPClassificationResult alloc] initWithClassifications:classifications]; + } + + return classificationResult; } @end diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index 434d20085..efe481ea8 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -90,6 +90,13 @@ objc_library( deps = [ "//mediapipe/framework:calculator_cc_proto", "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", ], ) + +objc_library( + name = "MPPResultCallback", + hdrs = ["sources/MPPResultCallback.h"], +) diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h index 9c6595cfc..088d2d5da 100644 --- a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h @@ -22,7 +22,7 @@ NS_ASSUME_NONNULL_BEGIN typedef NS_ENUM(NSUInteger, MPPDelegate) { /** CPU. */ MPPDelegateCPU, - + /** GPU. */ MPPDelegateGPU } NS_SWIFT_NAME(Delegate); @@ -46,4 +46,3 @@ NS_SWIFT_NAME(BaseOptions) @end NS_ASSUME_NONNULL_END - diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m index b2b027da7..eaf2aa895 100644 --- a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m @@ -26,10 +26,10 @@ - (id)copyWithZone:(NSZone *)zone { MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init]; - + baseOptions.modelAssetPath = self.modelAssetPath; baseOptions.delegate = self.delegate; - + return baseOptions; } diff --git a/mediapipe/tasks/ios/core/sources/MPPResultCallback.h b/mediapipe/tasks/ios/core/sources/MPPResultCallback.h new file mode 100644 index 000000000..908f9edd9 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPResultCallback.h @@ -0,0 +1,21 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +typedef void (^MPPResultCallback)(id oputput, id input, NSError *error); + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h index d15d4f258..4ee7b2fc6 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h @@ -26,11 +26,11 @@ NS_SWIFT_NAME(TaskResult) /** * Timestamp that is associated with the task result object. */ -@property(nonatomic, assign, readonly) long timestamp; +@property(nonatomic, assign, readonly) NSInteger timestampMs; - (instancetype)init NS_UNAVAILABLE; -- (instancetype)initWithTimestamp:(long)timestamp NS_DESIGNATED_INITIALIZER; +- (instancetype)initWithTimestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; @end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m index 7088eb246..6c08014ff 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m @@ -16,16 +16,16 @@ @implementation MPPTaskResult -- (instancetype)initWithTimestamp:(long)timestamp { +- (instancetype)initWithTimestampMs:(NSInteger)timestampMs { self = [super init]; if (self) { - _timestamp = timestamp; + _timestampMs = timestampMs; } return self; } - (id)copyWithZone:(NSZone *)zone { - return [[MPPTaskResult alloc] initWithTimestamp:self.timestamp]; + return [[MPPTaskResult alloc] initWithTimestampMs:self.timestampMs]; } @end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h index 2b9f2ecdb..97255234f 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -20,23 +20,63 @@ NS_ASSUME_NONNULL_BEGIN /** - * This class is used to create and call appropriate methods on the C++ Task Runner. + * This class is used to create and call appropriate methods on the C++ Task Runner to initialize, + * execute and terminate any Mediapipe task. + * + * An instance of the newly created C++ task runner will + * be stored until this class is destroyed. When methods are called for processing (performing + * inference), closing etc., on this class, internally the appropriate methods will be called on the + * C++ task runner instance to execute the appropriate actions. For each type of task, a subclass of + * this class must be defined to add any additional functionality. For eg:, vision tasks must create + * an `MPPVisionTaskRunner` and provide additional functionality. An instance of + * `MPPVisionTaskRunner` can in turn be used by the each vision task for creation and execution of + * the task. Please see the documentation for the C++ Task Runner for more details on how the taks + * runner operates. */ @interface MPPTaskRunner : NSObject /** - * Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto. + * Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto and an optional C++ + * packets callback. + * + * You can pass `nullptr` for `packetsCallback` in case the mode of operation + * requested by the user is synchronous. + * + * If the task is operating in asynchronous mode, any iOS Mediapipe task that uses the `MPPTaskRunner` + * must define a C++ callback function to obtain the results of inference asynchronously and deliver + * the results to the user. To accomplish this, callback function will in turn invoke the block + * provided by the user in the task options supplied to create the task. + * Please see the documentation of the C++ Task Runner for more information on the synchronous and + * asynchronous modes of operation. * * @param graphConfig A mediapipe task graph config proto. * - * @return An instance of `MPPTaskRunner` initialized to the given graph config proto. + * @param packetsCallback An optional C++ callback function that takes a list of output packets as + * the input argument. If provided, the callback must in turn call the block provided by the user in + * the appropriate task options. + * + * @return An instance of `MPPTaskRunner` initialized to the given graph config proto and optional + * packetsCallback. */ - (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig + packetsCallback: + (mediapipe::tasks::core::PacketsCallback)packetsCallback error:(NSError **)error NS_DESIGNATED_INITIALIZER; +/** A synchronous method for processing batch data or offline streaming data. This method is +designed for processing either batch data such as unrelated images and texts or offline streaming +data such as the decoded frames from a video file and an audio file. The call blocks the current +thread until a failure status or a successful result is returned. If the input packets have no +timestamp, an internal timestamp will be assigend per invocation. Otherwise, when the timestamp is +set in the input packets, the caller must ensure that the input packet timestamps are greater than +the timestamps of the previous invocation. This method is thread-unsafe and it is the caller's +responsibility to synchronize access to this method across multiple threads and to ensure that the +input packet timestamps are in order.*/ - (absl::StatusOr)process: (const mediapipe::tasks::core::PacketMap &)packetMap; +/** Shuts down the C++ task runner. After the runner is closed, any calls that send input data to + * the runner are illegal and will receive errors. */ - (absl::Status)close; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm index c5c307fd5..fd3f780fa 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm @@ -13,11 +13,15 @@ // limitations under the License. #import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#include "tensorflow/lite/core/api/op_resolver.h" namespace { using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::tasks::core::MediaPipeBuiltinOpResolver; using ::mediapipe::tasks::core::PacketMap; +using ::mediapipe::tasks::core::PacketsCallback; using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; } // namespace @@ -30,15 +34,17 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; @implementation MPPTaskRunner - (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig + packetsCallback:(PacketsCallback)packetsCallback error:(NSError **)error { self = [super init]; if (self) { - auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig)); + auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig), + absl::make_unique(), + std::move(packetsCallback)); if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) { return nil; } - _cppTaskRunner = std::move(taskRunnerResult.value()); } return self; diff --git a/mediapipe/tasks/ios/test/text/text_classifier/BUILD b/mediapipe/tasks/ios/test/BUILD similarity index 56% rename from mediapipe/tasks/ios/test/text/text_classifier/BUILD rename to mediapipe/tasks/ios/test/BUILD index 2202ff1a6..9df178a19 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/test/BUILD @@ -1,18 +1,15 @@ load( - "//mediapipe/tasks:ios/ios.bzl", - "MPP_TASK_MINIMUM_OS_VERSION", - "MPP_TASK_DEFAULT_TAGS", - "MPP_TASK_DISABLED_SANITIZER_TAGS", -) -load( - "@build_bazel_rules_apple//apple:ios.bzl", + "@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test", ) load( - "@org_tensorflow//tensorflow/lite:special_rules.bzl", + "@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner" ) +load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") + + package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) @@ -25,7 +22,7 @@ objc_library( "//mediapipe/tasks/testdata/text:bert_text_classifier_models", "//mediapipe/tasks/testdata/text:text_classifier_models", ], - tags = MPP_TASK_DEFAULT_TAGS, + tags = [], copts = [ "-ObjC++", "-std=c++17", @@ -38,10 +35,27 @@ objc_library( ios_unit_test( name = "MPPTextClassifierObjcTest", - minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION, + minimum_os_version = "11.0", runner = tflite_ios_lab_runner("IOS_LATEST"), - tags = MPP_TASK_DEFAULT_TAGS + MPP_TASK_DISABLED_SANITIZER_TAGS, + tags =[], deps = [ ":MPPTextClassifierObjcTestLibrary", ], ) + +swift_library( + name = "MPPTextClassifierSwiftTestLibrary", + testonly = 1, + srcs = ["TextClassifierTests.swift"], + tags = [], +) + +ios_unit_test( + name = "MPPTextClassifierSwiftTest", + minimum_os_version = "11.0", + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = [], + deps = [ + ":MPPTextClassifierSwiftTestLibrary", + ], +) diff --git a/mediapipe/tasks/ios/test/MPPTextClassifierTests.m b/mediapipe/tasks/ios/test/MPPTextClassifierTests.m new file mode 100644 index 000000000..d213fd97c --- /dev/null +++ b/mediapipe/tasks/ios/test/MPPTextClassifierTests.m @@ -0,0 +1,110 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#import + +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" + +static NSString *const kBertTextClassifierModelName = @"bert_text_classifier"; +static NSString *const kNegativeText = @"unflinchingly bleak and desperate"; +static NSString *const kPositiveText = @"it's a charming and often affecting journey"; + +#define AssertCategoriesAre(categories, expectedCategories) \ + XCTAssertEqual(categories.count, expectedCategories.count); \ + for (int i = 0; i < categories.count; i++) { \ + XCTAssertEqual(categories[i].index, expectedCategories[i].index); \ + XCTAssertEqualWithAccuracy(categories[i].score, expectedCategories[i].score, 1e-6); \ + XCTAssertEqualObjects(categories[i].categoryName, expectedCategories[i].categoryName); \ + XCTAssertEqualObjects(categories[i].displayName, expectedCategories[i].displayName); \ + } + +#define AssertHasOneHead(textClassifierResult) \ + XCTAssertNotNil(textClassifierResult); \ + XCTAssertNotNil(textClassifierResult.classificationResult); \ + XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); \ + XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0); + +@interface MPPTextClassifierTests : XCTestCase +@end + +@implementation MPPTextClassifierTests + +- (void)setUp { +} + +- (void)tearDown { + // Put teardown code here. This method is called after the invocation of each test method in the class. +} + +- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension { + NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName + ofType:extension]; + XCTAssertNotNil(filePath); + + return filePath; +} + +- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName { + NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; + MPPTextClassifierOptions *textClassifierOptions = + [[MPPTextClassifierOptions alloc] init]; + textClassifierOptions.baseOptions.modelAssetPath = modelPath; + + return textClassifierOptions; +} + +- (MPPTextClassifier *)createTextClassifierFromOptionsWithModelName:(NSString *)modelName { + MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:modelName]; + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; + XCTAssertNotNil(textClassifier); + + return textClassifier; +} + +- (void)testClassifyWithBertSucceeds { + MPPTextClassifier *textClassifier = [self createTextClassifierFromOptionsWithModelName:kBertTextClassifierModelName]; + + MPPTextClassifierResult *negativeResult = [textClassifier classifyWithText:kNegativeText error:nil]; + AssertHasOneHead(negativeResult); + + NSArray *expectedNegativeCategories = @[[[MPPCategory alloc] initWithIndex:0 + score:0.956187f + categoryName:@"negative" + displayName:nil], + [[MPPCategory alloc] initWithIndex:1 + score:0.043812f + categoryName:@"positive" + displayName:nil]]; + + AssertCategoriesAre(negativeResult.classificationResult.classifications[0].categories, + expectedNegativeCategories + ); + + // MPPTextClassifierResult *positiveResult = [textClassifier classifyWithText:kPositiveText error:nil]; + // AssertHasOneHead(positiveResult); + // NSArray *expectedPositiveCategories = @[[[MPPCategory alloc] initWithIndex:0 + // score:0.99997187f + // label:@"positive" + // displayName:nil], + // [[MPPCategory alloc] initWithIndex:1 + // score:2.8132641E-5f + // label:@"negative" + // displayName:nil]]; + // AssertCategoriesAre(negativeResult.classificationResult.classifications[0].categories, + // expectedPositiveCategories + // ); + +} +@end diff --git a/mediapipe/tasks/ios/test/TextClassifierTests.swift b/mediapipe/tasks/ios/test/TextClassifierTests.swift new file mode 100644 index 000000000..2dca9c4b5 --- /dev/null +++ b/mediapipe/tasks/ios/test/TextClassifierTests.swift @@ -0,0 +1,272 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +// import GMLImageUtils +import XCTest + +// @testable import TFLImageSegmenter + +class TextClassifierTests: XCTestCase { + + func testExample() throws { + XCTAssertEqual(1, 1) + } + + // static let bundle = Bundle(for: TextClassifierTests.self) + // static let modelPath = bundle.path( + // forResource: "deeplabv3", + // ofType: "tflite") + + // // The maximum fraction of pixels in the candidate mask that can have a + // // different class than the golden mask for the test to pass. + // let kGoldenMaskTolerance: Float = 1e-2 + + // // Magnification factor used when creating the golden category masks to make + // // them more human-friendly. Each pixel in the golden masks has its value + // // multiplied by this factor, i.e. a value of 10 means class index 1, a value of + // // 20 means class index 2, etc. + // let kGoldenMaskMagnificationFactor: UInt8 = 10 + + // let deepLabV3SegmentationWidth = 257 + + // let deepLabV3SegmentationHeight = 257 + + // func verifyDeeplabV3PartialSegmentationResult(_ coloredLabels: [ColoredLabel]) { + + // self.verifyColoredLabel( + // coloredLabels[0], + // expectedR: 0, + // expectedG: 0, + // expectedB: 0, + // expectedLabel: "background") + + // self.verifyColoredLabel( + // coloredLabels[1], + // expectedR: 128, + // expectedG: 0, + // expectedB: 0, + // expectedLabel: "aeroplane") + + // self.verifyColoredLabel( + // coloredLabels[2], + // expectedR: 0, + // expectedG: 128, + // expectedB: 0, + // expectedLabel: "bicycle") + + // self.verifyColoredLabel( + // coloredLabels[3], + // expectedR: 128, + // expectedG: 128, + // expectedB: 0, + // expectedLabel: "bird") + + // self.verifyColoredLabel( + // coloredLabels[4], + // expectedR: 0, + // expectedG: 0, + // expectedB: 128, + // expectedLabel: "boat") + + // self.verifyColoredLabel( + // coloredLabels[5], + // expectedR: 128, + // expectedG: 0, + // expectedB: 128, + // expectedLabel: "bottle") + + // self.verifyColoredLabel( + // coloredLabels[6], + // expectedR: 0, + // expectedG: 128, + // expectedB: 128, + // expectedLabel: "bus") + + // self.verifyColoredLabel( + // coloredLabels[7], + // expectedR: 128, + // expectedG: 128, + // expectedB: 128, + // expectedLabel: "car") + + // self.verifyColoredLabel( + // coloredLabels[8], + // expectedR: 64, + // expectedG: 0, + // expectedB: 0, + // expectedLabel: "cat") + + // self.verifyColoredLabel( + // coloredLabels[9], + // expectedR: 192, + // expectedG: 0, + // expectedB: 0, + // expectedLabel: "chair") + + // self.verifyColoredLabel( + // coloredLabels[10], + // expectedR: 64, + // expectedG: 128, + // expectedB: 0, + // expectedLabel: "cow") + + // self.verifyColoredLabel( + // coloredLabels[11], + // expectedR: 192, + // expectedG: 128, + // expectedB: 0, + // expectedLabel: "dining table") + + // self.verifyColoredLabel( + // coloredLabels[12], + // expectedR: 64, + // expectedG: 0, + // expectedB: 128, + // expectedLabel: "dog") + + // self.verifyColoredLabel( + // coloredLabels[13], + // expectedR: 192, + // expectedG: 0, + // expectedB: 128, + // expectedLabel: "horse") + + // self.verifyColoredLabel( + // coloredLabels[14], + // expectedR: 64, + // expectedG: 128, + // expectedB: 128, + // expectedLabel: "motorbike") + + // self.verifyColoredLabel( + // coloredLabels[15], + // expectedR: 192, + // expectedG: 128, + // expectedB: 128, + // expectedLabel: "person") + + // self.verifyColoredLabel( + // coloredLabels[16], + // expectedR: 0, + // expectedG: 64, + // expectedB: 0, + // expectedLabel: "potted plant") + + // self.verifyColoredLabel( + // coloredLabels[17], + // expectedR: 128, + // expectedG: 64, + // expectedB: 0, + // expectedLabel: "sheep") + + // self.verifyColoredLabel( + // coloredLabels[18], + // expectedR: 0, + // expectedG: 192, + // expectedB: 0, + // expectedLabel: "sofa") + + // self.verifyColoredLabel( + // coloredLabels[19], + // expectedR: 128, + // expectedG: 192, + // expectedB: 0, + // expectedLabel: "train") + + // self.verifyColoredLabel( + // coloredLabels[20], + // expectedR: 0, + // expectedG: 64, + // expectedB: 128, + // expectedLabel: "tv") + // } + + // func verifyColoredLabel( + // _ coloredLabel: ColoredLabel, + // expectedR: UInt, + // expectedG: UInt, + // expectedB: UInt, + // expectedLabel: String + // ) { + // XCTAssertEqual( + // coloredLabel.r, + // expectedR) + // XCTAssertEqual( + // coloredLabel.g, + // expectedG) + // XCTAssertEqual( + // coloredLabel.b, + // expectedB) + // XCTAssertEqual( + // coloredLabel.label, + // expectedLabel) + // } + + // func testSuccessfullInferenceOnMLImageWithUIImage() throws { + + // let modelPath = try XCTUnwrap(ImageSegmenterTests.modelPath) + + // let imageSegmenterOptions = ImageSegmenterOptions(modelPath: modelPath) + + // let imageSegmenter = + // try ImageSegmenter.segmenter(options: imageSegmenterOptions) + + // let gmlImage = try XCTUnwrap( + // MLImage.imageFromBundle( + // class: type(of: self), + // filename: "segmentation_input_rotation0", + // type: "jpg")) + // let segmentationResult: SegmentationResult = + // try XCTUnwrap(imageSegmenter.segment(mlImage: gmlImage)) + + // XCTAssertEqual(segmentationResult.segmentations.count, 1) + + // let coloredLabels = try XCTUnwrap(segmentationResult.segmentations[0].coloredLabels) + // verifyDeeplabV3PartialSegmentationResult(coloredLabels) + + // let categoryMask = try XCTUnwrap(segmentationResult.segmentations[0].categoryMask) + // XCTAssertEqual(deepLabV3SegmentationWidth, categoryMask.width) + // XCTAssertEqual(deepLabV3SegmentationHeight, categoryMask.height) + + // let goldenMaskImage = try XCTUnwrap( + // MLImage.imageFromBundle( + // class: type(of: self), + // filename: "segmentation_golden_rotation0", + // type: "png")) + + // let pixelBuffer = goldenMaskImage.grayScalePixelBuffer().takeRetainedValue() + + // CVPixelBufferLockBaseAddress(pixelBuffer, CVPixelBufferLockFlags.readOnly) + + // let pixelBufferBaseAddress = (try XCTUnwrap(CVPixelBufferGetBaseAddress(pixelBuffer))) + // .assumingMemoryBound(to: UInt8.self) + + // let numPixels = deepLabV3SegmentationWidth * deepLabV3SegmentationHeight + + // let mask = try XCTUnwrap(categoryMask.mask) + + // var inconsistentPixels: Float = 0.0 + + // for i in 0.. - -#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" - -NS_ASSUME_NONNULL_BEGIN - -static NSString *const kBertTextClassifierModelName = @"bert_text_classifier"; -static NSString *const kNegativeText = @"unflinchingly bleak and desperate"; - -#define VerifyCategory(category, expectedIndex, expectedScore, expectedLabel, expectedDisplayName) \ - XCTAssertEqual(category.index, expectedIndex); \ - XCTAssertEqualWithAccuracy(category.score, expectedScore, 1e-6); \ - XCTAssertEqualObjects(category.label, expectedLabel); \ - XCTAssertEqualObjects(category.displayName, expectedDisplayName); - -#define VerifyClassifications(classifications, expectedHeadIndex, expectedCategoryCount) \ - XCTAssertEqual(classifications.categories.count, expectedCategoryCount); - -#define VerifyClassificationResult(classificationResult, expectedClassificationsCount) \ - XCTAssertNotNil(classificationResult); \ - XCTAssertEqual(classificationResult.classifications.count, expectedClassificationsCount) - -#define AssertClassificationResultHasOneHead(classificationResult) \ - XCTAssertNotNil(classificationResult); \ - XCTAssertEqual(classificationResult.classifications.count, 1); - XCTAssertEqual(classificationResult.classifications[0].headIndex, 1); - -#define AssertTextClassifierResultIsNotNil(textClassifierResult) \ - XCTAssertNotNil(textClassifierResult); - -@interface MPPTextClassifierTests : XCTestCase -@end - -@implementation MPPTextClassifierTests - -- (void)setUp { - [super setUp]; - -} - -- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension { - NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName - ofType:extension]; - XCTAssertNotNil(filePath); - - return filePath; -} - -- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName { - NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; - MPPTextClassifierOptions *textClassifierOptions = - [[MPPTextClassifierOptions alloc] init]; - textClassifierOptions.baseOptions.modelAssetPath = modelPath; - - return textClassifierOptions; -} - -kBertTextClassifierModelName - -- (MPPTextClassifier *)createTextClassifierFromOptionsWithModelName:(NSString *)modelName { - MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:modelName]; - MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; - XCTAssertNotNil(textClassifier); - - return textClassifier -} - -- (void)classifyWithBertSucceeds { - MPPTextClassifier *textClassifier = [self createTextClassifierWithModelName:kBertTextClassifierModelName]; - MPPTextClassifierResult *textClassifierResult = [textClassifier classifyWithText:kNegativeText]; -} - -@end - -NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/core/BUILD b/mediapipe/tasks/ios/text/core/BUILD index abb8edc71..6d558b22b 100644 --- a/mediapipe/tasks/ios/text/core/BUILD +++ b/mediapipe/tasks/ios/text/core/BUILD @@ -17,17 +17,15 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) objc_library( - name = "MPPBaseTextTaskApi", - srcs = ["sources/MPPBaseTextTaskApi.mm"], - hdrs = ["sources/MPPBaseTextTaskApi.h"], + name = "MPPTextTaskRunner", + srcs = ["sources/MPPTextTaskRunner.mm"], + hdrs = ["sources/MPPTextTaskRunner.h"], copts = [ "-ObjC++", "-std=c++17", ], deps = [ - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/core:MPPTaskRunner", ], ) diff --git a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h b/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h deleted file mode 100644 index 405d25a81..000000000 --- a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ -#import - -#include "mediapipe/framework/calculator.pb.h" -#include "mediapipe/tasks/cc/core/task_runner.h" - -NS_ASSUME_NONNULL_BEGIN - -/** - * The base class of the user-facing iOS mediapipe text task api classes. - */ -NS_SWIFT_NAME(BaseTextTaskApi) -@interface MPPBaseTextTaskApi : NSObject { - @protected - std::unique_ptr cppTaskRunner; -} - -/** - * Initializes a new `MPPBaseTextTaskApi` with the mediapipe text task graph config proto. - * - * @param graphConfig A mediapipe text task graph config proto. - * - * @return An instance of `MPPBaseTextTaskApi` initialized to the given graph config proto. - */ -- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig - error:(NSError **)error; -- (void)close; - -- (instancetype)init NS_UNAVAILABLE; - -+ (instancetype)new NS_UNAVAILABLE; - -@end - -NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm b/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm deleted file mode 100644 index 5c05797da..000000000 --- a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ -#import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h" -#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" - -namespace { -using ::mediapipe::CalculatorGraphConfig; -using ::mediapipe::Packet; -using ::mediapipe::tasks::core::PacketMap; -using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; -} // namespace - -@interface MPPBaseTextTaskApi () { - /** TextSearcher backed by C++ API */ - std::unique_ptr _cppTaskRunner; -} -@end - -@implementation MPPBaseTextTaskApi - -- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig - error:(NSError **)error { - self = [super init]; - if (self) { - auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig)); - - if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) { - return nil; - } - - _cppTaskRunner = std::move(taskRunnerResult.value()); - } - return self; -} - -- (void)close { - _cppTaskRunner->Close(); -} - -@end diff --git a/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h new file mode 100644 index 000000000..dd5d96ce6 --- /dev/null +++ b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h @@ -0,0 +1,37 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * This class is used to create and call appropriate methods on the C++ Task Runner to initialize, execute and terminate any Mediapipe text task. + */ +@interface MPPTextTaskRunner : MPPTaskRunner + +/** + * Initializes a new `MPPTextTaskRunner` with the mediapipe task graph config proto. + * + * @param graphConfig A mediapipe task graph config proto. + * + * @return An instance of `MPPTextTaskRunner` initialized to the given graph config proto. + */ +- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm new file mode 100644 index 000000000..956448c17 --- /dev/null +++ b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm @@ -0,0 +1,29 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h" + +namespace { +using ::mediapipe::CalculatorGraphConfig; +} // namespace + +@implementation MPPTextTaskRunner + +- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig + error:(NSError **)error { + self = [super initWithCalculatorGraphConfig:graphConfig packetsCallback:nullptr error:error]; + return self; +} + +@end diff --git a/mediapipe/tasks/ios/text/core/utils/BUILD b/mediapipe/tasks/ios/text/core/utils/BUILD deleted file mode 100644 index abb8edc71..000000000 --- a/mediapipe/tasks/ios/text/core/utils/BUILD +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -objc_library( - name = "MPPBaseTextTaskApi", - srcs = ["sources/MPPBaseTextTaskApi.mm"], - hdrs = ["sources/MPPBaseTextTaskApi.h"], - copts = [ - "-ObjC++", - "-std=c++17", - ], - deps = [ - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", - ], -) - diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index 61eecb9cd..ce118c718 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,9 +25,11 @@ objc_library( "-std=c++17", ], deps = [ + "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", "//mediapipe/tasks/ios/core:MPPTaskOptions", "//mediapipe/tasks/ios/core:MPPTaskInfo", - "//mediapipe/tasks/ios/core:MPPTaskRunner", + "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", "//mediapipe/tasks/ios/core:MPPTextPacketCreator", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers", @@ -35,6 +37,9 @@ objc_library( "//mediapipe/tasks/ios/common/utils:NSStringHelpers", ":MPPTextClassifierOptions", ], + sdk_frameworks = [ + "MetalKit", + ], ) objc_library( diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h index 19e10e35f..ee6f25100 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -1,34 +1,61 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import -#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" NS_ASSUME_NONNULL_BEGIN /** - * A Mediapipe iOS Text Classifier. + * This API expects a TFLite model with (optional) [TFLite Model + * Metadata](https://www.tensorflow.org/lite/convert/metadata")that contains the mandatory + * (described below) input tensors, output tensor, and the optional (but recommended) label items as + * AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor. + * + * Metadata is required for models with int32 input tensors because it contains the input process + * unit for the model's Tokenizer. No metadata is required for models with string input tensors. + * + * Input tensors + * - Three input tensors `kTfLiteInt32` of shape `[batch_size xbert_max_seq_len]` + * representing the input ids, mask ids, and segment ids. This input signature requires a + * Bert Tokenizer process unit in the model metadata. + * - Or one input tensor `kTfLiteInt32` of shape `[batch_size xmax_seq_len]` representing + * the input ids. This input signature requires a Regex Tokenizer process unit in the + * model metadata. + * - Or one input tensor ({@code kTfLiteString}) that is shapeless or has shape `[1]` containing + * the input string. + * + * At least one output tensor `(kTfLiteFloat32}/kBool)` with: + * - `N` classes and shape `[1 x N]` + * - optional (but recommended) label map(s) as AssociatedFile-s with type TENSOR_AXIS_LABELS, + * containing one label per line. The first such AssociatedFile (if any) is used to fill the + * `class_name` field of the results. The `display_name` field is filled from the AssociatedFile + * (if any) whose locale matches the `display_names_locale` field of the + * `MPPTextClassifierOptions` used at creation time ("en" by default, i.e. English). If none of + * these are available, only the `index` field of the results will be filled. + * + * @brief Performs classification on text. */ NS_SWIFT_NAME(TextClassifier) @interface MPPTextClassifier : NSObject /** * Creates a new instance of `MPPTextClassifier` from an absolute path to a TensorFlow Lite model - * file stored locally on the device. + * file stored locally on the device and the default `MPPTextClassifierOptions`. * * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. * @@ -41,9 +68,11 @@ NS_SWIFT_NAME(TextClassifier) - (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; /** - * Creates a new instance of `MPPTextClassifier` from the given text classifier options. + * Creates a new instance of `MPPTextClassifier` from the given `MPPTextClassifierOptions`. + * + * @param options The options of type `MPPTextClassifierOptions` to use for configuring the + * `MPPTextClassifier`. * - * @param options The options to use for configuring the `MPPTextClassifier`. * @param error An optional error parameter populated when there is an error in initializing * the text classifier. * @@ -52,6 +81,16 @@ NS_SWIFT_NAME(TextClassifier) */ - (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error; +/** + * Performs classification on the input text. + * + * @param text The `NSString` on which classification is to be performed. + * + * @param error An optional error parameter populated when there is an error in performing + * classification on the input text. + * + * @return A `MPPTextClassifierResult` object that contains a list of text classifications. + */ - (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm index b9e76fc69..487cdad42 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -1,25 +1,27 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" -#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" -#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" -#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" +#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h" #import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "absl/status/statusor.h" @@ -37,14 +39,13 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T @interface MPPTextClassifier () { /** TextSearcher backed by C++ API */ - MPPTaskRunner *_taskRunner; + MPPTextTaskRunner *_taskRunner; } @end @implementation MPPTextClassifier - (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error { - MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] initWithTaskGraphName:kTaskGraphName inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ] @@ -58,10 +59,11 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T return nil; } - _taskRunner = [[MPPTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error]; - + _taskRunner = + [[MPPTextTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] + error:error]; self = [super init]; - + return self; } @@ -76,14 +78,23 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T - (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error { Packet packet = [MPPTextPacketCreator createWithText:text]; - absl::StatusOr output_packet_map = [_taskRunner process:{{kTextInStreamName.cppString, packet}} error:error]; - if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) { + std::map packet_map = {{kTextInStreamName.cppString, packet}}; + absl::StatusOr status_or_output_packet_map = [_taskRunner process:packet_map]; + + if (![MPPCommonUtils checkCppError:status_or_output_packet_map.status() toError:error]) { return nil; } + Packet classifications_packet = + status_or_output_packet_map.value()[kClassificationsStreamName.cppString]; + return [MPPTextClassifierResult - textClassifierResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString] - .Get()]; + textClassifierResultWithClassificationsPacket:status_or_output_packet_map.value() + [kClassificationsStreamName.cppString]]; + + // return [MPPTextClassifierResult + // textClassifierResultWithClassificationsPacket:output_packet_map.value()[kClassificationsStreamName.cppString] + // .Get()]; } @end diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h index 374226998..25189578b 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import #import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" @@ -20,32 +20,16 @@ NS_ASSUME_NONNULL_BEGIN /** - * Options to configure MPPTextClassifierOptions. + * Options for setting up a `MPPTextClassifierOptions`. */ NS_SWIFT_NAME(TextClassifierOptions) @interface MPPTextClassifierOptions : MPPTaskOptions /** - * Options controlling the behavior of the embedding model specified in the - * base options. + * Options for configuring the classifier behavior, such as score threshold, number of results, etc. */ @property(nonatomic, copy) MPPClassifierOptions *classifierOptions; -// /** -// * Initializes a new `MPPTextClassifierOptions` with the absolute path to the model file -// * stored locally on the device, set to the given the model path. -// * -// * @discussion The external model file must be a single standalone TFLite file. It could be packed -// * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the -// * necessary metadata and associated files might result in errors. Check the [documentation] -// * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement. -// * -// * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. -// * -// * @return An instance of `MPPTextClassifierOptions` initialized to the given model path. -// */ -// - (instancetype)initWithModelPath:(NSString *)modelPath; - @end NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m index 82e9bed64..8d4ffd36f 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m @@ -1,27 +1,27 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" @implementation MPPTextClassifierOptions -// - (instancetype)initWithModelPath:(NSString *)modelPath { -// self = [super initWithModelPath:modelPath]; -// if (self) { -// _classifierOptions = [[MPPClassifierOptions alloc] init]; -// } -// return self; -// } +- (instancetype)init { + self = [super init]; + if (self) { + _classifierOptions = [[MPPClassifierOptions alloc] init]; + } + return self; +} @end \ No newline at end of file diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h index 414e6d9c6..6926757e4 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,23 +18,27 @@ NS_ASSUME_NONNULL_BEGIN -/** Encapsulates list of predicted classes (aka labels) for a given image classifier head. */ +/** Represents the classification results generated by `MPPTextClassifier`. */ NS_SWIFT_NAME(TextClassifierResult) @interface MPPTextClassifierResult : MPPTaskResult +/** The `MPPClassificationResult` instance containing one set of results per classifier head. */ @property(nonatomic, readonly) MPPClassificationResult *classificationResult; /** - * Initializes a new `MPPClassificationResult` with the given array of classifications. + * Initializes a new `MPPTextClassifierResult` with the given `MPPClassificationResult` and time + * stamp (in milliseconds). * - * @param classifications An Aaray of `MPPClassifications` objects containing classifier - * predictions per classifier head. + * @param classificationResult The `MPPClassificationResult` instance containing one set of results + * per classifier head. * - * @return An instance of MPPClassificationResult initialized with the given array of - * classifications. + * @param timeStampMs The time stamp for this result. + * + * @return An instance of `MPPTextClassifierResult` initialized with the given + * `MPPClassificationResult` and time stamp (in milliseconds). */ - (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult - timeStamp:(long)timeStamp; + timestampMs:(NSInteger)timestampMs; @end diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m index b99ee3b19..4d5c1104a 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,8 +17,8 @@ @implementation MPPTextClassifierResult - (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult - timeStamp:(long)timeStamp { - self = [super initWithTimestamp:timeStamp]; + timestampMs:(NSInteger)timestampMs { + self = [super initWithTimestampMs:timestampMs]; if (self) { _classificationResult = classificationResult; } diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD index d6a371137..abc1fc23b 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,5 +36,6 @@ objc_library( deps = [ "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierResult", "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", + "//mediapipe/framework:packet", ], ) diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h index 0771eafce..1e52e5c87 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" NS_ASSUME_NONNULL_BEGIN diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm index aa11384d2..728000b44 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h index d3fb04d69..f1b728b0a 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,16 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" +#include "mediapipe/framework/packet.h" + NS_ASSUME_NONNULL_BEGIN @interface MPPTextClassifierResult (Helpers) -+ (MPPTextClassifierResult *)textClassifierResultWithProto: - (const mediapipe::tasks::components::containers::proto::ClassificationResult &) - classificationResultProto; ++ (MPPTextClassifierResult *)textClassifierResultWithClassificationsPacket: + (const mediapipe::Packet &)packet; @end diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm index 2fc2d751d..f5d6aa1d3 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,28 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" #import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" + +static const int kMicroSecondsPerMilliSecond = 1000; namespace { using ClassificationResultProto = ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::Packet; } // namespace +#define int kMicroSecondsPerMilliSecond = 1000; + @implementation MPPTextClassifierResult (Helpers) -+ (MPPTextClassifierResult *)textClassifierResultWithProto: - (const ClassificationResultProto &)classificationResultProto { - long timeStamp; ++ (MPPTextClassifierResult *)textClassifierResultWithClassificationsPacket:(const Packet &)packet { + MPPClassificationResult *classificationResult = [MPPClassificationResult + classificationResultWithProto:packet.Get()]; - if (classificationResultProto.has_timestamp_ms()) { - timeStamp = classificationResultProto.timestamp_ms(); - } - - MPPClassificationResult *classificationResult = [MPPClassificationResult classificationResultWithProto:classificationResultProto]; - - return [[MPPTextClassifierResult alloc] initWithClassificationResult:classificationResult - timeStamp:timeStamp]; + return [[MPPTextClassifierResult alloc] + initWithClassificationResult:classificationResult + timestampMs:(NSInteger)(packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond)]; } @end From 3d634a48e3f6ad9ea5982d0fd5e947b762dd1852 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 5 Jan 2023 18:23:32 +0530 Subject: [PATCH 18/18] Removed comments --- .../ios/text/text_classifier/sources/MPPTextClassifier.mm | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm index 487cdad42..31dd69413 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -85,16 +85,9 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T return nil; } - Packet classifications_packet = - status_or_output_packet_map.value()[kClassificationsStreamName.cppString]; - return [MPPTextClassifierResult textClassifierResultWithClassificationsPacket:status_or_output_packet_map.value() [kClassificationsStreamName.cppString]]; - - // return [MPPTextClassifierResult - // textClassifierResultWithClassificationsPacket:output_packet_map.value()[kClassificationsStreamName.cppString] - // .Get()]; } @end