Added APIs for text classification
This commit is contained in:
parent
683c2b1f09
commit
8d9c1b8a0f
|
@ -23,6 +23,9 @@ objc_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/ios/common:MPPCommon",
|
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:cord",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,8 +13,16 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||||
|
|
||||||
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
|
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/status/status.h" // from @com_google_absl
|
||||||
|
#include "absl/strings/cord.h" // from @com_google_absl
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
|
||||||
/** Error domain of TensorFlow Lite Support related errors. */
|
/** Error domain of TensorFlow Lite Support related errors. */
|
||||||
NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks";
|
NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks";
|
||||||
|
|
||||||
|
@ -60,8 +68,8 @@ NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks";
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
return YES;
|
return YES;
|
||||||
}
|
}
|
||||||
// Payload of absl::Status created by the Media Pipe task library stores an appropriate value of the
|
// Payload of absl::Status created by the Media Pipe task library stores an appropriate value of
|
||||||
// enum MPPiteSupportStatus. The integer value corresponding to the MPPiteSupportStatus enum
|
// the enum MPPiteSupportStatus. The integer value corresponding to the MPPiteSupportStatus enum
|
||||||
// stored in the payload is extracted here to later map to the appropriate error code to be
|
// stored in the payload is extracted here to later map to the appropriate error code to be
|
||||||
// returned. In cases where the enum is not stored in (payload is NULL or the payload string
|
// returned. In cases where the enum is not stored in (payload is NULL or the payload string
|
||||||
// cannot be converted to an integer), we set the error code value to be 1
|
// cannot be converted to an integer), we set the error code value to be 1
|
||||||
|
|
|
@ -22,6 +22,8 @@ NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
@property(readonly) std::string cppString;
|
@property(readonly) std::string cppString;
|
||||||
|
|
||||||
|
+ (NSString *)stringWithCppString:(std::string)text;
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
||||||
NS_ASSUME_NONNULL_END
|
NS_ASSUME_NONNULL_END
|
||||||
|
|
|
@ -20,4 +20,8 @@
|
||||||
return std::string(self.UTF8String, [self lengthOfBytesUsingEncoding:NSUTF8StringEncoding]);
|
return std::string(self.UTF8String, [self lengthOfBytesUsingEncoding:NSUTF8StringEncoding]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
+ (NSString *)stringWithCppString:(std::string)text {
|
||||||
|
return [NSString stringWithCString:text.c_str() encoding:[NSString defaultCStringEncoding]];
|
||||||
|
}
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
|
@ -28,5 +28,6 @@ objc_library(
|
||||||
hdrs = ["sources/MPPClassificationResult.h"],
|
hdrs = ["sources/MPPClassificationResult.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":MPPCategory",
|
":MPPCategory",
|
||||||
|
"//mediapipe/tasks/ios/core:MPPTaskResult",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,7 +18,7 @@ NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
/** Encapsulates information about a class in the classification results. */
|
/** Encapsulates information about a class in the classification results. */
|
||||||
NS_SWIFT_NAME(ClassificationCategory)
|
NS_SWIFT_NAME(ClassificationCategory)
|
||||||
@interface TFLCategory : NSObject
|
@interface MPPCategory : NSObject
|
||||||
|
|
||||||
/** Index of the class in the corresponding label map, usually packed in the TFLite Model
|
/** Index of the class in the corresponding label map, usually packed in the TFLite Model
|
||||||
* Metadata. */
|
* Metadata. */
|
||||||
|
|
|
@ -12,9 +12,9 @@
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#import "mediapipe/tasks/ios/components/containers/sources/TFLCategory.h"
|
#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h"
|
||||||
|
|
||||||
@implementation TFLCategory
|
@implementation MPPCategory
|
||||||
|
|
||||||
- (instancetype)initWithIndex:(NSInteger)index
|
- (instancetype)initWithIndex:(NSInteger)index
|
||||||
score:(float)score
|
score:(float)score
|
||||||
|
|
|
@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#import <Foundation/Foundation.h>
|
#import <Foundation/Foundation.h>
|
||||||
#import "mediapipe/tasks/ios/task/components/containers/sources/MPPCategory.h"
|
#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h"
|
||||||
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
|
||||||
|
|
||||||
NS_ASSUME_NONNULL_BEGIN
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
|
@ -71,7 +72,7 @@ NS_SWIFT_NAME(Classifications)
|
||||||
|
|
||||||
/** Encapsulates results of any classification task. */
|
/** Encapsulates results of any classification task. */
|
||||||
NS_SWIFT_NAME(ClassificationResult)
|
NS_SWIFT_NAME(ClassificationResult)
|
||||||
@interface MPPClassificationResult : NSObject
|
@interface MPPClassificationResult : MPPTaskResult
|
||||||
|
|
||||||
/** Array of MPPClassifications objects containing classifier predictions per image classifier
|
/** Array of MPPClassifications objects containing classifier predictions per image classifier
|
||||||
* head.
|
* head.
|
||||||
|
@ -87,7 +88,8 @@ NS_SWIFT_NAME(ClassificationResult)
|
||||||
* @return An instance of MPPClassificationResult initialized with the given array of
|
* @return An instance of MPPClassificationResult initialized with the given array of
|
||||||
* classifications.
|
* classifications.
|
||||||
*/
|
*/
|
||||||
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications;
|
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
|
||||||
|
timeStamp:(long)timeStamp;
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
||||||
|
|
|
@ -39,8 +39,9 @@ limitations under the License.
|
||||||
NSArray<MPPClassifications *> *_classifications;
|
NSArray<MPPClassifications *> *_classifications;
|
||||||
}
|
}
|
||||||
|
|
||||||
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications {
|
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
|
||||||
self = [super init];
|
timeStamp:(long)timeStamp {
|
||||||
|
self = [super initWithTimeStamp:timeStamp];
|
||||||
if (self) {
|
if (self) {
|
||||||
_classifications = classifications;
|
_classifications = classifications;
|
||||||
}
|
}
|
||||||
|
|
40
mediapipe/tasks/ios/components/containers/utils/BUILD
Normal file
40
mediapipe/tasks/ios/components/containers/utils/BUILD
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPCategoryHelpers",
|
||||||
|
srcs = ["sources/MPPCategory+Helpers.mm"],
|
||||||
|
hdrs = ["sources/MPPCategory+Helpers.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/ios/components/containers:MPPCategory",
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
|
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPClassificationResultHelpers",
|
||||||
|
srcs = ["sources/MPPClassificationResult+Helpers.mm"],
|
||||||
|
hdrs = ["sources/MPPClassificationResult+Helpers.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/ios/components/containers:MPPClassificationResult",
|
||||||
|
":MPPCategoryHelpers",
|
||||||
|
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,26 @@
|
||||||
|
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
|
#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h"
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
|
@interface MPPCategory (Helpers)
|
||||||
|
|
||||||
|
+ (MPPCategory *)categoryWithProto:(const mediapipe::Classification &)classificationProto;
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_END
|
|
@ -0,0 +1,42 @@
|
||||||
|
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||||
|
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
using ClassificationProto = ::mediapipe::Classification;
|
||||||
|
}
|
||||||
|
|
||||||
|
@implementation MPPCategory (Helpers)
|
||||||
|
|
||||||
|
+ (MPPCategory *)categoryWithProto:(const ClassificationProto &)clasificationProto {
|
||||||
|
NSString *label;
|
||||||
|
NSString *displayName;
|
||||||
|
|
||||||
|
if (clasificationProto.has_label()) {
|
||||||
|
label = [NSString stringWithCppString:clasificationProto.label()];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (clasificationProto.has_display_name()) {
|
||||||
|
displayName = [NSString stringWithCppString:clasificationProto.display_name()];
|
||||||
|
}
|
||||||
|
|
||||||
|
return [[MPPCategory alloc] initWithIndex:clasificationProto.index()
|
||||||
|
score:clasificationProto.score()
|
||||||
|
label:label
|
||||||
|
displayName:displayName];
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
|
@ -0,0 +1,35 @@
|
||||||
|
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
|
#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h"
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
|
@interface MPPClassifications (Helpers)
|
||||||
|
|
||||||
|
+ (MPPClassifications *)classificationsWithProto:
|
||||||
|
(const mediapipe::tasks::components::containers::proto::Classifications &)classificationsProto;
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
@interface MPPClassificationResult (Helpers)
|
||||||
|
|
||||||
|
+ (MPPClassificationResult *)classificationResultWithProto:
|
||||||
|
(const mediapipe::tasks::components::containers::proto::ClassificationResult &)
|
||||||
|
classificationResultProto;
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_END
|
|
@ -0,0 +1,66 @@
|
||||||
|
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||||
|
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h"
|
||||||
|
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
using ClassificationsProto = ::mediapipe::tasks::components::containers::proto::Classifications;
|
||||||
|
using ClassificationResultProto =
|
||||||
|
::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
@implementation MPPClassifications (Helpers)
|
||||||
|
|
||||||
|
+ (MPPClassifications *)classificationsWithProto:
|
||||||
|
(const ClassificationsProto &)classificationsProto {
|
||||||
|
NSMutableArray *categories = [[NSMutableArray alloc] init];
|
||||||
|
for (const auto &classification : classificationsProto.classification_list().classification()) {
|
||||||
|
[categories addObject:[MPPCategory categoryWithProto:classification]];
|
||||||
|
}
|
||||||
|
|
||||||
|
NSString *headName;
|
||||||
|
|
||||||
|
if (classificationsProto.has_head_name()) {
|
||||||
|
headName = [NSString stringWithCppString:classificationsProto.head_name()];
|
||||||
|
}
|
||||||
|
|
||||||
|
return [[MPPClassifications alloc] initWithHeadIndex:(NSInteger)classificationsProto.head_index()
|
||||||
|
headName:headName
|
||||||
|
categories:categories];
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
@implementation MPPClassificationResult (Helpers)
|
||||||
|
|
||||||
|
+ (MPPClassificationResult *)classificationResultWithProto:
|
||||||
|
(const ClassificationResultProto &)classificationResultProto {
|
||||||
|
NSMutableArray *classifications = [[NSMutableArray alloc] init];
|
||||||
|
for (const auto &classifications_proto : classificationResultProto.classifications()) {
|
||||||
|
[classifications addObject:[MPPClassifications classificationsWithProto:classifications_proto]];
|
||||||
|
}
|
||||||
|
|
||||||
|
long timeStamp;
|
||||||
|
|
||||||
|
if (classificationResultProto.has_timestamp_ms()) {
|
||||||
|
timeStamp = classificationResultProto.timestamp_ms();
|
||||||
|
}
|
||||||
|
|
||||||
|
return [[MPPClassificationResult alloc] initWithClassifications:classifications
|
||||||
|
timeStamp:timeStamp];
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
|
@ -65,3 +65,23 @@ objc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPPacketCreator",
|
||||||
|
srcs = ["sources/MPPPacketCreator.mm"],
|
||||||
|
hdrs = ["sources/MPPPacketCreator.h"],
|
||||||
|
copts = [
|
||||||
|
"-ObjC++",
|
||||||
|
"-std=c++17",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:packet",
|
||||||
|
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPTaskResult",
|
||||||
|
srcs = ["sources/MPPTaskResult.m"],
|
||||||
|
hdrs = ["sources/MPPTaskResult.h"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
29
mediapipe/tasks/ios/core/sources/MPPPacketCreator.h
Normal file
29
mediapipe/tasks/ios/core/sources/MPPPacketCreator.h
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import <Foundation/Foundation.h>
|
||||||
|
|
||||||
|
#ifndef __cplusplus
|
||||||
|
#error This header can only be included by an Objective-C++ file.
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
|
||||||
|
/// This class is an Objective-C wrapper around a MediaPipe graph object, and
|
||||||
|
/// helps interface it with iOS technologies such as AVFoundation.
|
||||||
|
@interface MPPPacketCreator : NSObject
|
||||||
|
|
||||||
|
+ (mediapipe::Packet)createWithText:(NSString *)text;
|
||||||
|
|
||||||
|
@end
|
29
mediapipe/tasks/ios/core/sources/MPPPacketCreator.mm
Normal file
29
mediapipe/tasks/ios/core/sources/MPPPacketCreator.mm
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
// Copyright 2019 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h"
|
||||||
|
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
using ::mediapipe::MakePacket;
|
||||||
|
using ::mediapipe::Packet;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
@implementation MPPPacketCreator
|
||||||
|
|
||||||
|
+ (Packet)createWithText:(NSString *)text {
|
||||||
|
return MakePacket<std::string>(text.cppString);
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
31
mediapipe/tasks/ios/core/sources/MPPTaskResult.h
Normal file
31
mediapipe/tasks/ios/core/sources/MPPTaskResult.h
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#import <Foundation/Foundation.h>
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
|
/**
|
||||||
|
* MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend
|
||||||
|
* this class.
|
||||||
|
*/
|
||||||
|
NS_SWIFT_NAME(TaskResult)
|
||||||
|
@interface MPPTaskResult : NSObject <NSCopying>
|
||||||
|
/**
|
||||||
|
* Base options for configuring the Mediapipe task.
|
||||||
|
*/
|
||||||
|
@property(nonatomic, assign, readonly) long timeStamp;
|
||||||
|
|
||||||
|
- (instancetype)initWithTimeStamp:(long)timeStamp;
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_END
|
27
mediapipe/tasks/ios/core/sources/MPPTaskResult.m
Normal file
27
mediapipe/tasks/ios/core/sources/MPPTaskResult.m
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
|
||||||
|
|
||||||
|
@implementation MPPTaskResult
|
||||||
|
|
||||||
|
- (instancetype)initWithTimeStamp:(long)timeStamp {
|
||||||
|
self = [self init];
|
||||||
|
if (self) {
|
||||||
|
_timeStamp = timeStamp;
|
||||||
|
}
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
|
@ -15,6 +15,7 @@
|
||||||
#import <Foundation/Foundation.h>
|
#import <Foundation/Foundation.h>
|
||||||
|
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||||
|
|
||||||
NS_ASSUME_NONNULL_BEGIN
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
|
@ -22,7 +23,10 @@ NS_ASSUME_NONNULL_BEGIN
|
||||||
* The base class of the user-facing iOS mediapipe text task api classes.
|
* The base class of the user-facing iOS mediapipe text task api classes.
|
||||||
*/
|
*/
|
||||||
NS_SWIFT_NAME(BaseTextTaskApi)
|
NS_SWIFT_NAME(BaseTextTaskApi)
|
||||||
@interface MPPBaseTextTaskApi : NSObject
|
@interface MPPBaseTextTaskApi : NSObject {
|
||||||
|
@protected
|
||||||
|
std::unique_ptr<mediapipe::tasks::core::TaskRunner> cppTaskRunner;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initializes a new `MPPBaseTextTaskApi` with the mediapipe text task graph config proto.
|
* Initializes a new `MPPBaseTextTaskApi` with the mediapipe text task graph config proto.
|
||||||
|
|
|
@ -13,18 +13,18 @@
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h"
|
#import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h"
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
|
||||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
using ::mediapipe::CalculatorGraphConfig;
|
using ::mediapipe::CalculatorGraphConfig;
|
||||||
|
using ::mediapipe::Packet;
|
||||||
|
using ::mediapipe::tasks::core::PacketMap;
|
||||||
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
|
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@interface MPPBaseTextTaskApi () {
|
@interface MPPBaseTextTaskApi () {
|
||||||
/** TextSearcher backed by C++ API */
|
/** TextSearcher backed by C++ API */
|
||||||
std::unique_ptr<TaskRunnerCpp> _taskRunner;
|
std::unique_ptr<TaskRunnerCpp> _cppTaskRunner;
|
||||||
}
|
}
|
||||||
@end
|
@end
|
||||||
|
|
||||||
|
@ -40,13 +40,13 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
|
||||||
return nil;
|
return nil;
|
||||||
}
|
}
|
||||||
|
|
||||||
_taskRunner = std::move(taskRunnerResult.value());
|
_cppTaskRunner = std::move(taskRunnerResult.value());
|
||||||
}
|
}
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)close {
|
- (void)close {
|
||||||
_taskRunner->Close();
|
_cppTaskRunner->Close();
|
||||||
}
|
}
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
33
mediapipe/tasks/ios/text/core/utils/BUILD
Normal file
33
mediapipe/tasks/ios/text/core/utils/BUILD
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPBaseTextTaskApi",
|
||||||
|
srcs = ["sources/MPPBaseTextTaskApi.mm"],
|
||||||
|
hdrs = ["sources/MPPBaseTextTaskApi.h"],
|
||||||
|
copts = [
|
||||||
|
"-ObjC++",
|
||||||
|
"-std=c++17",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/core:task_runner",
|
||||||
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
|
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -28,8 +28,11 @@ objc_library(
|
||||||
"//mediapipe/tasks/ios/core:MPPTaskOptions",
|
"//mediapipe/tasks/ios/core:MPPTaskOptions",
|
||||||
"//mediapipe/tasks/ios/core:MPPTaskInfo",
|
"//mediapipe/tasks/ios/core:MPPTaskInfo",
|
||||||
"//mediapipe/tasks/ios/text/core:MPPBaseTextTaskApi",
|
"//mediapipe/tasks/ios/text/core:MPPBaseTextTaskApi",
|
||||||
|
"//mediapipe/tasks/ios/core:MPPPacketCreator",
|
||||||
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers",
|
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers",
|
||||||
|
"//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers",
|
||||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||||
|
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||||
":MPPTextClassifierOptions",
|
":MPPTextClassifierOptions",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#import <Foundation/Foundation.h>
|
#import <Foundation/Foundation.h>
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h"
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
|
||||||
#import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h"
|
#import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h"
|
||||||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
|
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
|
||||||
|
@ -52,6 +53,8 @@ NS_SWIFT_NAME(TextClassifier)
|
||||||
*/
|
*/
|
||||||
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error;
|
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error;
|
||||||
|
|
||||||
|
- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error;
|
||||||
|
|
||||||
- (instancetype)init NS_UNAVAILABLE;
|
- (instancetype)init NS_UNAVAILABLE;
|
||||||
|
|
||||||
+ (instancetype)new NS_UNAVAILABLE;
|
+ (instancetype)new NS_UNAVAILABLE;
|
||||||
|
|
|
@ -15,9 +15,20 @@
|
||||||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
|
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
|
||||||
|
|
||||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||||
|
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||||
|
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
|
||||||
|
#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h"
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
||||||
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h"
|
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h"
|
||||||
|
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
using ::mediapipe::Packet;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
|
using ::mediapipe::tasks::core::PacketMap;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
static NSString *const kClassificationsStreamName = @"classifications_out";
|
static NSString *const kClassificationsStreamName = @"classifications_out";
|
||||||
static NSString *const kClassificationsTag = @"classifications";
|
static NSString *const kClassificationsTag = @"classifications";
|
||||||
static NSString *const kTextInStreamName = @"text_in";
|
static NSString *const kTextInStreamName = @"text_in";
|
||||||
|
@ -50,4 +61,18 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
|
||||||
return [self initWithOptions:options error:error];
|
return [self initWithOptions:options error:error];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error {
|
||||||
|
Packet packet = [MPPPacketCreator createWithText:text];
|
||||||
|
absl::StatusOr<PacketMap> output_packet_map =
|
||||||
|
cppTaskRunner->Process({{kTextInStreamName.cppString, packet}});
|
||||||
|
|
||||||
|
if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
|
||||||
|
return [MPPClassificationResult
|
||||||
|
classificationResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString]
|
||||||
|
.Get<ClassificationResult>()];
|
||||||
|
}
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
Loading…
Reference in New Issue
Block a user