Added iOS task manager
This commit is contained in:
parent
dea0e21aec
commit
96247ccce4
|
@ -42,6 +42,7 @@ objc_library(
|
|||
"//mediapipe/framework:calculator_options_cc_proto",
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator_cc_proto",
|
||||
":MPPTaskOptions",
|
||||
":MPPTaskOptionsProtocol",
|
||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||
|
@ -81,3 +82,14 @@ objc_library(
|
|||
"//mediapipe/framework:calculator_options_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPTaskManager",
|
||||
srcs = ["sources/MPPTaskManager.mm"],
|
||||
hdrs = ["sources/MPPTaskManager.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
#import <Foundation/Foundation.h>
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"
|
||||
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
|
|
|
@ -37,7 +37,6 @@ using FlowLimiterCalculatorOptions = ::mediapipe::FlowLimiterCalculatorOptions;
|
|||
taskOptions:(id<MPPTaskOptionsProtocol>)taskOptions
|
||||
enableFlowLimiting:(BOOL)enableFlowLimiting
|
||||
error:(NSError **)error {
|
||||
self = [super init];
|
||||
if (!taskGraphName || !inputStreams.count || !outputStreams.count) {
|
||||
[MPPCommonUtils
|
||||
createCustomError:error
|
||||
|
@ -46,6 +45,8 @@ using FlowLimiterCalculatorOptions = ::mediapipe::FlowLimiterCalculatorOptions;
|
|||
@"Task graph's name, input streams, and output streams should be non-empty."];
|
||||
}
|
||||
|
||||
self = [super init];
|
||||
|
||||
if (self) {
|
||||
_taskGraphName = taskGraphName;
|
||||
_inputStreams = inputStreams;
|
||||
|
|
47
mediapipe/tasks/ios/core/sources/MPPTaskManager.h
Normal file
47
mediapipe/tasks/ios/core/sources/MPPTaskManager.h
Normal file
|
@ -0,0 +1,47 @@
|
|||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* The base class of the user-facing iOS mediapipe text task api classes.
|
||||
*/
|
||||
@interface MPPTaskManager : NSObject
|
||||
/**
|
||||
* Initializes a new `MPPBaseTextTaskApi` with the mediapipe text task graph config proto.
|
||||
*
|
||||
* @param graphConfig A mediapipe text task graph config proto.
|
||||
*
|
||||
* @return An instance of `MPPBaseTextTaskApi` initialized to the given graph config proto.
|
||||
*/
|
||||
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
|
||||
error:(NSError **)error;
|
||||
|
||||
- (absl::StatusOr<mediapipe::tasks::core::PacketMap>)process:(const mediapipe::tasks::core::PacketMap&)packetMap error:(NSError **)error;
|
||||
|
||||
- (void)close;
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
+ (instancetype)new NS_UNAVAILABLE;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
56
mediapipe/tasks/ios/core/sources/MPPTaskManager.mm
Normal file
56
mediapipe/tasks/ios/core/sources/MPPTaskManager.mm
Normal file
|
@ -0,0 +1,56 @@
|
|||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h"
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||
|
||||
namespace {
|
||||
using ::mediapipe::CalculatorGraphConfig;
|
||||
using ::mediapipe::Packet;
|
||||
using ::mediapipe::tasks::core::PacketMap;
|
||||
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
|
||||
} // namespace
|
||||
|
||||
@interface MPPTaskManager () {
|
||||
/** TextSearcher backed by C++ API */
|
||||
std::unique_ptr<TaskRunnerCpp> _cppTaskRunner;
|
||||
}
|
||||
@end
|
||||
|
||||
@implementation MPPTaskManager
|
||||
|
||||
- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig
|
||||
error:(NSError **)error {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig));
|
||||
|
||||
if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
_cppTaskRunner = std::move(taskRunnerResult.value());
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (absl::StatusOr<PacketMap>)process:(const PacketMap&)packetMap {
|
||||
return _cppTaskRunner->Process(packetMap);
|
||||
}
|
||||
|
||||
- (void)close {
|
||||
_cppTaskRunner->Close();
|
||||
}
|
||||
|
||||
@end
|
|
@ -28,7 +28,7 @@
|
|||
- (instancetype)initWithModelPath:(NSString *)modelPath {
|
||||
self = [self init];
|
||||
if (self) {
|
||||
_baseOptions.modelAssetFile.filePath = modelPath;
|
||||
_baseOptions.modelAssetPath = modelPath;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
|
|
@ -21,8 +21,8 @@ using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions;
|
|||
@implementation MPPBaseOptions (Helpers)
|
||||
|
||||
- (void)copyToProto:(BaseOptionsProto *)baseOptionsProto {
|
||||
if (self.modelAssetFile.filePath) {
|
||||
baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetFile.filePath.UTF8String);
|
||||
if (self.modelAssetPath) {
|
||||
baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String);
|
||||
}
|
||||
|
||||
switch (self.delegate) {
|
||||
|
|
|
@ -27,7 +27,7 @@ objc_library(
|
|||
deps = [
|
||||
"//mediapipe/tasks/ios/core:MPPTaskOptions",
|
||||
"//mediapipe/tasks/ios/core:MPPTaskInfo",
|
||||
"//mediapipe/tasks/ios/text/core:MPPBaseTextTaskApi",
|
||||
"//mediapipe/tasks/ios/core:MPPTaskManager",
|
||||
"//mediapipe/tasks/ios/core:MPPPacketCreator",
|
||||
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers",
|
||||
"//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers",
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
|
||||
#import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h"
|
||||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
@ -25,7 +24,7 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
* A Mediapipe iOS Text Classifier.
|
||||
*/
|
||||
NS_SWIFT_NAME(TextClassifier)
|
||||
@interface MPPTextClassifier : MPPBaseTextTaskApi
|
||||
@interface MPPTextClassifier : NSObject
|
||||
|
||||
/**
|
||||
* Creates a new instance of `MPPTextClassifier` from an absolute path to a TensorFlow Lite model
|
||||
|
@ -53,7 +52,7 @@ NS_SWIFT_NAME(TextClassifier)
|
|||
*/
|
||||
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error;
|
||||
|
||||
- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error;
|
||||
- (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error;
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user