From 7e0fec7c28eb25eb69793c5a33194b96ef8d1734 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 23 Dec 2022 17:52:00 +0530 Subject: [PATCH] Updated implementation of text classifier --- .../tasks/ios/components/containers/BUILD | 1 - .../sources/MPPClassificationResult.h | 6 +- .../sources/MPPClassificationResult.m | 6 +- .../MPPClassificationResult+Helpers.mm | 9 +-- .../tasks/ios/core/sources/MPPTaskManager.h | 47 ---------------- .../tasks/ios/core/sources/MPPTaskManager.mm | 56 ------------------- .../text_classifier/MPPTextClassifierTests.m | 38 ++++++++++++- .../tasks/ios/text/text_classifier/BUILD | 15 ++++- .../sources/MPPTextClassifier.h | 4 +- .../sources/MPPTextClassifier.mm | 31 +++++----- .../sources/MPPTextClassifierOptions.h | 28 +++++----- .../sources/MPPTextClassifierOptions.m | 14 ++--- .../sources/MPPTextClassifierResult.h | 41 ++++++++++++++ .../sources/MPPTextClassifierResult.m | 28 ++++++++++ .../ios/text/text_classifier/utils/BUILD | 10 ++++ .../MPPTextClassifierOptions+Helpers.h | 26 ++++----- .../MPPTextClassifierOptions+Helpers.mm | 26 ++++----- .../sources/MPPTextClassifierResult+Helpers.h | 28 ++++++++++ .../MPPTextClassifierResult+Helpers.mm | 39 +++++++++++++ 19 files changed, 264 insertions(+), 189 deletions(-) delete mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskManager.h delete mode 100644 mediapipe/tasks/ios/core/sources/MPPTaskManager.mm create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h create mode 100644 mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m create mode 100644 mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h create mode 100644 mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm diff --git a/mediapipe/tasks/ios/components/containers/BUILD b/mediapipe/tasks/ios/components/containers/BUILD index 5d6bae220..ce80571e9 100644 --- a/mediapipe/tasks/ios/components/containers/BUILD +++ b/mediapipe/tasks/ios/components/containers/BUILD @@ -28,6 +28,5 @@ objc_library( hdrs = ["sources/MPPClassificationResult.h"], deps = [ ":MPPCategory", - "//mediapipe/tasks/ios/core:MPPTaskResult", ], ) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h index b0e0c4073..24f99bfde 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h @@ -14,7 +14,6 @@ #import #import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" -#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" NS_ASSUME_NONNULL_BEGIN @@ -72,7 +71,7 @@ NS_SWIFT_NAME(Classifications) /** Encapsulates results of any classification task. */ NS_SWIFT_NAME(ClassificationResult) -@interface MPPClassificationResult : MPPTaskResult +@interface MPPClassificationResult : NSObject /** Array of MPPClassifications objects containing classifier predictions per image classifier * head. @@ -88,8 +87,7 @@ NS_SWIFT_NAME(ClassificationResult) * @return An instance of MPPClassificationResult initialized with the given array of * classifications. */ -- (instancetype)initWithClassifications:(NSArray *)classifications - timeStamp:(long)timeStamp; +- (instancetype)initWithClassifications:(NSArray *)classifications; @end diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m index e4e5eaac5..dd9c4e024 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m @@ -39,9 +39,9 @@ NSArray *_classifications; } -- (instancetype)initWithClassifications:(NSArray *)classifications - timeStamp:(long)timeStamp { - self = [super initWithTimeStamp:timeStamp]; +- (instancetype)initWithClassifications:(NSArray *)classifications { + + self = [super init]; if (self) { _classifications = classifications; } diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm index 0e9e599d7..84d5872d7 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm @@ -53,14 +53,7 @@ using ClassificationResultProto = [classifications addObject:[MPPClassifications classificationsWithProto:classifications_proto]]; } - long timeStamp; - - if (classificationResultProto.has_timestamp_ms()) { - timeStamp = classificationResultProto.timestamp_ms(); - } - - return [[MPPClassificationResult alloc] initWithClassifications:classifications - timeStamp:timeStamp]; + return [[MPPClassificationResult alloc] initWithClassifications:classifications]; } @end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskManager.h b/mediapipe/tasks/ios/core/sources/MPPTaskManager.h deleted file mode 100644 index f6dea201a..000000000 --- a/mediapipe/tasks/ios/core/sources/MPPTaskManager.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -#include "mediapipe/framework/calculator.pb.h" -#include "mediapipe/tasks/cc/core/task_runner.h" - - -NS_ASSUME_NONNULL_BEGIN - -/** - * The base class of the user-facing iOS mediapipe task api classes. - */ -@interface MPPTaskManager : NSObject -/** - * Initializes a new `MPPTaskManager` with the mediapipe task graph config proto. - * - * @param graphConfig A mediapipe task graph config proto. - * - * @return An instance of `MPPTaskManager` initialized to the given graph config proto. - */ -- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig - error:(NSError **)error; - -- (absl::StatusOr)process:(const mediapipe::tasks::core::PacketMap&)packetMap error:(NSError **)error; - -- (void)close; - -- (instancetype)init NS_UNAVAILABLE; - -+ (instancetype)new NS_UNAVAILABLE; - -@end - -NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskManager.mm b/mediapipe/tasks/ios/core/sources/MPPTaskManager.mm deleted file mode 100644 index 492ed8cf6..000000000 --- a/mediapipe/tasks/ios/core/sources/MPPTaskManager.mm +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h" -#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" - -namespace { -using ::mediapipe::CalculatorGraphConfig; -using ::mediapipe::Packet; -using ::mediapipe::tasks::core::PacketMap; -using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; -} // namespace - -@interface MPPTaskManager () { - /** TextSearcher backed by C++ API */ - std::unique_ptr _cppTaskRunner; -} -@end - -@implementation MPPTaskManager - -- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig - error:(NSError **)error { - self = [super init]; - if (self) { - auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig)); - - if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) { - return nil; - } - - _cppTaskRunner = std::move(taskRunnerResult.value()); - } - return self; -} - -- (absl::StatusOr)process:(const PacketMap&)packetMap { - return _cppTaskRunner->Process(packetMap); -} - -- (void)close { - _cppTaskRunner->Close(); -} - -@end diff --git a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m index 3808009f3..fa04c3e65 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m +++ b/mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m @@ -19,6 +19,28 @@ NS_ASSUME_NONNULL_BEGIN static NSString *const kBertTextClassifierModelName = @"bert_text_classifier"; +static NSString *const kNegativeText = @"unflinchingly bleak and desperate"; + +#define VerifyCategory(category, expectedIndex, expectedScore, expectedLabel, expectedDisplayName) \ + XCTAssertEqual(category.index, expectedIndex); \ + XCTAssertEqualWithAccuracy(category.score, expectedScore, 1e-6); \ + XCTAssertEqualObjects(category.label, expectedLabel); \ + XCTAssertEqualObjects(category.displayName, expectedDisplayName); + +#define VerifyClassifications(classifications, expectedHeadIndex, expectedCategoryCount) \ + XCTAssertEqual(classifications.categories.count, expectedCategoryCount); + +#define VerifyClassificationResult(classificationResult, expectedClassificationsCount) \ + XCTAssertNotNil(classificationResult); \ + XCTAssertEqual(classificationResult.classifications.count, expectedClassificationsCount) + +#define AssertClassificationResultHasOneHead(classificationResult) \ + XCTAssertNotNil(classificationResult); \ + XCTAssertEqual(classificationResult.classifications.count, 1); + XCTAssertEqual(classificationResult.classifications[0].headIndex, 1); + +#define AssertTextClassifierResultIsNotNil(textClassifierResult) \ + XCTAssertNotNil(textClassifierResult); @interface MPPTextClassifierTests : XCTestCase @end @@ -41,15 +63,25 @@ static NSString *const kBertTextClassifierModelName = @"bert_text_classifier"; - (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName { NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; MPPTextClassifierOptions *textClassifierOptions = - [[MPPTextClassifierOptions alloc] initWithModelPath:modelPath]; + [[MPPTextClassifierOptions alloc] init]; + textClassifierOptions.baseOptions.modelAssetPath = modelPath; return textClassifierOptions; } -- (void)testCreateTextClassifierOptionsSucceeds { - MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:kBertTextClassifierModelName]; +kBertTextClassifierModelName + +- (MPPTextClassifier *)createTextClassifierFromOptionsWithModelName:(NSString *)modelName { + MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:modelName]; MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; XCTAssertNotNil(textClassifier); + + return textClassifier +} + +- (void)classifyWithBertSucceeds { + MPPTextClassifier *textClassifier = [self createTextClassifierWithModelName:kBertTextClassifierModelName]; + MPPTextClassifierResult *textClassifierResult = [textClassifier classifyWithText:kNegativeText]; } @end diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index 3427e3a6f..61eecb9cd 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -27,10 +27,10 @@ objc_library( deps = [ "//mediapipe/tasks/ios/core:MPPTaskOptions", "//mediapipe/tasks/ios/core:MPPTaskInfo", - "//mediapipe/tasks/ios/core:MPPTaskManager", - "//mediapipe/tasks/ios/core:MPPPacketCreator", + "//mediapipe/tasks/ios/core:MPPTaskRunner", + "//mediapipe/tasks/ios/core:MPPTextPacketCreator", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", - "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", + "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:NSStringHelpers", ":MPPTextClassifierOptions", @@ -51,3 +51,12 @@ objc_library( ], ) +objc_library( + name = "MPPTextClassifierResult", + srcs = ["sources/MPPTextClassifierResult.m"], + hdrs = ["sources/MPPTextClassifierResult.h"], + deps = [ + "//mediapipe/tasks/ios/core:MPPTaskResult", + "//mediapipe/tasks/ios/components/containers:MPPClassificationResult", + ], +) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h index 0c33a5288..19e10e35f 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -14,7 +14,7 @@ ==============================================================================*/ #import -#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" @@ -52,7 +52,7 @@ NS_SWIFT_NAME(TextClassifier) */ - (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error; -- (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error; +- (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm index b4cd66f70..b9e76fc69 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -15,9 +15,9 @@ #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" -#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" -#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h" -#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h" +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" +#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" #import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" @@ -30,14 +30,14 @@ using ::mediapipe::tasks::core::PacketMap; } // namespace static NSString *const kClassificationsStreamName = @"classifications_out"; -static NSString *const kClassificationsTag = @"classifications"; +static NSString *const kClassificationsTag = @"CLASSIFICATIONS"; static NSString *const kTextInStreamName = @"text_in"; static NSString *const kTextTag = @"TEXT"; static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph"; @interface MPPTextClassifier () { /** TextSearcher backed by C++ API */ - MPPTaskManager *_taskManager; + MPPTaskRunner *_taskRunner; } @end @@ -47,8 +47,8 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] initWithTaskGraphName:kTaskGraphName - inputStreams:@[ [NSString stringWithFormat:@"@:@", kTextTag, kTextInStreamName] ] - outputStreams:@[ [NSString stringWithFormat:@"@:@", kClassificationsTag, + inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ] + outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kClassificationsTag, kClassificationsStreamName] ] taskOptions:options enableFlowLimiting:NO @@ -58,7 +58,7 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T return nil; } - _taskManager = [[MPPTaskManager alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error]; + _taskRunner = [[MPPTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error]; self = [super init]; @@ -66,22 +66,23 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T } - (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { - MPPTextClassifierOptions *options = - [[MPPTextClassifierOptions alloc] initWithModelPath:modelPath]; + MPPTextClassifierOptions *options = [[MPPTextClassifierOptions alloc] init]; + + options.baseOptions.modelAssetPath = modelPath; return [self initWithOptions:options error:error]; } -- (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error { - Packet packet = [MPPPacketCreator createWithText:text]; +- (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error { + Packet packet = [MPPTextPacketCreator createWithText:text]; - absl::StatusOr output_packet_map = [_taskManager process:{{kTextInStreamName.cppString, packet}} error:error]; + absl::StatusOr output_packet_map = [_taskRunner process:{{kTextInStreamName.cppString, packet}} error:error]; if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) { return nil; } - return [MPPClassificationResult - classificationResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString] + return [MPPTextClassifierResult + textClassifierResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString] .Get()]; } diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h index 47c44dd0d..374226998 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h @@ -31,20 +31,20 @@ NS_SWIFT_NAME(TextClassifierOptions) */ @property(nonatomic, copy) MPPClassifierOptions *classifierOptions; -/** - * Initializes a new `MPPTextClassifierOptions` with the absolute path to the model file - * stored locally on the device, set to the given the model path. - * - * @discussion The external model file must be a single standalone TFLite file. It could be packed - * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the - * necessary metadata and associated files might result in errors. Check the [documentation] - * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement. - * - * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. - * - * @return An instance of `MPPTextClassifierOptions` initialized to the given model path. - */ -- (instancetype)initWithModelPath:(NSString *)modelPath; +// /** +// * Initializes a new `MPPTextClassifierOptions` with the absolute path to the model file +// * stored locally on the device, set to the given the model path. +// * +// * @discussion The external model file must be a single standalone TFLite file. It could be packed +// * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the +// * necessary metadata and associated files might result in errors. Check the [documentation] +// * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement. +// * +// * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. +// * +// * @return An instance of `MPPTextClassifierOptions` initialized to the given model path. +// */ +// - (instancetype)initWithModelPath:(NSString *)modelPath; @end diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m index 8cab693cd..82e9bed64 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m @@ -16,12 +16,12 @@ @implementation MPPTextClassifierOptions -- (instancetype)initWithModelPath:(NSString *)modelPath { - self = [super initWithModelPath:modelPath]; - if (self) { - _classifierOptions = [[MPPClassifierOptions alloc] init]; - } - return self; -} +// - (instancetype)initWithModelPath:(NSString *)modelPath { +// self = [super initWithModelPath:modelPath]; +// if (self) { +// _classifierOptions = [[MPPClassifierOptions alloc] init]; +// } +// return self; +// } @end \ No newline at end of file diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h new file mode 100644 index 000000000..414e6d9c6 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h @@ -0,0 +1,41 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Encapsulates list of predicted classes (aka labels) for a given image classifier head. */ +NS_SWIFT_NAME(TextClassifierResult) +@interface MPPTextClassifierResult : MPPTaskResult + +@property(nonatomic, readonly) MPPClassificationResult *classificationResult; + +/** + * Initializes a new `MPPClassificationResult` with the given array of classifications. + * + * @param classifications An Aaray of `MPPClassifications` objects containing classifier + * predictions per classifier head. + * + * @return An instance of MPPClassificationResult initialized with the given array of + * classifications. + */ +- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult + timeStamp:(long)timeStamp; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m new file mode 100644 index 000000000..b99ee3b19 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m @@ -0,0 +1,28 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" + +@implementation MPPTextClassifierResult + +- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult + timeStamp:(long)timeStamp { + self = [super initWithTimestamp:timeStamp]; + if (self) { + _classificationResult = classificationResult; + } + return self; +} + +@end diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD index 662e76c2a..d6a371137 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD @@ -28,3 +28,13 @@ objc_library( "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", ], ) + +objc_library( + name = "MPPTextClassifierResultHelpers", + srcs = ["sources/MPPTextClassifierResult+Helpers.mm"], + hdrs = ["sources/MPPTextClassifierResult+Helpers.h"], + deps = [ + "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierResult", + "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", + ], +) diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h index 71076da26..0771eafce 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm index 3576cb8d2..aa11384d2 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h" #import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h" #import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h new file mode 100644 index 000000000..d3fb04d69 --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h @@ -0,0 +1,28 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPTextClassifierResult (Helpers) + ++ (MPPTextClassifierResult *)textClassifierResultWithProto: + (const mediapipe::tasks::components::containers::proto::ClassificationResult &) + classificationResultProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm new file mode 100644 index 000000000..2fc2d751d --- /dev/null +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm @@ -0,0 +1,39 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" + +namespace { +using ClassificationResultProto = + ::mediapipe::tasks::components::containers::proto::ClassificationResult; +} // namespace + +@implementation MPPTextClassifierResult (Helpers) + ++ (MPPTextClassifierResult *)textClassifierResultWithProto: + (const ClassificationResultProto &)classificationResultProto { + long timeStamp; + + if (classificationResultProto.has_timestamp_ms()) { + timeStamp = classificationResultProto.timestamp_ms(); + } + + MPPClassificationResult *classificationResult = [MPPClassificationResult classificationResultWithProto:classificationResultProto]; + + return [[MPPTextClassifierResult alloc] initWithClassificationResult:classificationResult + timeStamp:timeStamp]; +} + +@end