Added APIs for text classification
This commit is contained in:
parent
683c2b1f09
commit
8d9c1b8a0f
|
@ -23,6 +23,9 @@ objc_library(
|
|||
deps = [
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:cord",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -13,8 +13,16 @@
|
|||
// limitations under the License.
|
||||
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||
|
||||
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h" // from @com_google_absl
|
||||
#include "absl/strings/cord.h" // from @com_google_absl
|
||||
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
|
||||
/** Error domain of TensorFlow Lite Support related errors. */
|
||||
NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks";
|
||||
|
||||
|
@ -60,8 +68,8 @@ NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks";
|
|||
if (status.ok()) {
|
||||
return YES;
|
||||
}
|
||||
// Payload of absl::Status created by the Media Pipe task library stores an appropriate value of the
|
||||
// enum MPPiteSupportStatus. The integer value corresponding to the MPPiteSupportStatus enum
|
||||
// Payload of absl::Status created by the Media Pipe task library stores an appropriate value of
|
||||
// the enum MPPiteSupportStatus. The integer value corresponding to the MPPiteSupportStatus enum
|
||||
// stored in the payload is extracted here to later map to the appropriate error code to be
|
||||
// returned. In cases where the enum is not stored in (payload is NULL or the payload string
|
||||
// cannot be converted to an integer), we set the error code value to be 1
|
||||
|
|
|
@ -22,6 +22,8 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
|
||||
@property(readonly) std::string cppString;
|
||||
|
||||
+ (NSString *)stringWithCppString:(std::string)text;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
|
|
@ -20,4 +20,8 @@
|
|||
return std::string(self.UTF8String, [self lengthOfBytesUsingEncoding:NSUTF8StringEncoding]);
|
||||
}
|
||||
|
||||
+ (NSString *)stringWithCppString:(std::string)text {
|
||||
return [NSString stringWithCString:text.c_str() encoding:[NSString defaultCStringEncoding]];
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
|
@ -28,5 +28,6 @@ objc_library(
|
|||
hdrs = ["sources/MPPClassificationResult.h"],
|
||||
deps = [
|
||||
":MPPCategory",
|
||||
"//mediapipe/tasks/ios/core:MPPTaskResult",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -18,7 +18,7 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
|
||||
/** Encapsulates information about a class in the classification results. */
|
||||
NS_SWIFT_NAME(ClassificationCategory)
|
||||
@interface TFLCategory : NSObject
|
||||
@interface MPPCategory : NSObject
|
||||
|
||||
/** Index of the class in the corresponding label map, usually packed in the TFLite Model
|
||||
* Metadata. */
|
||||
|
|
|
@ -12,9 +12,9 @@
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#import "mediapipe/tasks/ios/components/containers/sources/TFLCategory.h"
|
||||
#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h"
|
||||
|
||||
@implementation TFLCategory
|
||||
@implementation MPPCategory
|
||||
|
||||
- (instancetype)initWithIndex:(NSInteger)index
|
||||
score:(float)score
|
||||
|
|
|
@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#import <Foundation/Foundation.h>
|
||||
#import "mediapipe/tasks/ios/task/components/containers/sources/MPPCategory.h"
|
||||
#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
|
@ -71,7 +72,7 @@ NS_SWIFT_NAME(Classifications)
|
|||
|
||||
/** Encapsulates results of any classification task. */
|
||||
NS_SWIFT_NAME(ClassificationResult)
|
||||
@interface MPPClassificationResult : NSObject
|
||||
@interface MPPClassificationResult : MPPTaskResult
|
||||
|
||||
/** Array of MPPClassifications objects containing classifier predictions per image classifier
|
||||
* head.
|
||||
|
@ -87,7 +88,8 @@ NS_SWIFT_NAME(ClassificationResult)
|
|||
* @return An instance of MPPClassificationResult initialized with the given array of
|
||||
* classifications.
|
||||
*/
|
||||
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications;
|
||||
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
|
||||
timeStamp:(long)timeStamp;
|
||||
|
||||
@end
|
||||
|
||||
|
|
|
@ -39,8 +39,9 @@ limitations under the License.
|
|||
NSArray<MPPClassifications *> *_classifications;
|
||||
}
|
||||
|
||||
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications {
|
||||
self = [super init];
|
||||
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
|
||||
timeStamp:(long)timeStamp {
|
||||
self = [super initWithTimeStamp:timeStamp];
|
||||
if (self) {
|
||||
_classifications = classifications;
|
||||
}
|
||||
|
|
40
mediapipe/tasks/ios/components/containers/utils/BUILD
Normal file
40
mediapipe/tasks/ios/components/containers/utils/BUILD
Normal file
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
objc_library(
|
||||
name = "MPPCategoryHelpers",
|
||||
srcs = ["sources/MPPCategory+Helpers.mm"],
|
||||
hdrs = ["sources/MPPCategory+Helpers.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/ios/components/containers:MPPCategory",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPClassificationResultHelpers",
|
||||
srcs = ["sources/MPPClassificationResult+Helpers.mm"],
|
||||
hdrs = ["sources/MPPClassificationResult+Helpers.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/ios/components/containers:MPPClassificationResult",
|
||||
":MPPCategoryHelpers",
|
||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,26 @@
|
|||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@interface MPPCategory (Helpers)
|
||||
|
||||
+ (MPPCategory *)categoryWithProto:(const mediapipe::Classification &)classificationProto;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
|
@ -0,0 +1,42 @@
|
|||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h"
|
||||
|
||||
namespace {
|
||||
using ClassificationProto = ::mediapipe::Classification;
|
||||
}
|
||||
|
||||
@implementation MPPCategory (Helpers)
|
||||
|
||||
+ (MPPCategory *)categoryWithProto:(const ClassificationProto &)clasificationProto {
|
||||
NSString *label;
|
||||
NSString *displayName;
|
||||
|
||||
if (clasificationProto.has_label()) {
|
||||
label = [NSString stringWithCppString:clasificationProto.label()];
|
||||
}
|
||||
|
||||
if (clasificationProto.has_display_name()) {
|
||||
displayName = [NSString stringWithCppString:clasificationProto.display_name()];
|
||||
}
|
||||
|
||||
return [[MPPCategory alloc] initWithIndex:clasificationProto.index()
|
||||
score:clasificationProto.score()
|
||||
label:label
|
||||
displayName:displayName];
|
||||
}
|
||||
|
||||
@end
|
|
@ -0,0 +1,35 @@
|
|||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@interface MPPClassifications (Helpers)
|
||||
|
||||
+ (MPPClassifications *)classificationsWithProto:
|
||||
(const mediapipe::tasks::components::containers::proto::Classifications &)classificationsProto;
|
||||
|
||||
@end
|
||||
|
||||
@interface MPPClassificationResult (Helpers)
|
||||
|
||||
+ (MPPClassificationResult *)classificationResultWithProto:
|
||||
(const mediapipe::tasks::components::containers::proto::ClassificationResult &)
|
||||
classificationResultProto;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
|
@ -0,0 +1,66 @@
|
|||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
|
||||
|
||||
namespace {
|
||||
using ClassificationsProto = ::mediapipe::tasks::components::containers::proto::Classifications;
|
||||
using ClassificationResultProto =
|
||||
::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
} // namespace
|
||||
|
||||
@implementation MPPClassifications (Helpers)
|
||||
|
||||
+ (MPPClassifications *)classificationsWithProto:
|
||||
(const ClassificationsProto &)classificationsProto {
|
||||
NSMutableArray *categories = [[NSMutableArray alloc] init];
|
||||
for (const auto &classification : classificationsProto.classification_list().classification()) {
|
||||
[categories addObject:[MPPCategory categoryWithProto:classification]];
|
||||
}
|
||||
|
||||
NSString *headName;
|
||||
|
||||
if (classificationsProto.has_head_name()) {
|
||||
headName = [NSString stringWithCppString:classificationsProto.head_name()];
|
||||
}
|
||||
|
||||
return [[MPPClassifications alloc] initWithHeadIndex:(NSInteger)classificationsProto.head_index()
|
||||
headName:headName
|
||||
categories:categories];
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
@implementation MPPClassificationResult (Helpers)
|
||||
|
||||
+ (MPPClassificationResult *)classificationResultWithProto:
|
||||
(const ClassificationResultProto &)classificationResultProto {
|
||||
NSMutableArray *classifications = [[NSMutableArray alloc] init];
|
||||
for (const auto &classifications_proto : classificationResultProto.classifications()) {
|
||||
[classifications addObject:[MPPClassifications classificationsWithProto:classifications_proto]];
|
||||
}
|
||||
|
||||
long timeStamp;
|
||||
|
||||
if (classificationResultProto.has_timestamp_ms()) {
|
||||
timeStamp = classificationResultProto.timestamp_ms();
|
||||
}
|
||||
|
||||
return [[MPPClassificationResult alloc] initWithClassifications:classifications
|
||||
timeStamp:timeStamp];
|
||||
}
|
||||
|
||||
@end
|
|
@ -65,3 +65,23 @@ objc_library(
|
|||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPPacketCreator",
|
||||
srcs = ["sources/MPPPacketCreator.mm"],
|
||||
hdrs = ["sources/MPPPacketCreator.h"],
|
||||
copts = [
|
||||
"-ObjC++",
|
||||
"-std=c++17",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPTaskResult",
|
||||
srcs = ["sources/MPPTaskResult.m"],
|
||||
hdrs = ["sources/MPPTaskResult.h"],
|
||||
)
|
||||
|
||||
|
|
29
mediapipe/tasks/ios/core/sources/MPPPacketCreator.h
Normal file
29
mediapipe/tasks/ios/core/sources/MPPPacketCreator.h
Normal file
|
@ -0,0 +1,29 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#ifndef __cplusplus
|
||||
#error This header can only be included by an Objective-C++ file.
|
||||
#endif
|
||||
|
||||
#include "mediapipe/framework/packet.h"
|
||||
|
||||
/// This class is an Objective-C wrapper around a MediaPipe graph object, and
|
||||
/// helps interface it with iOS technologies such as AVFoundation.
|
||||
@interface MPPPacketCreator : NSObject
|
||||
|
||||
+ (mediapipe::Packet)createWithText:(NSString *)text;
|
||||
|
||||
@end
|
29
mediapipe/tasks/ios/core/sources/MPPPacketCreator.mm
Normal file
29
mediapipe/tasks/ios/core/sources/MPPPacketCreator.mm
Normal file
|
@ -0,0 +1,29 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h"
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||
|
||||
namespace {
|
||||
using ::mediapipe::MakePacket;
|
||||
using ::mediapipe::Packet;
|
||||
} // namespace
|
||||
|
||||
@implementation MPPPacketCreator
|
||||
|
||||
+ (Packet)createWithText:(NSString *)text {
|
||||
return MakePacket<std::string>(text.cppString);
|
||||
}
|
||||
|
||||
@end
|
31
mediapipe/tasks/ios/core/sources/MPPTaskResult.h
Normal file
31
mediapipe/tasks/ios/core/sources/MPPTaskResult.h
Normal file
|
@ -0,0 +1,31 @@
|
|||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend
|
||||
* this class.
|
||||
*/
|
||||
NS_SWIFT_NAME(TaskResult)
|
||||
@interface MPPTaskResult : NSObject <NSCopying>
|
||||
/**
|
||||
* Base options for configuring the Mediapipe task.
|
||||
*/
|
||||
@property(nonatomic, assign, readonly) long timeStamp;
|
||||
|
||||
- (instancetype)initWithTimeStamp:(long)timeStamp;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
27
mediapipe/tasks/ios/core/sources/MPPTaskResult.m
Normal file
27
mediapipe/tasks/ios/core/sources/MPPTaskResult.m
Normal file
|
@ -0,0 +1,27 @@
|
|||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
|
||||
|
||||
@implementation MPPTaskResult
|
||||
|
||||
- (instancetype)initWithTimeStamp:(long)timeStamp {
|
||||
self = [self init];
|
||||
if (self) {
|
||||
_timeStamp = timeStamp;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
@end
|
|
@ -15,6 +15,7 @@
|
|||
#import <Foundation/Foundation.h>
|
||||
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
|
@ -22,7 +23,10 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
* The base class of the user-facing iOS mediapipe text task api classes.
|
||||
*/
|
||||
NS_SWIFT_NAME(BaseTextTaskApi)
|
||||
@interface MPPBaseTextTaskApi : NSObject
|
||||
@interface MPPBaseTextTaskApi : NSObject {
|
||||
@protected
|
||||
std::unique_ptr<mediapipe::tasks::core::TaskRunner> cppTaskRunner;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a new `MPPBaseTextTaskApi` with the mediapipe text task graph config proto.
|
||||
|
|
|
@ -13,18 +13,18 @@
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h"
|
||||
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||
|
||||
namespace {
|
||||
using ::mediapipe::CalculatorGraphConfig;
|
||||
using ::mediapipe::Packet;
|
||||
using ::mediapipe::tasks::core::PacketMap;
|
||||
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
|
||||
} // namespace
|
||||
|
||||
@interface MPPBaseTextTaskApi () {
|
||||
/** TextSearcher backed by C++ API */
|
||||
std::unique_ptr<TaskRunnerCpp> _taskRunner;
|
||||
std::unique_ptr<TaskRunnerCpp> _cppTaskRunner;
|
||||
}
|
||||
@end
|
||||
|
||||
|
@ -40,13 +40,13 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
|
|||
return nil;
|
||||
}
|
||||
|
||||
_taskRunner = std::move(taskRunnerResult.value());
|
||||
_cppTaskRunner = std::move(taskRunnerResult.value());
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (void)close {
|
||||
_taskRunner->Close();
|
||||
_cppTaskRunner->Close();
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
33
mediapipe/tasks/ios/text/core/utils/BUILD
Normal file
33
mediapipe/tasks/ios/text/core/utils/BUILD
Normal file
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
objc_library(
|
||||
name = "MPPBaseTextTaskApi",
|
||||
srcs = ["sources/MPPBaseTextTaskApi.mm"],
|
||||
hdrs = ["sources/MPPBaseTextTaskApi.h"],
|
||||
copts = [
|
||||
"-ObjC++",
|
||||
"-std=c++17",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||
],
|
||||
)
|
||||
|
|
@ -28,8 +28,11 @@ objc_library(
|
|||
"//mediapipe/tasks/ios/core:MPPTaskOptions",
|
||||
"//mediapipe/tasks/ios/core:MPPTaskInfo",
|
||||
"//mediapipe/tasks/ios/text/core:MPPBaseTextTaskApi",
|
||||
"//mediapipe/tasks/ios/core:MPPPacketCreator",
|
||||
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers",
|
||||
"//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers",
|
||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||
":MPPTextClassifierOptions",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
==============================================================================*/
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
|
||||
#import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h"
|
||||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
|
||||
|
@ -52,6 +53,8 @@ NS_SWIFT_NAME(TextClassifier)
|
|||
*/
|
||||
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error;
|
||||
|
||||
- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error;
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
+ (instancetype)new NS_UNAVAILABLE;
|
||||
|
|
|
@ -15,9 +15,20 @@
|
|||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
|
||||
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
||||
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h"
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
|
||||
namespace {
|
||||
using ::mediapipe::Packet;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::core::PacketMap;
|
||||
} // namespace
|
||||
|
||||
static NSString *const kClassificationsStreamName = @"classifications_out";
|
||||
static NSString *const kClassificationsTag = @"classifications";
|
||||
static NSString *const kTextInStreamName = @"text_in";
|
||||
|
@ -50,4 +61,18 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
|
|||
return [self initWithOptions:options error:error];
|
||||
}
|
||||
|
||||
- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error {
|
||||
Packet packet = [MPPPacketCreator createWithText:text];
|
||||
absl::StatusOr<PacketMap> output_packet_map =
|
||||
cppTaskRunner->Process({{kTextInStreamName.cppString, packet}});
|
||||
|
||||
if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
return [MPPClassificationResult
|
||||
classificationResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString]
|
||||
.Get<ClassificationResult>()];
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
Loading…
Reference in New Issue
Block a user