Updated text classifier to use task manager

This commit is contained in:
Prianka Liz Kariat 2022-12-14 18:59:56 +05:30
parent 96247ccce4
commit 781f7adf26

View File

@ -13,11 +13,11 @@
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" #import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h" #import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" #import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h"
@ -35,9 +35,16 @@ static NSString *const kTextInStreamName = @"text_in";
static NSString *const kTextTag = @"TEXT"; static NSString *const kTextTag = @"TEXT";
static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph"; static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
@interface MPPTextClassifier () {
/** TextSearcher backed by C++ API */
MPPTaskManager *_taskManager;
}
@end
@implementation MPPTextClassifier @implementation MPPTextClassifier
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error { - (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error {
MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc]
initWithTaskGraphName:kTaskGraphName initWithTaskGraphName:kTaskGraphName
inputStreams:@[ [NSString stringWithFormat:@"@:@", kTextTag, kTextInStreamName] ] inputStreams:@[ [NSString stringWithFormat:@"@:@", kTextTag, kTextInStreamName] ]
@ -51,7 +58,11 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
return nil; return nil;
} }
return [super initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error]; _taskManager = [[MPPTaskManager alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error];
self = [super init];
return self;
} }
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { - (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error {
@ -61,11 +72,10 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
return [self initWithOptions:options error:error]; return [self initWithOptions:options error:error];
} }
- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error { - (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error {
Packet packet = [MPPPacketCreator createWithText:text]; Packet packet = [MPPPacketCreator createWithText:text];
absl::StatusOr<PacketMap> output_packet_map =
cppTaskRunner->Process({{kTextInStreamName.cppString, packet}});
absl::StatusOr<PacketMap> output_packet_map = [_taskManager process:{{kTextInStreamName.cppString, packet}} error:error];
if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) { if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) {
return nil; return nil;
} }