From 8d9c1b8a0f12a29c04b3d0001a3999a340f28327 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 1 Dec 2022 09:13:05 +0530 Subject: [PATCH] Added APIs for text classification --- mediapipe/tasks/ios/common/utils/BUILD | 3 + .../common/utils/sources/MPPCommonUtils.mm | 12 +++- .../common/utils/sources/NSString+Helpers.h | 2 + .../common/utils/sources/NSString+Helpers.mm | 4 ++ .../tasks/ios/components/containers/BUILD | 1 + .../containers/sources/MPPCategory.h | 2 +- .../containers/sources/MPPCategory.m | 4 +- .../sources/MPPClassificationResult.h | 8 ++- .../sources/MPPClassificationResult.m | 5 +- .../ios/components/containers/utils/BUILD | 40 +++++++++++ .../utils/sources/MPPCategory+Helpers.h | 26 ++++++++ .../utils/sources/MPPCategory+Helpers.mm | 42 ++++++++++++ .../sources/MPPClassificationResult+Helpers.h | 35 ++++++++++ .../MPPClassificationResult+Helpers.mm | 66 +++++++++++++++++++ mediapipe/tasks/ios/core/BUILD | 20 ++++++ .../tasks/ios/core/sources/MPPPacketCreator.h | 29 ++++++++ .../ios/core/sources/MPPPacketCreator.mm | 29 ++++++++ .../tasks/ios/core/sources/MPPTaskResult.h | 31 +++++++++ .../tasks/ios/core/sources/MPPTaskResult.m | 27 ++++++++ .../text/core/sources/MPPBaseTextTaskApi.h | 6 +- .../text/core/sources/MPPBaseTextTaskApi.mm | 10 +-- mediapipe/tasks/ios/text/core/utils/BUILD | 33 ++++++++++ .../tasks/ios/text/text_classifier/BUILD | 3 + .../sources/MPPTextClassifier.h | 3 + .../sources/MPPTextClassifier.mm | 25 +++++++ 25 files changed, 450 insertions(+), 16 deletions(-) create mode 100644 mediapipe/tasks/ios/components/containers/utils/BUILD create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h create mode 100644 mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm create mode 100644 mediapipe/tasks/ios/core/sources/MPPPacketCreator.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPPacketCreator.mm create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskResult.h create mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskResult.m create mode 100644 mediapipe/tasks/ios/text/core/utils/BUILD diff --git a/mediapipe/tasks/ios/common/utils/BUILD b/mediapipe/tasks/ios/common/utils/BUILD index 9c8c75586..f2ffda39e 100644 --- a/mediapipe/tasks/ios/common/utils/BUILD +++ b/mediapipe/tasks/ios/common/utils/BUILD @@ -23,6 +23,9 @@ objc_library( deps = [ "//mediapipe/tasks/cc:common", "//mediapipe/tasks/ios/common:MPPCommon", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", ], ) diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 6b277ae0c..141270b6d 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -13,8 +13,16 @@ // limitations under the License. #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" + #import "mediapipe/tasks/ios/common/sources/MPPCommon.h" +#include + +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/cord.h" // from @com_google_absl + +#include "mediapipe/tasks/cc/common.h" + /** Error domain of TensorFlow Lite Support related errors. */ NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; @@ -60,8 +68,8 @@ NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; if (status.ok()) { return YES; } - // Payload of absl::Status created by the Media Pipe task library stores an appropriate value of the - // enum MPPiteSupportStatus. The integer value corresponding to the MPPiteSupportStatus enum + // Payload of absl::Status created by the Media Pipe task library stores an appropriate value of + // the enum MPPiteSupportStatus. The integer value corresponding to the MPPiteSupportStatus enum // stored in the payload is extracted here to later map to the appropriate error code to be // returned. In cases where the enum is not stored in (payload is NULL or the payload string // cannot be converted to an integer), we set the error code value to be 1 diff --git a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h index 31828a367..4c1e3af01 100644 --- a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h +++ b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h @@ -22,6 +22,8 @@ NS_ASSUME_NONNULL_BEGIN @property(readonly) std::string cppString; ++ (NSString *)stringWithCppString:(std::string)text; + @end NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm index 540903c6c..0720223dc 100644 --- a/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm +++ b/mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.mm @@ -20,4 +20,8 @@ return std::string(self.UTF8String, [self lengthOfBytesUsingEncoding:NSUTF8StringEncoding]); } ++ (NSString *)stringWithCppString:(std::string)text { + return [NSString stringWithCString:text.c_str() encoding:[NSString defaultCStringEncoding]]; +} + @end diff --git a/mediapipe/tasks/ios/components/containers/BUILD b/mediapipe/tasks/ios/components/containers/BUILD index ce80571e9..5d6bae220 100644 --- a/mediapipe/tasks/ios/components/containers/BUILD +++ b/mediapipe/tasks/ios/components/containers/BUILD @@ -28,5 +28,6 @@ objc_library( hdrs = ["sources/MPPClassificationResult.h"], deps = [ ":MPPCategory", + "//mediapipe/tasks/ios/core:MPPTaskResult", ], ) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h index e2f8a0729..b74a5edb5 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h @@ -18,7 +18,7 @@ NS_ASSUME_NONNULL_BEGIN /** Encapsulates information about a class in the classification results. */ NS_SWIFT_NAME(ClassificationCategory) -@interface TFLCategory : NSObject +@interface MPPCategory : NSObject /** Index of the class in the corresponding label map, usually packed in the TFLite Model * Metadata. */ diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m index 30efa239a..a14c38b2c 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m @@ -12,9 +12,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#import "mediapipe/tasks/ios/components/containers/sources/TFLCategory.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" -@implementation TFLCategory +@implementation MPPCategory - (instancetype)initWithIndex:(NSInteger)index score:(float)score diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h index 120780a7b..986a8cc1f 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #import -#import "mediapipe/tasks/ios/task/components/containers/sources/MPPCategory.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" NS_ASSUME_NONNULL_BEGIN @@ -71,7 +72,7 @@ NS_SWIFT_NAME(Classifications) /** Encapsulates results of any classification task. */ NS_SWIFT_NAME(ClassificationResult) -@interface MPPClassificationResult : NSObject +@interface MPPClassificationResult : MPPTaskResult /** Array of MPPClassifications objects containing classifier predictions per image classifier * head. @@ -87,7 +88,8 @@ NS_SWIFT_NAME(ClassificationResult) * @return An instance of MPPClassificationResult initialized with the given array of * classifications. */ -- (instancetype)initWithClassifications:(NSArray *)classifications; +- (instancetype)initWithClassifications:(NSArray *)classifications + timeStamp:(long)timeStamp; @end diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m index 266b401e0..df3e6a52d 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m @@ -39,8 +39,9 @@ limitations under the License. NSArray *_classifications; } -- (instancetype)initWithClassifications:(NSArray *)classifications { - self = [super init]; +- (instancetype)initWithClassifications:(NSArray *)classifications + timeStamp:(long)timeStamp { + self = [super initWithTimeStamp:timeStamp]; if (self) { _classifications = classifications; } diff --git a/mediapipe/tasks/ios/components/containers/utils/BUILD b/mediapipe/tasks/ios/components/containers/utils/BUILD new file mode 100644 index 000000000..a61dd6ca0 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/BUILD @@ -0,0 +1,40 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCategoryHelpers", + srcs = ["sources/MPPCategory+Helpers.mm"], + hdrs = ["sources/MPPCategory+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/components/containers:MPPCategory", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + ], +) + +objc_library( + name = "MPPClassificationResultHelpers", + srcs = ["sources/MPPClassificationResult+Helpers.mm"], + hdrs = ["sources/MPPClassificationResult+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/components/containers:MPPClassificationResult", + ":MPPCategoryHelpers", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + ], +) diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h new file mode 100644 index 000000000..874c751ac --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h @@ -0,0 +1,26 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#include "mediapipe/framework/formats/classification.pb.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPCategory (Helpers) + ++ (MPPCategory *)categoryWithProto:(const mediapipe::Classification &)classificationProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm new file mode 100644 index 000000000..24d250795 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm @@ -0,0 +1,42 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h" + +namespace { +using ClassificationProto = ::mediapipe::Classification; +} + +@implementation MPPCategory (Helpers) + ++ (MPPCategory *)categoryWithProto:(const ClassificationProto &)clasificationProto { + NSString *label; + NSString *displayName; + + if (clasificationProto.has_label()) { + label = [NSString stringWithCppString:clasificationProto.label()]; + } + + if (clasificationProto.has_display_name()) { + displayName = [NSString stringWithCppString:clasificationProto.display_name()]; + } + + return [[MPPCategory alloc] initWithIndex:clasificationProto.index() + score:clasificationProto.score() + label:label + displayName:displayName]; +} + +@end diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h new file mode 100644 index 000000000..5b19447ac --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h @@ -0,0 +1,35 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPClassifications (Helpers) + ++ (MPPClassifications *)classificationsWithProto: + (const mediapipe::tasks::components::containers::proto::Classifications &)classificationsProto; + +@end + +@interface MPPClassificationResult (Helpers) + ++ (MPPClassificationResult *)classificationResultWithProto: + (const mediapipe::tasks::components::containers::proto::ClassificationResult &) + classificationResultProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm new file mode 100644 index 000000000..0e9e599d7 --- /dev/null +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm @@ -0,0 +1,66 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" + +namespace { +using ClassificationsProto = ::mediapipe::tasks::components::containers::proto::Classifications; +using ClassificationResultProto = + ::mediapipe::tasks::components::containers::proto::ClassificationResult; +} // namespace + +@implementation MPPClassifications (Helpers) + ++ (MPPClassifications *)classificationsWithProto: + (const ClassificationsProto &)classificationsProto { + NSMutableArray *categories = [[NSMutableArray alloc] init]; + for (const auto &classification : classificationsProto.classification_list().classification()) { + [categories addObject:[MPPCategory categoryWithProto:classification]]; + } + + NSString *headName; + + if (classificationsProto.has_head_name()) { + headName = [NSString stringWithCppString:classificationsProto.head_name()]; + } + + return [[MPPClassifications alloc] initWithHeadIndex:(NSInteger)classificationsProto.head_index() + headName:headName + categories:categories]; +} + +@end + +@implementation MPPClassificationResult (Helpers) + ++ (MPPClassificationResult *)classificationResultWithProto: + (const ClassificationResultProto &)classificationResultProto { + NSMutableArray *classifications = [[NSMutableArray alloc] init]; + for (const auto &classifications_proto : classificationResultProto.classifications()) { + [classifications addObject:[MPPClassifications classificationsWithProto:classifications_proto]]; + } + + long timeStamp; + + if (classificationResultProto.has_timestamp_ms()) { + timeStamp = classificationResultProto.timestamp_ms(); + } + + return [[MPPClassificationResult alloc] initWithClassifications:classifications + timeStamp:timeStamp]; +} + +@end diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index 91257b4de..c6def1685 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -65,3 +65,23 @@ objc_library( ], ) +objc_library( + name = "MPPPacketCreator", + srcs = ["sources/MPPPacketCreator.mm"], + hdrs = ["sources/MPPPacketCreator.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + ], +) + +objc_library( + name = "MPPTaskResult", + srcs = ["sources/MPPTaskResult.m"], + hdrs = ["sources/MPPTaskResult.h"], +) + diff --git a/mediapipe/tasks/ios/core/sources/MPPPacketCreator.h b/mediapipe/tasks/ios/core/sources/MPPPacketCreator.h new file mode 100644 index 000000000..ecd0c5bfd --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPPacketCreator.h @@ -0,0 +1,29 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#ifndef __cplusplus +#error This header can only be included by an Objective-C++ file. +#endif + +#include "mediapipe/framework/packet.h" + +/// This class is an Objective-C wrapper around a MediaPipe graph object, and +/// helps interface it with iOS technologies such as AVFoundation. +@interface MPPPacketCreator : NSObject + ++ (mediapipe::Packet)createWithText:(NSString *)text; + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPPacketCreator.mm b/mediapipe/tasks/ios/core/sources/MPPPacketCreator.mm new file mode 100644 index 000000000..6ce5a5139 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPPacketCreator.mm @@ -0,0 +1,29 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" + +namespace { +using ::mediapipe::MakePacket; +using ::mediapipe::Packet; +} // namespace + +@implementation MPPPacketCreator + ++ (Packet)createWithText:(NSString *)text { + return MakePacket(text.cppString); +} + +@end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h new file mode 100644 index 000000000..e4845c26d --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h @@ -0,0 +1,31 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import + +NS_ASSUME_NONNULL_BEGIN + +/** + * MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend + * this class. + */ +NS_SWIFT_NAME(TaskResult) +@interface MPPTaskResult : NSObject +/** + * Base options for configuring the Mediapipe task. + */ +@property(nonatomic, assign, readonly) long timeStamp; + +- (instancetype)initWithTimeStamp:(long)timeStamp; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m new file mode 100644 index 000000000..6a79ea7a9 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m @@ -0,0 +1,27 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" + +@implementation MPPTaskResult + +- (instancetype)initWithTimeStamp:(long)timeStamp { + self = [self init]; + if (self) { + _timeStamp = timeStamp; + } + return self; +} + +@end diff --git a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h b/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h index 3d2fd4f43..405d25a81 100644 --- a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h +++ b/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h @@ -15,6 +15,7 @@ #import #include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" NS_ASSUME_NONNULL_BEGIN @@ -22,7 +23,10 @@ NS_ASSUME_NONNULL_BEGIN * The base class of the user-facing iOS mediapipe text task api classes. */ NS_SWIFT_NAME(BaseTextTaskApi) -@interface MPPBaseTextTaskApi : NSObject +@interface MPPBaseTextTaskApi : NSObject { + @protected + std::unique_ptr cppTaskRunner; +} /** * Initializes a new `MPPBaseTextTaskApi` with the mediapipe text task graph config proto. diff --git a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm b/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm index 9d7142fe5..5c05797da 100644 --- a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm +++ b/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm @@ -13,18 +13,18 @@ limitations under the License. ==============================================================================*/ #import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h" - -#include "mediapipe/tasks/cc/core/task_runner.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" namespace { using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::Packet; +using ::mediapipe::tasks::core::PacketMap; using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; } // namespace @interface MPPBaseTextTaskApi () { /** TextSearcher backed by C++ API */ - std::unique_ptr _taskRunner; + std::unique_ptr _cppTaskRunner; } @end @@ -40,13 +40,13 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; return nil; } - _taskRunner = std::move(taskRunnerResult.value()); + _cppTaskRunner = std::move(taskRunnerResult.value()); } return self; } - (void)close { - _taskRunner->Close(); + _cppTaskRunner->Close(); } @end diff --git a/mediapipe/tasks/ios/text/core/utils/BUILD b/mediapipe/tasks/ios/text/core/utils/BUILD new file mode 100644 index 000000000..abb8edc71 --- /dev/null +++ b/mediapipe/tasks/ios/text/core/utils/BUILD @@ -0,0 +1,33 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPBaseTextTaskApi", + srcs = ["sources/MPPBaseTextTaskApi.mm"], + hdrs = ["sources/MPPBaseTextTaskApi.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + ], +) + diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index e1f6eaab8..eb0800fcd 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -28,8 +28,11 @@ objc_library( "//mediapipe/tasks/ios/core:MPPTaskOptions", "//mediapipe/tasks/ios/core:MPPTaskInfo", "//mediapipe/tasks/ios/text/core:MPPBaseTextTaskApi", + "//mediapipe/tasks/ios/core:MPPPacketCreator", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", + "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", ":MPPTextClassifierOptions", ], ) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h index d6b2f770f..96d5887ff 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -14,6 +14,7 @@ ==============================================================================*/ #import +#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h" #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" @@ -52,6 +53,8 @@ NS_SWIFT_NAME(TextClassifier) */ - (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error; +- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error; + - (instancetype)init NS_UNAVAILABLE; + (instancetype)new NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm index 08557d94e..e61e6998d 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -15,9 +15,20 @@ #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" +#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" #import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" +#include "absl/status/statusor.h" + +namespace { +using ::mediapipe::Packet; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::core::PacketMap; +} // namespace + static NSString *const kClassificationsStreamName = @"classifications_out"; static NSString *const kClassificationsTag = @"classifications"; static NSString *const kTextInStreamName = @"text_in"; @@ -50,4 +61,18 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T return [self initWithOptions:options error:error]; } +- (MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error { + Packet packet = [MPPPacketCreator createWithText:text]; + absl::StatusOr output_packet_map = + cppTaskRunner->Process({{kTextInStreamName.cppString, packet}}); + + if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) { + return nil; + } + + return [MPPClassificationResult + classificationResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString] + .Get()]; +} + @end