Added APIs for text classification

This commit is contained in:
Prianka Liz Kariat 2022-12-01 09:13:05 +05:30
parent 683c2b1f09
commit 8d9c1b8a0f
25 changed files with 450 additions and 16 deletions

View File

@ -23,6 +23,9 @@ objc_library(
deps = [ deps = [
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/ios/common:MPPCommon", "//mediapipe/tasks/ios/common:MPPCommon",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
], ],
) )

View File

@ -13,8 +13,16 @@
// limitations under the License. // limitations under the License.
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h" #import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
#include <string>
#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/cord.h" // from @com_google_absl
#include "mediapipe/tasks/cc/common.h"
/** Error domain of TensorFlow Lite Support related errors. */ /** Error domain of TensorFlow Lite Support related errors. */
NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks";
@ -60,8 +68,8 @@ NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks";
if (status.ok()) { if (status.ok()) {
return YES; return YES;
} }
// Payload of absl::Status created by the Media Pipe task library stores an appropriate value of the // Payload of absl::Status created by the Media Pipe task library stores an appropriate value of
// enum MPPiteSupportStatus. The integer value corresponding to the MPPiteSupportStatus enum // the enum MPPiteSupportStatus. The integer value corresponding to the MPPiteSupportStatus enum
// stored in the payload is extracted here to later map to the appropriate error code to be // stored in the payload is extracted here to later map to the appropriate error code to be
// returned. In cases where the enum is not stored in (payload is NULL or the payload string // returned. In cases where the enum is not stored in (payload is NULL or the payload string
// cannot be converted to an integer), we set the error code value to be 1 // cannot be converted to an integer), we set the error code value to be 1

View File

@ -22,6 +22,8 @@ NS_ASSUME_NONNULL_BEGIN
@property(readonly) std::string cppString; @property(readonly) std::string cppString;
+ (NSString *)stringWithCppString:(std::string)text;
@end @end
NS_ASSUME_NONNULL_END NS_ASSUME_NONNULL_END

View File

@ -20,4 +20,8 @@
return std::string(self.UTF8String, [self lengthOfBytesUsingEncoding:NSUTF8StringEncoding]); return std::string(self.UTF8String, [self lengthOfBytesUsingEncoding:NSUTF8StringEncoding]);
} }
+ (NSString *)stringWithCppString:(std::string)text {
return [NSString stringWithCString:text.c_str() encoding:[NSString defaultCStringEncoding]];
}
@end @end

View File

@ -28,5 +28,6 @@ objc_library(
hdrs = ["sources/MPPClassificationResult.h"], hdrs = ["sources/MPPClassificationResult.h"],
deps = [ deps = [
":MPPCategory", ":MPPCategory",
"//mediapipe/tasks/ios/core:MPPTaskResult",
], ],
) )

View File

@ -18,7 +18,7 @@ NS_ASSUME_NONNULL_BEGIN
/** Encapsulates information about a class in the classification results. */ /** Encapsulates information about a class in the classification results. */
NS_SWIFT_NAME(ClassificationCategory) NS_SWIFT_NAME(ClassificationCategory)
@interface TFLCategory : NSObject @interface MPPCategory : NSObject
/** Index of the class in the corresponding label map, usually packed in the TFLite Model /** Index of the class in the corresponding label map, usually packed in the TFLite Model
* Metadata. */ * Metadata. */

View File

@ -12,9 +12,9 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#import "mediapipe/tasks/ios/components/containers/sources/TFLCategory.h" #import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h"
@implementation TFLCategory @implementation MPPCategory
- (instancetype)initWithIndex:(NSInteger)index - (instancetype)initWithIndex:(NSInteger)index
score:(float)score score:(float)score

View File

@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/task/components/containers/sources/MPPCategory.h" #import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
@ -71,7 +72,7 @@ NS_SWIFT_NAME(Classifications)
/** Encapsulates results of any classification task. */ /** Encapsulates results of any classification task. */
NS_SWIFT_NAME(ClassificationResult) NS_SWIFT_NAME(ClassificationResult)
@interface MPPClassificationResult : NSObject @interface MPPClassificationResult : MPPTaskResult
/** Array of MPPClassifications objects containing classifier predictions per image classifier /** Array of MPPClassifications objects containing classifier predictions per image classifier
* head. * head.
@ -87,7 +88,8 @@ NS_SWIFT_NAME(ClassificationResult)
* @return An instance of MPPClassificationResult initialized with the given array of * @return An instance of MPPClassificationResult initialized with the given array of
* classifications. * classifications.
*/ */
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications; - (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
timeStamp:(long)timeStamp;
@end @end

View File

@ -39,8 +39,9 @@ limitations under the License.
NSArray<MPPClassifications *> *_classifications; NSArray<MPPClassifications *> *_classifications;
} }
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications { - (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
self = [super init]; timeStamp:(long)timeStamp {
self = [super initWithTimeStamp:timeStamp];
if (self) { if (self) {
_classifications = classifications; _classifications = classifications;
} }

View File

@ -0,0 +1,40 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPCategoryHelpers",
srcs = ["sources/MPPCategory+Helpers.mm"],
hdrs = ["sources/MPPCategory+Helpers.h"],
deps = [
"//mediapipe/tasks/ios/components/containers:MPPCategory",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
],
)
objc_library(
name = "MPPClassificationResultHelpers",
srcs = ["sources/MPPClassificationResult+Helpers.mm"],
hdrs = ["sources/MPPClassificationResult+Helpers.h"],
deps = [
"//mediapipe/tasks/ios/components/containers:MPPClassificationResult",
":MPPCategoryHelpers",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
],
)

View File

@ -0,0 +1,26 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/framework/formats/classification.pb.h"
#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h"
NS_ASSUME_NONNULL_BEGIN
@interface MPPCategory (Helpers)
+ (MPPCategory *)categoryWithProto:(const mediapipe::Classification &)classificationProto;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,42 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h"
namespace {
using ClassificationProto = ::mediapipe::Classification;
}
@implementation MPPCategory (Helpers)
+ (MPPCategory *)categoryWithProto:(const ClassificationProto &)clasificationProto {
NSString *label;
NSString *displayName;
if (clasificationProto.has_label()) {
label = [NSString stringWithCppString:clasificationProto.label()];
}
if (clasificationProto.has_display_name()) {
displayName = [NSString stringWithCppString:clasificationProto.display_name()];
}
return [[MPPCategory alloc] initWithIndex:clasificationProto.index()
score:clasificationProto.score()
label:label
displayName:displayName];
}
@end

View File

@ -0,0 +1,35 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h"
NS_ASSUME_NONNULL_BEGIN
@interface MPPClassifications (Helpers)
+ (MPPClassifications *)classificationsWithProto:
(const mediapipe::tasks::components::containers::proto::Classifications &)classificationsProto;
@end
@interface MPPClassificationResult (Helpers)
+ (MPPClassificationResult *)classificationResultWithProto:
(const mediapipe::tasks::components::containers::proto::ClassificationResult &)
classificationResultProto;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,66 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
namespace {
using ClassificationsProto = ::mediapipe::tasks::components::containers::proto::Classifications;
using ClassificationResultProto =
::mediapipe::tasks::components::containers::proto::ClassificationResult;
} // namespace
@implementation MPPClassifications (Helpers)
+ (MPPClassifications *)classificationsWithProto:
(const ClassificationsProto &)classificationsProto {
NSMutableArray *categories = [[NSMutableArray alloc] init];
for (const auto &classification : classificationsProto.classification_list().classification()) {
[categories addObject:[MPPCategory categoryWithProto:classification]];
}
NSString *headName;
if (classificationsProto.has_head_name()) {
headName = [NSString stringWithCppString:classificationsProto.head_name()];
}
return [[MPPClassifications alloc] initWithHeadIndex:(NSInteger)classificationsProto.head_index()
headName:headName
categories:categories];
}
@end
@implementation MPPClassificationResult (Helpers)
+ (MPPClassificationResult *)classificationResultWithProto:
(const ClassificationResultProto &)classificationResultProto {
NSMutableArray *classifications = [[NSMutableArray alloc] init];
for (const auto &classifications_proto : classificationResultProto.classifications()) {
[classifications addObject:[MPPClassifications classificationsWithProto:classifications_proto]];
}
long timeStamp;
if (classificationResultProto.has_timestamp_ms()) {
timeStamp = classificationResultProto.timestamp_ms();
}
return [[MPPClassificationResult alloc] initWithClassifications:classifications
timeStamp:timeStamp];
}
@end

View File

@ -65,3 +65,23 @@ objc_library(
], ],
) )
objc_library(
name = "MPPPacketCreator",
srcs = ["sources/MPPPacketCreator.mm"],
hdrs = ["sources/MPPPacketCreator.h"],
copts = [
"-ObjC++",
"-std=c++17",
],
deps = [
"//mediapipe/framework:packet",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
],
)
objc_library(
name = "MPPTaskResult",
srcs = ["sources/MPPTaskResult.m"],
hdrs = ["sources/MPPTaskResult.h"],
)

View File

@ -0,0 +1,29 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#ifndef __cplusplus
#error This header can only be included by an Objective-C++ file.
#endif
#include "mediapipe/framework/packet.h"
/// This class is an Objective-C wrapper around a MediaPipe graph object, and
/// helps interface it with iOS technologies such as AVFoundation.
@interface MPPPacketCreator : NSObject
+ (mediapipe::Packet)createWithText:(NSString *)text;
@end

View File

@ -0,0 +1,29 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h"
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
namespace {
using ::mediapipe::MakePacket;
using ::mediapipe::Packet;
} // namespace
@implementation MPPPacketCreator
+ (Packet)createWithText:(NSString *)text {
return MakePacket<std::string>(text.cppString);
}
@end

View File

@ -0,0 +1,31 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import <Foundation/Foundation.h>
NS_ASSUME_NONNULL_BEGIN
/**
* MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend
* this class.
*/
NS_SWIFT_NAME(TaskResult)
@interface MPPTaskResult : NSObject <NSCopying>
/**
* Base options for configuring the Mediapipe task.
*/
@property(nonatomic, assign, readonly) long timeStamp;
- (instancetype)initWithTimeStamp:(long)timeStamp;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,27 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
@implementation MPPTaskResult
- (instancetype)initWithTimeStamp:(long)timeStamp {
self = [self init];
if (self) {
_timeStamp = timeStamp;
}
return self;
}
@end

View File

@ -15,6 +15,7 @@
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
@ -22,7 +23,10 @@ NS_ASSUME_NONNULL_BEGIN
* The base class of the user-facing iOS mediapipe text task api classes. * The base class of the user-facing iOS mediapipe text task api classes.
*/ */
NS_SWIFT_NAME(BaseTextTaskApi) NS_SWIFT_NAME(BaseTextTaskApi)
@interface MPPBaseTextTaskApi : NSObject @interface MPPBaseTextTaskApi : NSObject {
@protected
std::unique_ptr<mediapipe::tasks::core::TaskRunner> cppTaskRunner;
}
/** /**
* Initializes a new `MPPBaseTextTaskApi` with the mediapipe text task graph config proto. * Initializes a new `MPPBaseTextTaskApi` with the mediapipe text task graph config proto.

View File

@ -13,18 +13,18 @@
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h" #import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
namespace { namespace {
using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Packet;
using ::mediapipe::tasks::core::PacketMap;
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
} // namespace } // namespace
@interface MPPBaseTextTaskApi () { @interface MPPBaseTextTaskApi () {
/** TextSearcher backed by C++ API */ /** TextSearcher backed by C++ API */
std::unique_ptr<TaskRunnerCpp> _taskRunner; std::unique_ptr<TaskRunnerCpp> _cppTaskRunner;
} }
@end @end
@ -40,13 +40,13 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
return nil; return nil;
} }
_taskRunner = std::move(taskRunnerResult.value()); _cppTaskRunner = std::move(taskRunnerResult.value());
} }
return self; return self;
} }
- (void)close { - (void)close {
_taskRunner->Close(); _cppTaskRunner->Close();
} }
@end @end

View File

@ -0,0 +1,33 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPBaseTextTaskApi",
srcs = ["sources/MPPBaseTextTaskApi.mm"],
hdrs = ["sources/MPPBaseTextTaskApi.h"],
copts = [
"-ObjC++",
"-std=c++17",
],
deps = [
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
],
)

View File

@ -28,8 +28,11 @@ objc_library(
"//mediapipe/tasks/ios/core:MPPTaskOptions", "//mediapipe/tasks/ios/core:MPPTaskOptions",
"//mediapipe/tasks/ios/core:MPPTaskInfo", "//mediapipe/tasks/ios/core:MPPTaskInfo",
"//mediapipe/tasks/ios/text/core:MPPBaseTextTaskApi", "//mediapipe/tasks/ios/text/core:MPPBaseTextTaskApi",
"//mediapipe/tasks/ios/core:MPPPacketCreator",
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers",
"//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
":MPPTextClassifierOptions", ":MPPTextClassifierOptions",
], ],
) )

View File

@ -14,6 +14,7 @@
==============================================================================*/ ==============================================================================*/
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
#import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h" #import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h"
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
@ -52,6 +53,8 @@ NS_SWIFT_NAME(TextClassifier)
*/ */
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error; - (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error;
- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error;
- (instancetype)init NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE;
+ (instancetype)new NS_UNAVAILABLE; + (instancetype)new NS_UNAVAILABLE;

View File

@ -15,9 +15,20 @@
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" #import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h"
#include "absl/status/statusor.h"
namespace {
using ::mediapipe::Packet;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::PacketMap;
} // namespace
static NSString *const kClassificationsStreamName = @"classifications_out"; static NSString *const kClassificationsStreamName = @"classifications_out";
static NSString *const kClassificationsTag = @"classifications"; static NSString *const kClassificationsTag = @"classifications";
static NSString *const kTextInStreamName = @"text_in"; static NSString *const kTextInStreamName = @"text_in";
@ -50,4 +61,18 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
return [self initWithOptions:options error:error]; return [self initWithOptions:options error:error];
} }
- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error {
Packet packet = [MPPPacketCreator createWithText:text];
absl::StatusOr<PacketMap> output_packet_map =
cppTaskRunner->Process({{kTextInStreamName.cppString, packet}});
if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) {
return nil;
}
return [MPPClassificationResult
classificationResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString]
.Get<ClassificationResult>()];
}
@end @end