diff --git a/mediapipe/tasks/ios/text/language_detector/BUILD b/mediapipe/tasks/ios/text/language_detector/BUILD index 3b59fbd59..4df278037 100644 --- a/mediapipe/tasks/ios/text/language_detector/BUILD +++ b/mediapipe/tasks/ios/text/language_detector/BUILD @@ -31,3 +31,28 @@ objc_library( "//mediapipe/tasks/ios/core:MPPTaskResult", ], ) + +objc_library( + name = "MPPLanguageDetector", + srcs = ["sources/MPPLanguageDetector.mm"], + hdrs = ["sources/MPPLanguageDetector.h"], + copts = [ + "-ObjC++", + "-std=c++17", + "-x objective-c++", + ], + module_name = "MPPLanguageDetector", + deps = [ + ":MPPLanguageDetectorOptions", + ":MPPLanguageDetectorResult", + "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/core:MPPTaskInfo", + "//mediapipe/tasks/ios/core:MPPTaskOptions", + "//mediapipe/tasks/ios/core:MPPTextPacketCreator", + "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", + "//mediapipe/tasks/ios/text/language_detector/utils:MPPLanguageDetectorOptionsHelpers", + "//mediapipe/tasks/ios/text/language_detector/utils:MPPLanguageDetectorResultHelpers", + ], +) diff --git a/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetector.h b/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetector.h new file mode 100644 index 000000000..7213a8e5f --- /dev/null +++ b/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetector.h @@ -0,0 +1,88 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorOptions.h" +#import "mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * @brief Predicts the language of an input text. + * + * This API expects a TFLite model with [TFLite Model + * Metadata](https://www.tensorflow.org/lite/convert/metadata")that contains the mandatory + * (described below) input tensor, output tensor, and the language codes in an AssociatedFile. + * + * Metadata is required for models with int32 input tensors because it contains the input + * process unit for the model's Tokenizer. No metadata is required for models with string + * input tensors. + * + * Input tensor + * - One input tensor (`kTfLiteString`) of shape `[1]` containing the input string. + * + * Output tensor + * - One output tensor (`kTfLiteFloat32`) of shape `[1 x N]` where `N` is the number of languages. + */ +NS_SWIFT_NAME(LanguageDetector) +@interface MPPLanguageDetector : NSObject + +/** + * Creates a new instance of `LanguageDetector` from an absolute path to a TensorFlow Lite + * model file stored locally on the device and the default `LanguageDetectorOptions`. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * @param error An optional error parameter populated when there is an error in initializing the + * language detector. + * + * @return A new instance of `LanguageDetector` with the given model path. `nil` if there is an + * error in initializing the language detector. + */ +- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; + +/** + * Creates a new instance of `LanguageDetector` from the given `LanguageDetectorOptions`. + * + * @param options The options of type `LanguageDetectorOptions` to use for configuring the + * `LanguageDetector`. + * @param error An optional error parameter populated when there is an error in initializing the + * language detector. + * + * @return A new instance of `LanguageDetector` with the given options. `nil` if there is an + * error in initializing the language detector. + */ +- (nullable instancetype)initWithOptions:(MPPLanguageDetectorOptions *)options + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +/** + * Predicts the language of the input text. + * + * @param text The `NSString` for which language is to be predicted. + * @param error An optional error parameter populated when there is an error in performing + * language prediction on the input text. + * + * @return A `LanguageDetectorResult` object that contains a list of language predictions. + */ +- (nullable MPPLanguageDetectorResult *)detectText:(NSString *)text + error:(NSError **)error NS_SWIFT_NAME(detect(text:)); + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetector.mm b/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetector.mm new file mode 100644 index 000000000..4c9628c82 --- /dev/null +++ b/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetector.mm @@ -0,0 +1,96 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetector.h" + +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" +#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h" +#import "mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorOptions+Helpers.h" +#import "mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorResult+Helpers.h" + +namespace { +using ::mediapipe::Packet; +using ::mediapipe::tasks::core::PacketMap; +} // namespace + +static NSString *const kClassificationsStreamName = @"classifications_out"; +static NSString *const kClassificationsTag = @"CLASSIFICATIONS"; +static NSString *const kTextInStreamName = @"text_in"; +static NSString *const kTextTag = @"TEXT"; +static NSString *const kTaskGraphName = + @"mediapipe.tasks.text.language_detector.LanguageDetectorGraph"; + +@interface MPPLanguageDetector () { + /** iOS Text Task Runner */ + MPPTextTaskRunner *_textTaskRunner; +} +@end + +@implementation MPPLanguageDetector + +- (instancetype)initWithOptions:(MPPLanguageDetectorOptions *)options error:(NSError **)error { + self = [super init]; + if (self) { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] + initWithTaskGraphName:kTaskGraphName + inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ] + outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kClassificationsTag, + kClassificationsStreamName] ] + taskOptions:options + enableFlowLimiting:NO + error:error]; + + if (!taskInfo) { + return nil; + } + + _textTaskRunner = + [[MPPTextTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] + error:error]; + + if (!_textTaskRunner) { + return nil; + } + } + return self; +} + +- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { + MPPLanguageDetectorOptions *options = [[MPPLanguageDetectorOptions alloc] init]; + + options.baseOptions.modelAssetPath = modelPath; + + return [self initWithOptions:options error:error]; +} + +- (nullable MPPLanguageDetectorResult *)detectText:(NSString *)text error:(NSError **)error { + Packet packet = [MPPTextPacketCreator createWithText:text]; + + std::map packetMap = {{kTextInStreamName.cppString, packet}}; + std::optional outputPacketMap = [_textTaskRunner processPacketMap:packetMap + error:error]; + + if (!outputPacketMap.has_value()) { + return nil; + } + + return + [MPPLanguageDetectorResult languageDetectorResultWithClassificationsPacket: + outputPacketMap.value()[kClassificationsStreamName.cppString]]; +} + +@end