Updated text classifier to use task manager
This commit is contained in:
parent
96247ccce4
commit
781f7adf26
|
@ -13,11 +13,11 @@
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
|
||||
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
||||
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h"
|
||||
|
||||
|
@ -35,9 +35,16 @@ static NSString *const kTextInStreamName = @"text_in";
|
|||
static NSString *const kTextTag = @"TEXT";
|
||||
static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||
|
||||
@interface MPPTextClassifier () {
|
||||
/** TextSearcher backed by C++ API */
|
||||
MPPTaskManager *_taskManager;
|
||||
}
|
||||
@end
|
||||
|
||||
@implementation MPPTextClassifier
|
||||
|
||||
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error {
|
||||
|
||||
MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc]
|
||||
initWithTaskGraphName:kTaskGraphName
|
||||
inputStreams:@[ [NSString stringWithFormat:@"@:@", kTextTag, kTextInStreamName] ]
|
||||
|
@ -51,7 +58,11 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
|
|||
return nil;
|
||||
}
|
||||
|
||||
return [super initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error];
|
||||
_taskManager = [[MPPTaskManager alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error];
|
||||
|
||||
self = [super init];
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error {
|
||||
|
@ -61,11 +72,10 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
|
|||
return [self initWithOptions:options error:error];
|
||||
}
|
||||
|
||||
- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error {
|
||||
- (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error {
|
||||
Packet packet = [MPPPacketCreator createWithText:text];
|
||||
absl::StatusOr<PacketMap> output_packet_map =
|
||||
cppTaskRunner->Process({{kTextInStreamName.cppString, packet}});
|
||||
|
||||
absl::StatusOr<PacketMap> output_packet_map = [_taskManager process:{{kTextInStreamName.cppString, packet}} error:error];
|
||||
if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) {
|
||||
return nil;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user