Updated text classifier to use task manager
This commit is contained in:
parent
96247ccce4
commit
781f7adf26
|
@ -13,11 +13,11 @@
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
|
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
|
||||||
|
|
||||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||||
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||||
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
|
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h"
|
#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h"
|
||||||
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h"
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
||||||
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h"
|
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h"
|
||||||
|
|
||||||
|
@ -35,9 +35,16 @@ static NSString *const kTextInStreamName = @"text_in";
|
||||||
static NSString *const kTextTag = @"TEXT";
|
static NSString *const kTextTag = @"TEXT";
|
||||||
static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||||
|
|
||||||
|
@interface MPPTextClassifier () {
|
||||||
|
/** TextSearcher backed by C++ API */
|
||||||
|
MPPTaskManager *_taskManager;
|
||||||
|
}
|
||||||
|
@end
|
||||||
|
|
||||||
@implementation MPPTextClassifier
|
@implementation MPPTextClassifier
|
||||||
|
|
||||||
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error {
|
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error {
|
||||||
|
|
||||||
MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc]
|
MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc]
|
||||||
initWithTaskGraphName:kTaskGraphName
|
initWithTaskGraphName:kTaskGraphName
|
||||||
inputStreams:@[ [NSString stringWithFormat:@"@:@", kTextTag, kTextInStreamName] ]
|
inputStreams:@[ [NSString stringWithFormat:@"@:@", kTextTag, kTextInStreamName] ]
|
||||||
|
@ -51,7 +58,11 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
|
||||||
return nil;
|
return nil;
|
||||||
}
|
}
|
||||||
|
|
||||||
return [super initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error];
|
_taskManager = [[MPPTaskManager alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error];
|
||||||
|
|
||||||
|
self = [super init];
|
||||||
|
|
||||||
|
return self;
|
||||||
}
|
}
|
||||||
|
|
||||||
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error {
|
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error {
|
||||||
|
@ -61,11 +72,10 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
|
||||||
return [self initWithOptions:options error:error];
|
return [self initWithOptions:options error:error];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error {
|
- (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error {
|
||||||
Packet packet = [MPPPacketCreator createWithText:text];
|
Packet packet = [MPPPacketCreator createWithText:text];
|
||||||
absl::StatusOr<PacketMap> output_packet_map =
|
|
||||||
cppTaskRunner->Process({{kTextInStreamName.cppString, packet}});
|
|
||||||
|
|
||||||
|
absl::StatusOr<PacketMap> output_packet_map = [_taskManager process:{{kTextInStreamName.cppString, packet}} error:error];
|
||||||
if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) {
|
if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) {
|
||||||
return nil;
|
return nil;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user