From 781f7adf26d21df1b40beb6bb7518f886e7c9645 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 14 Dec 2022 18:59:56 +0530 Subject: [PATCH] Updated text classifier to use task manager --- .../sources/MPPTextClassifier.mm | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm index e61e6998d..b4cd66f70 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -13,11 +13,11 @@ limitations under the License. ==============================================================================*/ #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" - #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" #import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" #import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" #import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" @@ -35,9 +35,16 @@ static NSString *const kTextInStreamName = @"text_in"; static NSString *const kTextTag = @"TEXT"; static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph"; +@interface MPPTextClassifier () { + /** TextSearcher backed by C++ API */ + MPPTaskManager *_taskManager; +} +@end + @implementation MPPTextClassifier - (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] initWithTaskGraphName:kTaskGraphName inputStreams:@[ [NSString stringWithFormat:@"@:@", kTextTag, kTextInStreamName] ] @@ -51,7 +58,11 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T return nil; } - return [super initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error]; + _taskManager = [[MPPTaskManager alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error]; + + self = [super init]; + + return self; } - (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { @@ -61,11 +72,10 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T return [self initWithOptions:options error:error]; } -- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error { +- (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error { Packet packet = [MPPPacketCreator createWithText:text]; - absl::StatusOr output_packet_map = - cppTaskRunner->Process({{kTextInStreamName.cppString, packet}}); + absl::StatusOr output_packet_map = [_taskManager process:{{kTextInStreamName.cppString, packet}} error:error]; if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) { return nil; }