From 96247ccce484afad99d01a0434f818313ac7102d Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 14 Dec 2022 18:59:26 +0530 Subject: [PATCH] Added iOS task manager --- mediapipe/tasks/ios/core/BUILD | 12 ++++ .../tasks/ios/core/sources/MPPTaskInfo.h | 2 + .../tasks/ios/core/sources/MPPTaskInfo.mm | 3 +- .../tasks/ios/core/sources/MPPTaskManager.h | 47 ++++++++++++++++ .../tasks/ios/core/sources/MPPTaskManager.mm | 56 +++++++++++++++++++ .../tasks/ios/core/sources/MPPTaskOptions.m | 2 +- .../utils/sources/MPPBaseOptions+Helpers.mm | 4 +- .../tasks/ios/text/text_classifier/BUILD | 2 +- .../sources/MPPTextClassifier.h | 5 +- 9 files changed, 125 insertions(+), 8 deletions(-) create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskManager.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskManager.mm diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index 73fcacc37..666b0e6e1 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -42,6 +42,7 @@ objc_library( "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", ":MPPTaskOptions", + ":MPPTaskOptionsProtocol", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:NSStringHelpers", "//mediapipe/tasks/ios/common:MPPCommon", @@ -81,3 +82,14 @@ objc_library( "//mediapipe/framework:calculator_options_cc_proto", ], ) + +objc_library( + name = "MPPTaskManager", + srcs = ["sources/MPPTaskManager.mm"], + hdrs = ["sources/MPPTaskManager.h"], + deps = [ + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + ], +) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h index 620184518..a6ba4c4bd 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h @@ -12,6 +12,8 @@ #import #include "mediapipe/framework/calculator.pb.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" + NS_ASSUME_NONNULL_BEGIN diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm index 7e42d6eae..ed8e814d2 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm @@ -37,7 +37,6 @@ using FlowLimiterCalculatorOptions = ::mediapipe::FlowLimiterCalculatorOptions; taskOptions:(id)taskOptions enableFlowLimiting:(BOOL)enableFlowLimiting error:(NSError **)error { - self = [super init]; if (!taskGraphName || !inputStreams.count || !outputStreams.count) { [MPPCommonUtils createCustomError:error @@ -46,6 +45,8 @@ using FlowLimiterCalculatorOptions = ::mediapipe::FlowLimiterCalculatorOptions; @"Task graph's name, input streams, and output streams should be non-empty."]; } + self = [super init]; + if (self) { _taskGraphName = taskGraphName; _inputStreams = inputStreams; diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskManager.h b/mediapipe/tasks/ios/core/sources/MPPTaskManager.h new file mode 100644 index 000000000..b4ba02edd --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskManager.h @@ -0,0 +1,47 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import + +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" + + +NS_ASSUME_NONNULL_BEGIN + +/** + * The base class of the user-facing iOS mediapipe text task api classes. + */ +@interface MPPTaskManager : NSObject +/** + * Initializes a new `MPPBaseTextTaskApi` with the mediapipe text task graph config proto. + * + * @param graphConfig A mediapipe text task graph config proto. + * + * @return An instance of `MPPBaseTextTaskApi` initialized to the given graph config proto. + */ +- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig + error:(NSError **)error; + +- (absl::StatusOr)process:(const mediapipe::tasks::core::PacketMap&)packetMap error:(NSError **)error; + +- (void)close; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskManager.mm b/mediapipe/tasks/ios/core/sources/MPPTaskManager.mm new file mode 100644 index 000000000..2bf23d428 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskManager.mm @@ -0,0 +1,56 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h" +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" + +namespace { +using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::Packet; +using ::mediapipe::tasks::core::PacketMap; +using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; +} // namespace + +@interface MPPTaskManager () { + /** TextSearcher backed by C++ API */ + std::unique_ptr _cppTaskRunner; +} +@end + +@implementation MPPTaskManager + +- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig + error:(NSError **)error { + self = [super init]; + if (self) { + auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig)); + + if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) { + return nil; + } + + _cppTaskRunner = std::move(taskRunnerResult.value()); + } + return self; +} + +- (absl::StatusOr)process:(const PacketMap&)packetMap { + return _cppTaskRunner->Process(packetMap); +} + +- (void)close { + _cppTaskRunner->Close(); +} + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m index ec1adbaf1..f71d275be 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m @@ -28,7 +28,7 @@ - (instancetype)initWithModelPath:(NSString *)modelPath { self = [self init]; if (self) { - _baseOptions.modelAssetFile.filePath = modelPath; + _baseOptions.modelAssetPath = modelPath; } return self; } diff --git a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm index f20f8602a..9fce15dfa 100644 --- a/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm +++ b/mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.mm @@ -21,8 +21,8 @@ using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions; @implementation MPPBaseOptions (Helpers) - (void)copyToProto:(BaseOptionsProto *)baseOptionsProto { - if (self.modelAssetFile.filePath) { - baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetFile.filePath.UTF8String); + if (self.modelAssetPath) { + baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String); } switch (self.delegate) { diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index eb0800fcd..3427e3a6f 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -27,7 +27,7 @@ objc_library( deps = [ "//mediapipe/tasks/ios/core:MPPTaskOptions", "//mediapipe/tasks/ios/core:MPPTaskInfo", - "//mediapipe/tasks/ios/text/core:MPPBaseTextTaskApi", + "//mediapipe/tasks/ios/core:MPPTaskManager", "//mediapipe/tasks/ios/core:MPPPacketCreator", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h index 96d5887ff..0c33a5288 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -16,7 +16,6 @@ #import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" -#import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h" #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" NS_ASSUME_NONNULL_BEGIN @@ -25,7 +24,7 @@ NS_ASSUME_NONNULL_BEGIN * A Mediapipe iOS Text Classifier. */ NS_SWIFT_NAME(TextClassifier) -@interface MPPTextClassifier : MPPBaseTextTaskApi +@interface MPPTextClassifier : NSObject /** * Creates a new instance of `MPPTextClassifier` from an absolute path to a TensorFlow Lite model @@ -53,7 +52,7 @@ NS_SWIFT_NAME(TextClassifier) */ - (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error; -- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error; +- (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error; - (instancetype)init NS_UNAVAILABLE;