From d588f73a6d84d1ef5c36da72442b0db5a01b85dd Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 1 Feb 2023 18:51:30 +0530 Subject: [PATCH] Added MPPTextEmbedder --- mediapipe/tasks/ios/text/text_embedder/BUILD | 26 +++++ .../text_embedder/sources/MPPTextEmbedder.h | 91 ++++++++++++++++++ .../text_embedder/sources/MPPTextEmbedder.mm | 96 +++++++++++++++++++ 3 files changed, 213 insertions(+) create mode 100644 mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h create mode 100644 mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm diff --git a/mediapipe/tasks/ios/text/text_embedder/BUILD b/mediapipe/tasks/ios/text/text_embedder/BUILD index 143f0a587..21226b012 100644 --- a/mediapipe/tasks/ios/text/text_embedder/BUILD +++ b/mediapipe/tasks/ios/text/text_embedder/BUILD @@ -32,3 +32,29 @@ objc_library( "//mediapipe/tasks/ios/core:MPPTaskResult", ], ) + +objc_library( + name = "MPPTextEmbedder", + srcs = ["sources/MPPTextEmbedder.mm"], + hdrs = ["sources/MPPTextEmbedder.h"], + copts = [ + "-ObjC++", + "-std=c++17", + "-x objective-c++", + ], + module_name = "MPPTextEmbedder", + deps = [ + ":MPPTextEmbedderOptions", + ":MPPTextEmbedderResult", + "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/core:MPPTaskInfo", + "//mediapipe/tasks/ios/core:MPPTaskOptions", + "//mediapipe/tasks/ios/core:MPPTextPacketCreator", + "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", + "//mediapipe/tasks/ios/text/text_embedder/utils:MPPTextEmbedderOptionsHelpers", + "//mediapipe/tasks/ios/text/text_embedder/utils:MPPTextEmbedderResultHelpers", + "@com_google_absl//absl/status:statusor", + ], +) diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h new file mode 100644 index 000000000..d1deb60ed --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h @@ -0,0 +1,91 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderOptions.h" +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * @brief Performs embedding extraction on text. + * + * This API expects a TFLite model with (optional) [TFLite Model Metadata](https://www.tensorflow.org/lite/convert/metadata"). + * + * Metadata is required for models with int32 input tensors because it contains the input process + * unit for the model's Tokenizer. No metadata is required for models with string input tensors. + * + * Input tensors + * - Three input tensors `kTfLiteInt32` of shape `[batch_size x bert_max_seq_len]` + * representing the input ids, mask ids, and segment ids. This input signature requires + * a Bert Tokenizer process unit in the model metadata. + * - Or one input tensor `kTfLiteInt32` of shape `[batch_size x max_seq_len]` representing + * the input ids. This input signature requires a Regex Tokenizer process unit in the + * model metadata. + * - Or one input tensor (`kTfLiteString`) that is shapeless or has shape `[1]` containing + * the input string. + * + * At least one output tensor (`kTfLiteFloat32`/`kTfLiteUint8`) with shape `[1 x N]` where `N` is the number of dimensions in the produced embeddings. + */ +NS_SWIFT_NAME(TextEmbedder) +@interface MPPTextEmbedder : NSObject + +/** + * Creates a new instance of `MPPTextEmbedder` from an absolute path to a TensorFlow Lite + * model file stored locally on the device and the default `MPPTextEmbedderOptions`. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * @param error An optional error parameter populated when there is an error in initializing the + * text embedder. + * + * @return A new instance of `MPPTextEmbedder` with the given model path. `nil` if there is an + * error in initializing the text embedder. + */ +- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; + +/** + * Creates a new instance of `MPPTextEmbedder` from the given `MPPTextEmbedderOptions`. + * + * @param options The options of type `MPPTextEmbedderOptions` to use for configuring the + * `MPPTextEmbedder. + * @param error An optional error parameter populated when there is an error in initializing the + * text embedder. + * + * @return A new instance of `MPPTextEmbedder` with the given options. `nil` if there is an + * error in initializing the text embedder. + */ +- (nullable instancetype)initWithOptions:(MPPTextEmbedderOptions *)options + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +/** + * Performs embedding extraction on the input text. + * + * @param text The `NSString` on which embedding extraction is to be performed. + * @param error An optional error parameter populated when there is an error in performing + * embedding extraction on the input text. + * + * @return A `MPPTextEmbedderResult` object that contains a list of embeddings. + */ +- (nullable MPPTextEmbedderResult *)embedText:(NSString *)text + error:(NSError **)error NS_SWIFT_NAME(embed(text:)); + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm new file mode 100644 index 000000000..395ce28f6 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm @@ -0,0 +1,96 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h" + +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" +#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h" +#import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.h" +#import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h" + +#include "absl/status/statusor.h" + +namespace { +using ::mediapipe::Packet; +using ::mediapipe::tasks::core::PacketMap; +} // namespace + +static NSString *const kEmbeddingsOutStreamName = @"embeddings_out"; +static NSString *const kEmbeddingsTag = @"EMBEDDINGS"; +static NSString *const kTextInStreamName = @"text_in"; +static NSString *const kTextTag = @"TEXT"; +static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_embedder.TextEmbedderGraph"; + +@interface MPPTextEmbedder () { + /** iOS Text Task Runner */ + MPPTextTaskRunner *_textTaskRunner; +} +@end + +@implementation MPPTextEmbedder + +- (instancetype)initWithOptions:(MPPTextEmbedderOptions *)options error:(NSError **)error { + self = [super init]; + if (self) { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] + initWithTaskGraphName:kTaskGraphName + inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ] + outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kEmbeddingsTag, + kEmbeddingsOutStreamName] ] + taskOptions:options + enableFlowLimiting:NO + error:error]; + + if (!taskInfo) { + return nil; + } + + _textTaskRunner = + [[MPPTextTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] + error:error]; + + if (!_textTaskRunner) { + return nil; + } + } + return self; +} + +- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { + MPPTextEmbedderOptions *options = [[MPPTextEmbedderOptions alloc] init]; + + options.baseOptions.modelAssetPath = modelPath; + + return [self initWithOptions:options error:error]; +} + +- (nullable MPPTextEmbedderResult *)embedText:(NSString *)text error:(NSError **)error { + Packet packet = [MPPTextPacketCreator createWithText:text]; + + std::map packetMap = {{kTextInStreamName.cppString, packet}}; + absl::StatusOr statusOrOutputPacketMap = [_textTaskRunner process:packetMap]; + + if (![MPPCommonUtils checkCppError:statusOrOutputPacketMap.status() toError:error]) { + return nil; + } + + return [MPPTextEmbedderResult + textEmbedderResultWithOutputPacket:statusOrOutputPacketMap.value() + [kEmbeddingsOutStreamName.cppString]]; +} + +@end