From c8ebd21bd5698a9384d79084c2e90bf655aee9b1 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 5 Jan 2023 18:09:29 +0530 Subject: [PATCH] Updated implementation of iOS Text Classifier --- .../tasks/ios/components/containers/BUILD | 2 +- .../containers/sources/MPPCategory.h | 35 +-- .../containers/sources/MPPCategory.m | 6 +- .../sources/MPPClassificationResult.h | 68 +++-- .../sources/MPPClassificationResult.m | 18 +- .../ios/components/containers/utils/BUILD | 2 +- .../utils/sources/MPPCategory+Helpers.h | 26 +- .../utils/sources/MPPCategory+Helpers.mm | 32 +-- .../sources/MPPClassificationResult+Helpers.h | 26 +- .../MPPClassificationResult+Helpers.mm | 37 ++- mediapipe/tasks/ios/core/BUILD | 7 + .../tasks/ios/core/sources/MPPBaseOptions.h | 3 +- .../tasks/ios/core/sources/MPPBaseOptions.m | 4 +- .../ios/core/sources/MPPResultCallback.h | 21 ++ .../tasks/ios/core/sources/MPPTaskResult.h | 4 +- .../tasks/ios/core/sources/MPPTaskResult.m | 6 +- .../tasks/ios/core/sources/MPPTaskRunner.h | 46 ++- .../tasks/ios/core/sources/MPPTaskRunner.mm | 10 +- .../ios/test/{text/text_classifier => }/BUILD | 36 ++- .../tasks/ios/test/MPPTextClassifierTests.m | 110 +++++++ .../tasks/ios/test/TextClassifierTests.swift | 272 ++++++++++++++++++ .../text_classifier/MPPTextClassifierTests.m | 89 ------ mediapipe/tasks/ios/text/core/BUILD | 10 +- .../text/core/sources/MPPBaseTextTaskApi.h | 48 ---- .../text/core/sources/MPPBaseTextTaskApi.mm | 52 ---- .../ios/text/core/sources/MPPTextTaskRunner.h | 37 +++ .../text/core/sources/MPPTextTaskRunner.mm | 29 ++ mediapipe/tasks/ios/text/core/utils/BUILD | 33 --- .../tasks/ios/text/text_classifier/BUILD | 9 +- .../sources/MPPTextClassifier.h | 75 +++-- .../sources/MPPTextClassifier.mm | 61 ++-- .../sources/MPPTextClassifierOptions.h | 46 +-- .../sources/MPPTextClassifierOptions.m | 40 +-- .../sources/MPPTextClassifierResult.h | 20 +- .../sources/MPPTextClassifierResult.m | 6 +- .../ios/text/text_classifier/utils/BUILD | 3 +- .../MPPTextClassifierOptions+Helpers.h | 4 +- .../MPPTextClassifierOptions+Helpers.mm | 2 +- .../sources/MPPTextClassifierResult+Helpers.h | 10 +- .../MPPTextClassifierResult+Helpers.mm | 29 +- 40 files changed, 885 insertions(+), 489 deletions(-) create mode 100644 mediapipe/tasks/ios/core/sources/MPPResultCallback.h rename mediapipe/tasks/ios/test/{text/text_classifier => }/BUILD (56%) create mode 100644 mediapipe/tasks/ios/test/MPPTextClassifierTests.m create mode 100644 mediapipe/tasks/ios/test/TextClassifierTests.swift delete mode 100644 mediapipe/tasks/ios/test/text/text_classifier/MPPTextClassifierTests.m delete mode 100644 mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h delete mode 100644 mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm create mode 100644 mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h create mode 100644 mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm delete mode 100644 mediapipe/tasks/ios/text/core/utils/BUILD diff --git a/mediapipe/tasks/ios/components/containers/BUILD b/mediapipe/tasks/ios/components/containers/BUILD index ce80571e9..9d82fc55a 100644 --- a/mediapipe/tasks/ios/components/containers/BUILD +++ b/mediapipe/tasks/ios/components/containers/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h index 431b8a705..035cde09d 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.h @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,41 +16,44 @@ NS_ASSUME_NONNULL_BEGIN -/** Encapsulates information about a class in the classification results. */ +/** Category is a util class, contains a label, its display name, a float value as score, and the + * index of the label in the corresponding label file. Typically it's used as the result of + * classification tasks. */ NS_SWIFT_NAME(ClassificationCategory) @interface MPPCategory : NSObject -/** Index of the class in the corresponding label map, usually packed in the TFLite Model - * Metadata. */ +/** The index of the label in the corresponding label file. It takes the value -1 if the index is + * not set. */ @property(nonatomic, readonly) NSInteger index; /** Confidence score for this class . */ @property(nonatomic, readonly) float score; -/** Class name of the class. */ -@property(nonatomic, readonly, nullable) NSString *label; +/** The label of this category object. */ +@property(nonatomic, readonly, nullable) NSString *categoryName; -/** Display name of the class. */ +/** The display name of the label, which may be translated for different locales. For example, a + * label, "apple", may be translated into Spanish for display purpose, so that the display name is + * "manzana". */ @property(nonatomic, readonly, nullable) NSString *displayName; /** - * Initializes a new `TFLCategory` with the given index, score, label and display name. + * Initializes a new `MPPCategory` with the given index, score, category name and display name. * - * @param index Index of the class in the corresponding label map, usually packed in the TFLite - * Model Metadata. + * @param index The index of the label in the corresponding label file. * - * @param score Confidence score for this class. + * @param score The probability score of this label category. * - * @param label Class name of the class. + * @param categoryName The label of this category object.. * - * @param displayName Display name of the class. + * @param displayName The display name of the label. * - * @return An instance of `TFLCategory` initialized with the given index, score, label and display - * name. + * @return An instance of `MPPCategory` initialized with the given index, score, category name and + * display name. */ - (instancetype)initWithIndex:(NSInteger)index score:(float)score - label:(nullable NSString *)label + categoryName:(nullable NSString *)categoryName displayName:(nullable NSString *)displayName; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m index 20f745582..824fae65e 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPCategory.m @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,13 +18,13 @@ - (instancetype)initWithIndex:(NSInteger)index score:(float)score - label:(nullable NSString *)label + categoryName:(nullable NSString *)categoryName displayName:(nullable NSString *)displayName { self = [super init]; if (self) { _index = index; _score = score; - _label = label; + _categoryName = categoryName; _displayName = displayName; } return self; diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h index 24f99bfde..732d1f899 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,32 +17,27 @@ NS_ASSUME_NONNULL_BEGIN -/** Encapsulates list of predicted classes (aka labels) for a given image classifier head. */ +/** Represents the list of classification for a given classifier head. Typically used as a result + * for classification tasks. */ NS_SWIFT_NAME(Classifications) @interface MPPClassifications : NSObject -/** - * The index of the classifier head these classes refer to. This is useful for multi-head - * models. +/** The index of the classifier head these entries refer to. This is useful for multi-head models. */ @property(nonatomic, readonly) NSInteger headIndex; -/** The name of the classifier head, which is the corresponding tensor metadata - * name. - */ -@property(nonatomic, readonly) NSString *headName; +/** The optional name of the classifier head, which is the corresponding tensor metadata name. */ +@property(nonatomic, readonly, nullable) NSString *headName; -/** The array of predicted classes, usually sorted by descending scores (e.g.from high to low - * probability). */ +/** An array of `MPPCategory` objects containing the predicted categories. */ @property(nonatomic, readonly) NSArray *categories; /** - * Initializes a new `MPPClassifications` with the given head index and array of categories. - * head name is initialized to `nil`. + * Initializes a new `MPPClassifications` object with the given head index and array of categories. + * Head name is initialized to `nil`. * - * @param headIndex The index of the image classifier head these classes refer to. - * @param categories An array of `MPPCategory` objects encapsulating a list of - * predictions usually sorted by descending scores (e.g. from high to low probability). + * @param headIndex The index of the classifier head. + * @param categories An array of `MPPCategory` objects containing the predicted categories. * * @return An instance of `MPPClassifications` initialized with the given head index and * array of categories. @@ -54,11 +49,10 @@ NS_SWIFT_NAME(Classifications) * Initializes a new `MPPClassifications` with the given head index, head name and array of * categories. * - * @param headIndex The index of the classifier head these classes refer to. + * @param headIndex The index of the classifier head. * @param headName The name of the classifier head, which is the corresponding tensor metadata * name. - * @param categories An array of `MPPCategory` objects encapsulating a list of - * predictions usually sorted by descending scores (e.g. from high to low probability). + * @param categories An array of `MPPCategory` objects containing the predicted categories. * * @return An object of `MPPClassifications` initialized with the given head index, head name and * array of categories. @@ -69,17 +63,27 @@ NS_SWIFT_NAME(Classifications) @end -/** Encapsulates results of any classification task. */ +/** + * Represents the classification results of a model. Typically used as a result for classification + * tasks. + */ NS_SWIFT_NAME(ClassificationResult) @interface MPPClassificationResult : NSObject -/** Array of MPPClassifications objects containing classifier predictions per image classifier - * head. - */ +/** An Array of `MPPClassifications` objects containing the predicted categories for each head of + * the model. */ @property(nonatomic, readonly) NSArray *classifications; +/** The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to + * these results. If it is set to the value -1, it signifies the absence of a time stamp. This is + * only used for classification on time series (e.g. audio classification). In these use cases, the + * amount of data to process might exceed the maximum size that the model can process: to solve + * this, the input data is split into multiple chunks starting at different timestamps. */ +@property(nonatomic, readonly) NSInteger timestampMs; + /** - * Initializes a new `MPPClassificationResult` with the given array of classifications. + * Initializes a new `MPPClassificationResult` with the given array of classifications. This method + * must be used when no time stamp needs to be specified. It sets the property `timestampMs` to -1. * * @param classifications An Aaray of `MPPClassifications` objects containing classifier * predictions per classifier head. @@ -89,6 +93,22 @@ NS_SWIFT_NAME(ClassificationResult) */ - (instancetype)initWithClassifications:(NSArray *)classifications; +/** + * Initializes a new `MPPClassificationResult` with the given array of classifications and time + * stamp (in milliseconds). + * + * @param classifications An Array of `MPPClassifications` objects containing the predicted + * categories for each head of the model. + * + * @param timeStampMs The timestamp (in milliseconds) of the start of the chunk of data + * corresponding to these results. + * + * @return An instance of `MPPClassificationResult` initialized with the given array of + * classifications and timestampMs. + */ +- (instancetype)initWithClassifications:(NSArray *)classifications + timestampMs:(NSInteger)timestampMs; + @end NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m index dd9c4e024..6cf75234e 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -35,16 +35,22 @@ @end -@implementation MPPClassificationResult { - NSArray *_classifications; -} +@implementation MPPClassificationResult -- (instancetype)initWithClassifications:(NSArray *)classifications { - +- (instancetype)initWithClassifications:(NSArray *)classifications + timestampMs:(NSInteger)timestampMs { self = [super init]; if (self) { _classifications = classifications; + _timestampMs = timestampMs; } + + return self; +} + +- (instancetype)initWithClassifications:(NSArray *)classifications { + return [self initWithClassifications:classifications timestampMs:-1]; + return self; } diff --git a/mediapipe/tasks/ios/components/containers/utils/BUILD b/mediapipe/tasks/ios/components/containers/utils/BUILD index a61dd6ca0..e4c76ac4b 100644 --- a/mediapipe/tasks/ios/components/containers/utils/BUILD +++ b/mediapipe/tasks/ios/components/containers/utils/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h index 874c751ac..7580cfeeb 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #include "mediapipe/framework/formats/classification.pb.h" #import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm index 24d250795..f729d9720 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" #import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h" @@ -22,11 +22,11 @@ using ClassificationProto = ::mediapipe::Classification; @implementation MPPCategory (Helpers) + (MPPCategory *)categoryWithProto:(const ClassificationProto &)clasificationProto { - NSString *label; + NSString *categoryName; NSString *displayName; if (clasificationProto.has_label()) { - label = [NSString stringWithCppString:clasificationProto.label()]; + categoryName = [NSString stringWithCppString:clasificationProto.label()]; } if (clasificationProto.has_display_name()) { @@ -35,7 +35,7 @@ using ClassificationProto = ::mediapipe::Classification; return [[MPPCategory alloc] initWithIndex:clasificationProto.index() score:clasificationProto.score() - label:label + categoryName:categoryName displayName:displayName]; } diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h index 5b19447ac..fde436feb 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h" diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm index 84d5872d7..9ad284790 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" #import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h" #import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" @@ -53,7 +53,16 @@ using ClassificationResultProto = [classifications addObject:[MPPClassifications classificationsWithProto:classifications_proto]]; } - return [[MPPClassificationResult alloc] initWithClassifications:classifications]; + MPPClassificationResult *classificationResult; + + if (classificationResultProto.has_timestamp_ms()) { + classificationResult = [[MPPClassificationResult alloc] initWithClassifications:classifications timestampMs:(NSInteger)classificationResultProto.timestamp_ms()]; + } + else { + classificationResult = [[MPPClassificationResult alloc] initWithClassifications:classifications]; + } + + return classificationResult; } @end diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index 434d20085..efe481ea8 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -90,6 +90,13 @@ objc_library( deps = [ "//mediapipe/framework:calculator_cc_proto", "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", ], ) + +objc_library( + name = "MPPResultCallback", + hdrs = ["sources/MPPResultCallback.h"], +) diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h index 9c6595cfc..088d2d5da 100644 --- a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.h @@ -22,7 +22,7 @@ NS_ASSUME_NONNULL_BEGIN typedef NS_ENUM(NSUInteger, MPPDelegate) { /** CPU. */ MPPDelegateCPU, - + /** GPU. */ MPPDelegateGPU } NS_SWIFT_NAME(Delegate); @@ -46,4 +46,3 @@ NS_SWIFT_NAME(BaseOptions) @end NS_ASSUME_NONNULL_END - diff --git a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m index b2b027da7..eaf2aa895 100644 --- a/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m +++ b/mediapipe/tasks/ios/core/sources/MPPBaseOptions.m @@ -26,10 +26,10 @@ - (id)copyWithZone:(NSZone *)zone { MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init]; - + baseOptions.modelAssetPath = self.modelAssetPath; baseOptions.delegate = self.delegate; - + return baseOptions; } diff --git a/mediapipe/tasks/ios/core/sources/MPPResultCallback.h b/mediapipe/tasks/ios/core/sources/MPPResultCallback.h new file mode 100644 index 000000000..908f9edd9 --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPResultCallback.h @@ -0,0 +1,21 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +NS_ASSUME_NONNULL_BEGIN + +typedef void (^MPPResultCallback)(id oputput, id input, NSError *error); + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h index d15d4f258..4ee7b2fc6 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h @@ -26,11 +26,11 @@ NS_SWIFT_NAME(TaskResult) /** * Timestamp that is associated with the task result object. */ -@property(nonatomic, assign, readonly) long timestamp; +@property(nonatomic, assign, readonly) NSInteger timestampMs; - (instancetype)init NS_UNAVAILABLE; -- (instancetype)initWithTimestamp:(long)timestamp NS_DESIGNATED_INITIALIZER; +- (instancetype)initWithTimestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; @end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m index 7088eb246..6c08014ff 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m @@ -16,16 +16,16 @@ @implementation MPPTaskResult -- (instancetype)initWithTimestamp:(long)timestamp { +- (instancetype)initWithTimestampMs:(NSInteger)timestampMs { self = [super init]; if (self) { - _timestamp = timestamp; + _timestampMs = timestampMs; } return self; } - (id)copyWithZone:(NSZone *)zone { - return [[MPPTaskResult alloc] initWithTimestamp:self.timestamp]; + return [[MPPTaskResult alloc] initWithTimestampMs:self.timestampMs]; } @end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h index 2b9f2ecdb..97255234f 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -20,23 +20,63 @@ NS_ASSUME_NONNULL_BEGIN /** - * This class is used to create and call appropriate methods on the C++ Task Runner. + * This class is used to create and call appropriate methods on the C++ Task Runner to initialize, + * execute and terminate any Mediapipe task. + * + * An instance of the newly created C++ task runner will + * be stored until this class is destroyed. When methods are called for processing (performing + * inference), closing etc., on this class, internally the appropriate methods will be called on the + * C++ task runner instance to execute the appropriate actions. For each type of task, a subclass of + * this class must be defined to add any additional functionality. For eg:, vision tasks must create + * an `MPPVisionTaskRunner` and provide additional functionality. An instance of + * `MPPVisionTaskRunner` can in turn be used by the each vision task for creation and execution of + * the task. Please see the documentation for the C++ Task Runner for more details on how the taks + * runner operates. */ @interface MPPTaskRunner : NSObject /** - * Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto. + * Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto and an optional C++ + * packets callback. + * + * You can pass `nullptr` for `packetsCallback` in case the mode of operation + * requested by the user is synchronous. + * + * If the task is operating in asynchronous mode, any iOS Mediapipe task that uses the `MPPTaskRunner` + * must define a C++ callback function to obtain the results of inference asynchronously and deliver + * the results to the user. To accomplish this, callback function will in turn invoke the block + * provided by the user in the task options supplied to create the task. + * Please see the documentation of the C++ Task Runner for more information on the synchronous and + * asynchronous modes of operation. * * @param graphConfig A mediapipe task graph config proto. * - * @return An instance of `MPPTaskRunner` initialized to the given graph config proto. + * @param packetsCallback An optional C++ callback function that takes a list of output packets as + * the input argument. If provided, the callback must in turn call the block provided by the user in + * the appropriate task options. + * + * @return An instance of `MPPTaskRunner` initialized to the given graph config proto and optional + * packetsCallback. */ - (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig + packetsCallback: + (mediapipe::tasks::core::PacketsCallback)packetsCallback error:(NSError **)error NS_DESIGNATED_INITIALIZER; +/** A synchronous method for processing batch data or offline streaming data. This method is +designed for processing either batch data such as unrelated images and texts or offline streaming +data such as the decoded frames from a video file and an audio file. The call blocks the current +thread until a failure status or a successful result is returned. If the input packets have no +timestamp, an internal timestamp will be assigend per invocation. Otherwise, when the timestamp is +set in the input packets, the caller must ensure that the input packet timestamps are greater than +the timestamps of the previous invocation. This method is thread-unsafe and it is the caller's +responsibility to synchronize access to this method across multiple threads and to ensure that the +input packet timestamps are in order.*/ - (absl::StatusOr)process: (const mediapipe::tasks::core::PacketMap &)packetMap; +/** Shuts down the C++ task runner. After the runner is closed, any calls that send input data to + * the runner are illegal and will receive errors. */ - (absl::Status)close; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm index c5c307fd5..fd3f780fa 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm @@ -13,11 +13,15 @@ // limitations under the License. #import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#include "tensorflow/lite/core/api/op_resolver.h" namespace { using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::tasks::core::MediaPipeBuiltinOpResolver; using ::mediapipe::tasks::core::PacketMap; +using ::mediapipe::tasks::core::PacketsCallback; using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; } // namespace @@ -30,15 +34,17 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; @implementation MPPTaskRunner - (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig + packetsCallback:(PacketsCallback)packetsCallback error:(NSError **)error { self = [super init]; if (self) { - auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig)); + auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig), + absl::make_unique(), + std::move(packetsCallback)); if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) { return nil; } - _cppTaskRunner = std::move(taskRunnerResult.value()); } return self; diff --git a/mediapipe/tasks/ios/test/text/text_classifier/BUILD b/mediapipe/tasks/ios/test/BUILD similarity index 56% rename from mediapipe/tasks/ios/test/text/text_classifier/BUILD rename to mediapipe/tasks/ios/test/BUILD index 2202ff1a6..9df178a19 100644 --- a/mediapipe/tasks/ios/test/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/test/BUILD @@ -1,18 +1,15 @@ load( - "//mediapipe/tasks:ios/ios.bzl", - "MPP_TASK_MINIMUM_OS_VERSION", - "MPP_TASK_DEFAULT_TAGS", - "MPP_TASK_DISABLED_SANITIZER_TAGS", -) -load( - "@build_bazel_rules_apple//apple:ios.bzl", + "@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test", ) load( - "@org_tensorflow//tensorflow/lite:special_rules.bzl", + "@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner" ) +load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") + + package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) @@ -25,7 +22,7 @@ objc_library( "//mediapipe/tasks/testdata/text:bert_text_classifier_models", "//mediapipe/tasks/testdata/text:text_classifier_models", ], - tags = MPP_TASK_DEFAULT_TAGS, + tags = [], copts = [ "-ObjC++", "-std=c++17", @@ -38,10 +35,27 @@ objc_library( ios_unit_test( name = "MPPTextClassifierObjcTest", - minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION, + minimum_os_version = "11.0", runner = tflite_ios_lab_runner("IOS_LATEST"), - tags = MPP_TASK_DEFAULT_TAGS + MPP_TASK_DISABLED_SANITIZER_TAGS, + tags =[], deps = [ ":MPPTextClassifierObjcTestLibrary", ], ) + +swift_library( + name = "MPPTextClassifierSwiftTestLibrary", + testonly = 1, + srcs = ["TextClassifierTests.swift"], + tags = [], +) + +ios_unit_test( + name = "MPPTextClassifierSwiftTest", + minimum_os_version = "11.0", + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = [], + deps = [ + ":MPPTextClassifierSwiftTestLibrary", + ], +) diff --git a/mediapipe/tasks/ios/test/MPPTextClassifierTests.m b/mediapipe/tasks/ios/test/MPPTextClassifierTests.m new file mode 100644 index 000000000..d213fd97c --- /dev/null +++ b/mediapipe/tasks/ios/test/MPPTextClassifierTests.m @@ -0,0 +1,110 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#import + +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" + +static NSString *const kBertTextClassifierModelName = @"bert_text_classifier"; +static NSString *const kNegativeText = @"unflinchingly bleak and desperate"; +static NSString *const kPositiveText = @"it's a charming and often affecting journey"; + +#define AssertCategoriesAre(categories, expectedCategories) \ + XCTAssertEqual(categories.count, expectedCategories.count); \ + for (int i = 0; i < categories.count; i++) { \ + XCTAssertEqual(categories[i].index, expectedCategories[i].index); \ + XCTAssertEqualWithAccuracy(categories[i].score, expectedCategories[i].score, 1e-6); \ + XCTAssertEqualObjects(categories[i].categoryName, expectedCategories[i].categoryName); \ + XCTAssertEqualObjects(categories[i].displayName, expectedCategories[i].displayName); \ + } + +#define AssertHasOneHead(textClassifierResult) \ + XCTAssertNotNil(textClassifierResult); \ + XCTAssertNotNil(textClassifierResult.classificationResult); \ + XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); \ + XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0); + +@interface MPPTextClassifierTests : XCTestCase +@end + +@implementation MPPTextClassifierTests + +- (void)setUp { +} + +- (void)tearDown { + // Put teardown code here. This method is called after the invocation of each test method in the class. +} + +- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension { + NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName + ofType:extension]; + XCTAssertNotNil(filePath); + + return filePath; +} + +- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName { + NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; + MPPTextClassifierOptions *textClassifierOptions = + [[MPPTextClassifierOptions alloc] init]; + textClassifierOptions.baseOptions.modelAssetPath = modelPath; + + return textClassifierOptions; +} + +- (MPPTextClassifier *)createTextClassifierFromOptionsWithModelName:(NSString *)modelName { + MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:modelName]; + MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; + XCTAssertNotNil(textClassifier); + + return textClassifier; +} + +- (void)testClassifyWithBertSucceeds { + MPPTextClassifier *textClassifier = [self createTextClassifierFromOptionsWithModelName:kBertTextClassifierModelName]; + + MPPTextClassifierResult *negativeResult = [textClassifier classifyWithText:kNegativeText error:nil]; + AssertHasOneHead(negativeResult); + + NSArray *expectedNegativeCategories = @[[[MPPCategory alloc] initWithIndex:0 + score:0.956187f + categoryName:@"negative" + displayName:nil], + [[MPPCategory alloc] initWithIndex:1 + score:0.043812f + categoryName:@"positive" + displayName:nil]]; + + AssertCategoriesAre(negativeResult.classificationResult.classifications[0].categories, + expectedNegativeCategories + ); + + // MPPTextClassifierResult *positiveResult = [textClassifier classifyWithText:kPositiveText error:nil]; + // AssertHasOneHead(positiveResult); + // NSArray *expectedPositiveCategories = @[[[MPPCategory alloc] initWithIndex:0 + // score:0.99997187f + // label:@"positive" + // displayName:nil], + // [[MPPCategory alloc] initWithIndex:1 + // score:2.8132641E-5f + // label:@"negative" + // displayName:nil]]; + // AssertCategoriesAre(negativeResult.classificationResult.classifications[0].categories, + // expectedPositiveCategories + // ); + +} +@end diff --git a/mediapipe/tasks/ios/test/TextClassifierTests.swift b/mediapipe/tasks/ios/test/TextClassifierTests.swift new file mode 100644 index 000000000..2dca9c4b5 --- /dev/null +++ b/mediapipe/tasks/ios/test/TextClassifierTests.swift @@ -0,0 +1,272 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +// import GMLImageUtils +import XCTest + +// @testable import TFLImageSegmenter + +class TextClassifierTests: XCTestCase { + + func testExample() throws { + XCTAssertEqual(1, 1) + } + + // static let bundle = Bundle(for: TextClassifierTests.self) + // static let modelPath = bundle.path( + // forResource: "deeplabv3", + // ofType: "tflite") + + // // The maximum fraction of pixels in the candidate mask that can have a + // // different class than the golden mask for the test to pass. + // let kGoldenMaskTolerance: Float = 1e-2 + + // // Magnification factor used when creating the golden category masks to make + // // them more human-friendly. Each pixel in the golden masks has its value + // // multiplied by this factor, i.e. a value of 10 means class index 1, a value of + // // 20 means class index 2, etc. + // let kGoldenMaskMagnificationFactor: UInt8 = 10 + + // let deepLabV3SegmentationWidth = 257 + + // let deepLabV3SegmentationHeight = 257 + + // func verifyDeeplabV3PartialSegmentationResult(_ coloredLabels: [ColoredLabel]) { + + // self.verifyColoredLabel( + // coloredLabels[0], + // expectedR: 0, + // expectedG: 0, + // expectedB: 0, + // expectedLabel: "background") + + // self.verifyColoredLabel( + // coloredLabels[1], + // expectedR: 128, + // expectedG: 0, + // expectedB: 0, + // expectedLabel: "aeroplane") + + // self.verifyColoredLabel( + // coloredLabels[2], + // expectedR: 0, + // expectedG: 128, + // expectedB: 0, + // expectedLabel: "bicycle") + + // self.verifyColoredLabel( + // coloredLabels[3], + // expectedR: 128, + // expectedG: 128, + // expectedB: 0, + // expectedLabel: "bird") + + // self.verifyColoredLabel( + // coloredLabels[4], + // expectedR: 0, + // expectedG: 0, + // expectedB: 128, + // expectedLabel: "boat") + + // self.verifyColoredLabel( + // coloredLabels[5], + // expectedR: 128, + // expectedG: 0, + // expectedB: 128, + // expectedLabel: "bottle") + + // self.verifyColoredLabel( + // coloredLabels[6], + // expectedR: 0, + // expectedG: 128, + // expectedB: 128, + // expectedLabel: "bus") + + // self.verifyColoredLabel( + // coloredLabels[7], + // expectedR: 128, + // expectedG: 128, + // expectedB: 128, + // expectedLabel: "car") + + // self.verifyColoredLabel( + // coloredLabels[8], + // expectedR: 64, + // expectedG: 0, + // expectedB: 0, + // expectedLabel: "cat") + + // self.verifyColoredLabel( + // coloredLabels[9], + // expectedR: 192, + // expectedG: 0, + // expectedB: 0, + // expectedLabel: "chair") + + // self.verifyColoredLabel( + // coloredLabels[10], + // expectedR: 64, + // expectedG: 128, + // expectedB: 0, + // expectedLabel: "cow") + + // self.verifyColoredLabel( + // coloredLabels[11], + // expectedR: 192, + // expectedG: 128, + // expectedB: 0, + // expectedLabel: "dining table") + + // self.verifyColoredLabel( + // coloredLabels[12], + // expectedR: 64, + // expectedG: 0, + // expectedB: 128, + // expectedLabel: "dog") + + // self.verifyColoredLabel( + // coloredLabels[13], + // expectedR: 192, + // expectedG: 0, + // expectedB: 128, + // expectedLabel: "horse") + + // self.verifyColoredLabel( + // coloredLabels[14], + // expectedR: 64, + // expectedG: 128, + // expectedB: 128, + // expectedLabel: "motorbike") + + // self.verifyColoredLabel( + // coloredLabels[15], + // expectedR: 192, + // expectedG: 128, + // expectedB: 128, + // expectedLabel: "person") + + // self.verifyColoredLabel( + // coloredLabels[16], + // expectedR: 0, + // expectedG: 64, + // expectedB: 0, + // expectedLabel: "potted plant") + + // self.verifyColoredLabel( + // coloredLabels[17], + // expectedR: 128, + // expectedG: 64, + // expectedB: 0, + // expectedLabel: "sheep") + + // self.verifyColoredLabel( + // coloredLabels[18], + // expectedR: 0, + // expectedG: 192, + // expectedB: 0, + // expectedLabel: "sofa") + + // self.verifyColoredLabel( + // coloredLabels[19], + // expectedR: 128, + // expectedG: 192, + // expectedB: 0, + // expectedLabel: "train") + + // self.verifyColoredLabel( + // coloredLabels[20], + // expectedR: 0, + // expectedG: 64, + // expectedB: 128, + // expectedLabel: "tv") + // } + + // func verifyColoredLabel( + // _ coloredLabel: ColoredLabel, + // expectedR: UInt, + // expectedG: UInt, + // expectedB: UInt, + // expectedLabel: String + // ) { + // XCTAssertEqual( + // coloredLabel.r, + // expectedR) + // XCTAssertEqual( + // coloredLabel.g, + // expectedG) + // XCTAssertEqual( + // coloredLabel.b, + // expectedB) + // XCTAssertEqual( + // coloredLabel.label, + // expectedLabel) + // } + + // func testSuccessfullInferenceOnMLImageWithUIImage() throws { + + // let modelPath = try XCTUnwrap(ImageSegmenterTests.modelPath) + + // let imageSegmenterOptions = ImageSegmenterOptions(modelPath: modelPath) + + // let imageSegmenter = + // try ImageSegmenter.segmenter(options: imageSegmenterOptions) + + // let gmlImage = try XCTUnwrap( + // MLImage.imageFromBundle( + // class: type(of: self), + // filename: "segmentation_input_rotation0", + // type: "jpg")) + // let segmentationResult: SegmentationResult = + // try XCTUnwrap(imageSegmenter.segment(mlImage: gmlImage)) + + // XCTAssertEqual(segmentationResult.segmentations.count, 1) + + // let coloredLabels = try XCTUnwrap(segmentationResult.segmentations[0].coloredLabels) + // verifyDeeplabV3PartialSegmentationResult(coloredLabels) + + // let categoryMask = try XCTUnwrap(segmentationResult.segmentations[0].categoryMask) + // XCTAssertEqual(deepLabV3SegmentationWidth, categoryMask.width) + // XCTAssertEqual(deepLabV3SegmentationHeight, categoryMask.height) + + // let goldenMaskImage = try XCTUnwrap( + // MLImage.imageFromBundle( + // class: type(of: self), + // filename: "segmentation_golden_rotation0", + // type: "png")) + + // let pixelBuffer = goldenMaskImage.grayScalePixelBuffer().takeRetainedValue() + + // CVPixelBufferLockBaseAddress(pixelBuffer, CVPixelBufferLockFlags.readOnly) + + // let pixelBufferBaseAddress = (try XCTUnwrap(CVPixelBufferGetBaseAddress(pixelBuffer))) + // .assumingMemoryBound(to: UInt8.self) + + // let numPixels = deepLabV3SegmentationWidth * deepLabV3SegmentationHeight + + // let mask = try XCTUnwrap(categoryMask.mask) + + // var inconsistentPixels: Float = 0.0 + + // for i in 0.. - -#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" - -NS_ASSUME_NONNULL_BEGIN - -static NSString *const kBertTextClassifierModelName = @"bert_text_classifier"; -static NSString *const kNegativeText = @"unflinchingly bleak and desperate"; - -#define VerifyCategory(category, expectedIndex, expectedScore, expectedLabel, expectedDisplayName) \ - XCTAssertEqual(category.index, expectedIndex); \ - XCTAssertEqualWithAccuracy(category.score, expectedScore, 1e-6); \ - XCTAssertEqualObjects(category.label, expectedLabel); \ - XCTAssertEqualObjects(category.displayName, expectedDisplayName); - -#define VerifyClassifications(classifications, expectedHeadIndex, expectedCategoryCount) \ - XCTAssertEqual(classifications.categories.count, expectedCategoryCount); - -#define VerifyClassificationResult(classificationResult, expectedClassificationsCount) \ - XCTAssertNotNil(classificationResult); \ - XCTAssertEqual(classificationResult.classifications.count, expectedClassificationsCount) - -#define AssertClassificationResultHasOneHead(classificationResult) \ - XCTAssertNotNil(classificationResult); \ - XCTAssertEqual(classificationResult.classifications.count, 1); - XCTAssertEqual(classificationResult.classifications[0].headIndex, 1); - -#define AssertTextClassifierResultIsNotNil(textClassifierResult) \ - XCTAssertNotNil(textClassifierResult); - -@interface MPPTextClassifierTests : XCTestCase -@end - -@implementation MPPTextClassifierTests - -- (void)setUp { - [super setUp]; - -} - -- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension { - NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName - ofType:extension]; - XCTAssertNotNil(filePath); - - return filePath; -} - -- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName { - NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; - MPPTextClassifierOptions *textClassifierOptions = - [[MPPTextClassifierOptions alloc] init]; - textClassifierOptions.baseOptions.modelAssetPath = modelPath; - - return textClassifierOptions; -} - -kBertTextClassifierModelName - -- (MPPTextClassifier *)createTextClassifierFromOptionsWithModelName:(NSString *)modelName { - MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:modelName]; - MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil]; - XCTAssertNotNil(textClassifier); - - return textClassifier -} - -- (void)classifyWithBertSucceeds { - MPPTextClassifier *textClassifier = [self createTextClassifierWithModelName:kBertTextClassifierModelName]; - MPPTextClassifierResult *textClassifierResult = [textClassifier classifyWithText:kNegativeText]; -} - -@end - -NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/core/BUILD b/mediapipe/tasks/ios/text/core/BUILD index abb8edc71..6d558b22b 100644 --- a/mediapipe/tasks/ios/text/core/BUILD +++ b/mediapipe/tasks/ios/text/core/BUILD @@ -17,17 +17,15 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) objc_library( - name = "MPPBaseTextTaskApi", - srcs = ["sources/MPPBaseTextTaskApi.mm"], - hdrs = ["sources/MPPBaseTextTaskApi.h"], + name = "MPPTextTaskRunner", + srcs = ["sources/MPPTextTaskRunner.mm"], + hdrs = ["sources/MPPTextTaskRunner.h"], copts = [ "-ObjC++", "-std=c++17", ], deps = [ - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/core:MPPTaskRunner", ], ) diff --git a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h b/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h deleted file mode 100644 index 405d25a81..000000000 --- a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ -#import - -#include "mediapipe/framework/calculator.pb.h" -#include "mediapipe/tasks/cc/core/task_runner.h" - -NS_ASSUME_NONNULL_BEGIN - -/** - * The base class of the user-facing iOS mediapipe text task api classes. - */ -NS_SWIFT_NAME(BaseTextTaskApi) -@interface MPPBaseTextTaskApi : NSObject { - @protected - std::unique_ptr cppTaskRunner; -} - -/** - * Initializes a new `MPPBaseTextTaskApi` with the mediapipe text task graph config proto. - * - * @param graphConfig A mediapipe text task graph config proto. - * - * @return An instance of `MPPBaseTextTaskApi` initialized to the given graph config proto. - */ -- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig - error:(NSError **)error; -- (void)close; - -- (instancetype)init NS_UNAVAILABLE; - -+ (instancetype)new NS_UNAVAILABLE; - -@end - -NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm b/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm deleted file mode 100644 index 5c05797da..000000000 --- a/mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.mm +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ -#import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h" -#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" - -namespace { -using ::mediapipe::CalculatorGraphConfig; -using ::mediapipe::Packet; -using ::mediapipe::tasks::core::PacketMap; -using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; -} // namespace - -@interface MPPBaseTextTaskApi () { - /** TextSearcher backed by C++ API */ - std::unique_ptr _cppTaskRunner; -} -@end - -@implementation MPPBaseTextTaskApi - -- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig - error:(NSError **)error { - self = [super init]; - if (self) { - auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig)); - - if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) { - return nil; - } - - _cppTaskRunner = std::move(taskRunnerResult.value()); - } - return self; -} - -- (void)close { - _cppTaskRunner->Close(); -} - -@end diff --git a/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h new file mode 100644 index 000000000..dd5d96ce6 --- /dev/null +++ b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h @@ -0,0 +1,37 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * This class is used to create and call appropriate methods on the C++ Task Runner to initialize, execute and terminate any Mediapipe text task. + */ +@interface MPPTextTaskRunner : MPPTaskRunner + +/** + * Initializes a new `MPPTextTaskRunner` with the mediapipe task graph config proto. + * + * @param graphConfig A mediapipe task graph config proto. + * + * @return An instance of `MPPTextTaskRunner` initialized to the given graph config proto. + */ +- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm new file mode 100644 index 000000000..956448c17 --- /dev/null +++ b/mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.mm @@ -0,0 +1,29 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h" + +namespace { +using ::mediapipe::CalculatorGraphConfig; +} // namespace + +@implementation MPPTextTaskRunner + +- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig + error:(NSError **)error { + self = [super initWithCalculatorGraphConfig:graphConfig packetsCallback:nullptr error:error]; + return self; +} + +@end diff --git a/mediapipe/tasks/ios/text/core/utils/BUILD b/mediapipe/tasks/ios/text/core/utils/BUILD deleted file mode 100644 index abb8edc71..000000000 --- a/mediapipe/tasks/ios/text/core/utils/BUILD +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -objc_library( - name = "MPPBaseTextTaskApi", - srcs = ["sources/MPPBaseTextTaskApi.mm"], - hdrs = ["sources/MPPBaseTextTaskApi.h"], - copts = [ - "-ObjC++", - "-std=c++17", - ], - deps = [ - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", - ], -) - diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index 61eecb9cd..ce118c718 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,9 +25,11 @@ objc_library( "-std=c++17", ], deps = [ + "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", "//mediapipe/tasks/ios/core:MPPTaskOptions", "//mediapipe/tasks/ios/core:MPPTaskInfo", - "//mediapipe/tasks/ios/core:MPPTaskRunner", + "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", "//mediapipe/tasks/ios/core:MPPTextPacketCreator", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers", @@ -35,6 +37,9 @@ objc_library( "//mediapipe/tasks/ios/common/utils:NSStringHelpers", ":MPPTextClassifierOptions", ], + sdk_frameworks = [ + "MetalKit", + ], ) objc_library( diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h index 19e10e35f..ee6f25100 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h @@ -1,34 +1,61 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import -#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" NS_ASSUME_NONNULL_BEGIN /** - * A Mediapipe iOS Text Classifier. + * This API expects a TFLite model with (optional) [TFLite Model + * Metadata](https://www.tensorflow.org/lite/convert/metadata")that contains the mandatory + * (described below) input tensors, output tensor, and the optional (but recommended) label items as + * AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor. + * + * Metadata is required for models with int32 input tensors because it contains the input process + * unit for the model's Tokenizer. No metadata is required for models with string input tensors. + * + * Input tensors + * - Three input tensors `kTfLiteInt32` of shape `[batch_size xbert_max_seq_len]` + * representing the input ids, mask ids, and segment ids. This input signature requires a + * Bert Tokenizer process unit in the model metadata. + * - Or one input tensor `kTfLiteInt32` of shape `[batch_size xmax_seq_len]` representing + * the input ids. This input signature requires a Regex Tokenizer process unit in the + * model metadata. + * - Or one input tensor ({@code kTfLiteString}) that is shapeless or has shape `[1]` containing + * the input string. + * + * At least one output tensor `(kTfLiteFloat32}/kBool)` with: + * - `N` classes and shape `[1 x N]` + * - optional (but recommended) label map(s) as AssociatedFile-s with type TENSOR_AXIS_LABELS, + * containing one label per line. The first such AssociatedFile (if any) is used to fill the + * `class_name` field of the results. The `display_name` field is filled from the AssociatedFile + * (if any) whose locale matches the `display_names_locale` field of the + * `MPPTextClassifierOptions` used at creation time ("en" by default, i.e. English). If none of + * these are available, only the `index` field of the results will be filled. + * + * @brief Performs classification on text. */ NS_SWIFT_NAME(TextClassifier) @interface MPPTextClassifier : NSObject /** * Creates a new instance of `MPPTextClassifier` from an absolute path to a TensorFlow Lite model - * file stored locally on the device. + * file stored locally on the device and the default `MPPTextClassifierOptions`. * * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. * @@ -41,9 +68,11 @@ NS_SWIFT_NAME(TextClassifier) - (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; /** - * Creates a new instance of `MPPTextClassifier` from the given text classifier options. + * Creates a new instance of `MPPTextClassifier` from the given `MPPTextClassifierOptions`. + * + * @param options The options of type `MPPTextClassifierOptions` to use for configuring the + * `MPPTextClassifier`. * - * @param options The options to use for configuring the `MPPTextClassifier`. * @param error An optional error parameter populated when there is an error in initializing * the text classifier. * @@ -52,6 +81,16 @@ NS_SWIFT_NAME(TextClassifier) */ - (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error; +/** + * Performs classification on the input text. + * + * @param text The `NSString` on which classification is to be performed. + * + * @param error An optional error parameter populated when there is an error in performing + * classification on the input text. + * + * @return A `MPPTextClassifierResult` object that contains a list of text classifications. + */ - (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm index b9e76fc69..487cdad42 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -1,25 +1,27 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" -#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" -#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" -#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" +#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h" #import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "absl/status/statusor.h" @@ -37,14 +39,13 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T @interface MPPTextClassifier () { /** TextSearcher backed by C++ API */ - MPPTaskRunner *_taskRunner; + MPPTextTaskRunner *_taskRunner; } @end @implementation MPPTextClassifier - (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error { - MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] initWithTaskGraphName:kTaskGraphName inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ] @@ -58,10 +59,11 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T return nil; } - _taskRunner = [[MPPTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error]; - + _taskRunner = + [[MPPTextTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] + error:error]; self = [super init]; - + return self; } @@ -76,14 +78,23 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T - (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error { Packet packet = [MPPTextPacketCreator createWithText:text]; - absl::StatusOr output_packet_map = [_taskRunner process:{{kTextInStreamName.cppString, packet}} error:error]; - if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) { + std::map packet_map = {{kTextInStreamName.cppString, packet}}; + absl::StatusOr status_or_output_packet_map = [_taskRunner process:packet_map]; + + if (![MPPCommonUtils checkCppError:status_or_output_packet_map.status() toError:error]) { return nil; } + Packet classifications_packet = + status_or_output_packet_map.value()[kClassificationsStreamName.cppString]; + return [MPPTextClassifierResult - textClassifierResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString] - .Get()]; + textClassifierResultWithClassificationsPacket:status_or_output_packet_map.value() + [kClassificationsStreamName.cppString]]; + + // return [MPPTextClassifierResult + // textClassifierResultWithClassificationsPacket:output_packet_map.value()[kClassificationsStreamName.cppString] + // .Get()]; } @end diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h index 374226998..25189578b 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import #import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h" @@ -20,32 +20,16 @@ NS_ASSUME_NONNULL_BEGIN /** - * Options to configure MPPTextClassifierOptions. + * Options for setting up a `MPPTextClassifierOptions`. */ NS_SWIFT_NAME(TextClassifierOptions) @interface MPPTextClassifierOptions : MPPTaskOptions /** - * Options controlling the behavior of the embedding model specified in the - * base options. + * Options for configuring the classifier behavior, such as score threshold, number of results, etc. */ @property(nonatomic, copy) MPPClassifierOptions *classifierOptions; -// /** -// * Initializes a new `MPPTextClassifierOptions` with the absolute path to the model file -// * stored locally on the device, set to the given the model path. -// * -// * @discussion The external model file must be a single standalone TFLite file. It could be packed -// * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the -// * necessary metadata and associated files might result in errors. Check the [documentation] -// * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement. -// * -// * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. -// * -// * @return An instance of `MPPTextClassifierOptions` initialized to the given model path. -// */ -// - (instancetype)initWithModelPath:(NSString *)modelPath; - @end NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m index 82e9bed64..8d4ffd36f 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.m @@ -1,27 +1,27 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" @implementation MPPTextClassifierOptions -// - (instancetype)initWithModelPath:(NSString *)modelPath { -// self = [super initWithModelPath:modelPath]; -// if (self) { -// _classifierOptions = [[MPPClassifierOptions alloc] init]; -// } -// return self; -// } +- (instancetype)init { + self = [super init]; + if (self) { + _classifierOptions = [[MPPClassifierOptions alloc] init]; + } + return self; +} @end \ No newline at end of file diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h index 414e6d9c6..6926757e4 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,23 +18,27 @@ NS_ASSUME_NONNULL_BEGIN -/** Encapsulates list of predicted classes (aka labels) for a given image classifier head. */ +/** Represents the classification results generated by `MPPTextClassifier`. */ NS_SWIFT_NAME(TextClassifierResult) @interface MPPTextClassifierResult : MPPTaskResult +/** The `MPPClassificationResult` instance containing one set of results per classifier head. */ @property(nonatomic, readonly) MPPClassificationResult *classificationResult; /** - * Initializes a new `MPPClassificationResult` with the given array of classifications. + * Initializes a new `MPPTextClassifierResult` with the given `MPPClassificationResult` and time + * stamp (in milliseconds). * - * @param classifications An Aaray of `MPPClassifications` objects containing classifier - * predictions per classifier head. + * @param classificationResult The `MPPClassificationResult` instance containing one set of results + * per classifier head. * - * @return An instance of MPPClassificationResult initialized with the given array of - * classifications. + * @param timeStampMs The time stamp for this result. + * + * @return An instance of `MPPTextClassifierResult` initialized with the given + * `MPPClassificationResult` and time stamp (in milliseconds). */ - (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult - timeStamp:(long)timeStamp; + timestampMs:(NSInteger)timestampMs; @end diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m index b99ee3b19..4d5c1104a 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,8 +17,8 @@ @implementation MPPTextClassifierResult - (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult - timeStamp:(long)timeStamp { - self = [super initWithTimestamp:timeStamp]; + timestampMs:(NSInteger)timestampMs { + self = [super initWithTimestampMs:timestampMs]; if (self) { _classificationResult = classificationResult; } diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD index d6a371137..abc1fc23b 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/utils/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,5 +36,6 @@ objc_library( deps = [ "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierResult", "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", + "//mediapipe/framework:packet", ], ) diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h index 0771eafce..1e52e5c87 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" +#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h" NS_ASSUME_NONNULL_BEGIN diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm index aa11384d2..728000b44 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.mm @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h index d3fb04d69..f1b728b0a 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,16 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h" +#include "mediapipe/framework/packet.h" + NS_ASSUME_NONNULL_BEGIN @interface MPPTextClassifierResult (Helpers) -+ (MPPTextClassifierResult *)textClassifierResultWithProto: - (const mediapipe::tasks::components::containers::proto::ClassificationResult &) - classificationResultProto; ++ (MPPTextClassifierResult *)textClassifierResultWithClassificationsPacket: + (const mediapipe::Packet &)packet; @end diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm index 2fc2d751d..f5d6aa1d3 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,28 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" #import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" +#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" + +static const int kMicroSecondsPerMilliSecond = 1000; namespace { using ClassificationResultProto = ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::Packet; } // namespace +#define int kMicroSecondsPerMilliSecond = 1000; + @implementation MPPTextClassifierResult (Helpers) -+ (MPPTextClassifierResult *)textClassifierResultWithProto: - (const ClassificationResultProto &)classificationResultProto { - long timeStamp; ++ (MPPTextClassifierResult *)textClassifierResultWithClassificationsPacket:(const Packet &)packet { + MPPClassificationResult *classificationResult = [MPPClassificationResult + classificationResultWithProto:packet.Get()]; - if (classificationResultProto.has_timestamp_ms()) { - timeStamp = classificationResultProto.timestamp_ms(); - } - - MPPClassificationResult *classificationResult = [MPPClassificationResult classificationResultWithProto:classificationResultProto]; - - return [[MPPTextClassifierResult alloc] initWithClassificationResult:classificationResult - timeStamp:timeStamp]; + return [[MPPTextClassifierResult alloc] + initWithClassificationResult:classificationResult + timestampMs:(NSInteger)(packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond)]; } @end