Updated implementation of iOS Text Classifier

This commit is contained in:
Prianka Liz Kariat 2023-01-05 18:09:29 +05:30
parent 7ce21038bb
commit c8ebd21bd5
40 changed files with 885 additions and 489 deletions

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
// Copyright 2022 The MediaPipe Authors.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -16,41 +16,44 @@
NS_ASSUME_NONNULL_BEGIN
/** Encapsulates information about a class in the classification results. */
/** Category is a util class, contains a label, its display name, a float value as score, and the
* index of the label in the corresponding label file. Typically it's used as the result of
* classification tasks. */
NS_SWIFT_NAME(ClassificationCategory)
@interface MPPCategory : NSObject
/** Index of the class in the corresponding label map, usually packed in the TFLite Model
* Metadata. */
/** The index of the label in the corresponding label file. It takes the value -1 if the index is
* not set. */
@property(nonatomic, readonly) NSInteger index;
/** Confidence score for this class . */
@property(nonatomic, readonly) float score;
/** Class name of the class. */
@property(nonatomic, readonly, nullable) NSString *label;
/** The label of this category object. */
@property(nonatomic, readonly, nullable) NSString *categoryName;
/** Display name of the class. */
/** The display name of the label, which may be translated for different locales. For example, a
* label, "apple", may be translated into Spanish for display purpose, so that the display name is
* "manzana". */
@property(nonatomic, readonly, nullable) NSString *displayName;
/**
* Initializes a new `TFLCategory` with the given index, score, label and display name.
* Initializes a new `MPPCategory` with the given index, score, category name and display name.
*
* @param index Index of the class in the corresponding label map, usually packed in the TFLite
* Model Metadata.
* @param index The index of the label in the corresponding label file.
*
* @param score Confidence score for this class.
* @param score The probability score of this label category.
*
* @param label Class name of the class.
* @param categoryName The label of this category object..
*
* @param displayName Display name of the class.
* @param displayName The display name of the label.
*
* @return An instance of `TFLCategory` initialized with the given index, score, label and display
* name.
* @return An instance of `MPPCategory` initialized with the given index, score, category name and
* display name.
*/
- (instancetype)initWithIndex:(NSInteger)index
score:(float)score
label:(nullable NSString *)label
categoryName:(nullable NSString *)categoryName
displayName:(nullable NSString *)displayName;
- (instancetype)init NS_UNAVAILABLE;

View File

@ -1,4 +1,4 @@
// Copyright 2022 The MediaPipe Authors.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -18,13 +18,13 @@
- (instancetype)initWithIndex:(NSInteger)index
score:(float)score
label:(nullable NSString *)label
categoryName:(nullable NSString *)categoryName
displayName:(nullable NSString *)displayName {
self = [super init];
if (self) {
_index = index;
_score = score;
_label = label;
_categoryName = categoryName;
_displayName = displayName;
}
return self;

View File

@ -1,4 +1,4 @@
// Copyright 2022 The MediaPipe Authors.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -17,32 +17,27 @@
NS_ASSUME_NONNULL_BEGIN
/** Encapsulates list of predicted classes (aka labels) for a given image classifier head. */
/** Represents the list of classification for a given classifier head. Typically used as a result
* for classification tasks. */
NS_SWIFT_NAME(Classifications)
@interface MPPClassifications : NSObject
/**
* The index of the classifier head these classes refer to. This is useful for multi-head
* models.
/** The index of the classifier head these entries refer to. This is useful for multi-head models.
*/
@property(nonatomic, readonly) NSInteger headIndex;
/** The name of the classifier head, which is the corresponding tensor metadata
* name.
*/
@property(nonatomic, readonly) NSString *headName;
/** The optional name of the classifier head, which is the corresponding tensor metadata name. */
@property(nonatomic, readonly, nullable) NSString *headName;
/** The array of predicted classes, usually sorted by descending scores (e.g.from high to low
* probability). */
/** An array of `MPPCategory` objects containing the predicted categories. */
@property(nonatomic, readonly) NSArray<MPPCategory *> *categories;
/**
* Initializes a new `MPPClassifications` with the given head index and array of categories.
* head name is initialized to `nil`.
* Initializes a new `MPPClassifications` object with the given head index and array of categories.
* Head name is initialized to `nil`.
*
* @param headIndex The index of the image classifier head these classes refer to.
* @param categories An array of `MPPCategory` objects encapsulating a list of
* predictions usually sorted by descending scores (e.g. from high to low probability).
* @param headIndex The index of the classifier head.
* @param categories An array of `MPPCategory` objects containing the predicted categories.
*
* @return An instance of `MPPClassifications` initialized with the given head index and
* array of categories.
@ -54,11 +49,10 @@ NS_SWIFT_NAME(Classifications)
* Initializes a new `MPPClassifications` with the given head index, head name and array of
* categories.
*
* @param headIndex The index of the classifier head these classes refer to.
* @param headIndex The index of the classifier head.
* @param headName The name of the classifier head, which is the corresponding tensor metadata
* name.
* @param categories An array of `MPPCategory` objects encapsulating a list of
* predictions usually sorted by descending scores (e.g. from high to low probability).
* @param categories An array of `MPPCategory` objects containing the predicted categories.
*
* @return An object of `MPPClassifications` initialized with the given head index, head name and
* array of categories.
@ -69,17 +63,27 @@ NS_SWIFT_NAME(Classifications)
@end
/** Encapsulates results of any classification task. */
/**
* Represents the classification results of a model. Typically used as a result for classification
* tasks.
*/
NS_SWIFT_NAME(ClassificationResult)
@interface MPPClassificationResult : NSObject
/** Array of MPPClassifications objects containing classifier predictions per image classifier
* head.
*/
/** An Array of `MPPClassifications` objects containing the predicted categories for each head of
* the model. */
@property(nonatomic, readonly) NSArray<MPPClassifications *> *classifications;
/** The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to
* these results. If it is set to the value -1, it signifies the absence of a time stamp. This is
* only used for classification on time series (e.g. audio classification). In these use cases, the
* amount of data to process might exceed the maximum size that the model can process: to solve
* this, the input data is split into multiple chunks starting at different timestamps. */
@property(nonatomic, readonly) NSInteger timestampMs;
/**
* Initializes a new `MPPClassificationResult` with the given array of classifications.
* Initializes a new `MPPClassificationResult` with the given array of classifications. This method
* must be used when no time stamp needs to be specified. It sets the property `timestampMs` to -1.
*
* @param classifications An Aaray of `MPPClassifications` objects containing classifier
* predictions per classifier head.
@ -89,6 +93,22 @@ NS_SWIFT_NAME(ClassificationResult)
*/
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications;
/**
* Initializes a new `MPPClassificationResult` with the given array of classifications and time
* stamp (in milliseconds).
*
* @param classifications An Array of `MPPClassifications` objects containing the predicted
* categories for each head of the model.
*
* @param timeStampMs The timestamp (in milliseconds) of the start of the chunk of data
* corresponding to these results.
*
* @return An instance of `MPPClassificationResult` initialized with the given array of
* classifications and timestampMs.
*/
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
timestampMs:(NSInteger)timestampMs;
@end
NS_ASSUME_NONNULL_END

View File

@ -1,4 +1,4 @@
// Copyright 2022 The MediaPipe Authors.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -35,16 +35,22 @@
@end
@implementation MPPClassificationResult {
NSArray<MPPClassifications *> *_classifications;
}
@implementation MPPClassificationResult
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications {
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
timestampMs:(NSInteger)timestampMs {
self = [super init];
if (self) {
_classifications = classifications;
_timestampMs = timestampMs;
}
return self;
}
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications {
return [self initWithClassifications:classifications timestampMs:-1];
return self;
}

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,17 +1,17 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/framework/formats/classification.pb.h"
#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h"

View File

@ -1,17 +1,17 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h"
@ -22,11 +22,11 @@ using ClassificationProto = ::mediapipe::Classification;
@implementation MPPCategory (Helpers)
+ (MPPCategory *)categoryWithProto:(const ClassificationProto &)clasificationProto {
NSString *label;
NSString *categoryName;
NSString *displayName;
if (clasificationProto.has_label()) {
label = [NSString stringWithCppString:clasificationProto.label()];
categoryName = [NSString stringWithCppString:clasificationProto.label()];
}
if (clasificationProto.has_display_name()) {
@ -35,7 +35,7 @@ using ClassificationProto = ::mediapipe::Classification;
return [[MPPCategory alloc] initWithIndex:clasificationProto.index()
score:clasificationProto.score()
label:label
categoryName:categoryName
displayName:displayName];
}

View File

@ -1,17 +1,17 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h"

View File

@ -1,17 +1,17 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
@ -53,7 +53,16 @@ using ClassificationResultProto =
[classifications addObject:[MPPClassifications classificationsWithProto:classifications_proto]];
}
return [[MPPClassificationResult alloc] initWithClassifications:classifications];
MPPClassificationResult *classificationResult;
if (classificationResultProto.has_timestamp_ms()) {
classificationResult = [[MPPClassificationResult alloc] initWithClassifications:classifications timestampMs:(NSInteger)classificationResultProto.timestamp_ms()];
}
else {
classificationResult = [[MPPClassificationResult alloc] initWithClassifications:classifications];
}
return classificationResult;
}
@end

View File

@ -90,6 +90,13 @@ objc_library(
deps = [
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
],
)
objc_library(
name = "MPPResultCallback",
hdrs = ["sources/MPPResultCallback.h"],
)

View File

@ -22,7 +22,7 @@ NS_ASSUME_NONNULL_BEGIN
typedef NS_ENUM(NSUInteger, MPPDelegate) {
/** CPU. */
MPPDelegateCPU,
/** GPU. */
MPPDelegateGPU
} NS_SWIFT_NAME(Delegate);
@ -46,4 +46,3 @@ NS_SWIFT_NAME(BaseOptions)
@end
NS_ASSUME_NONNULL_END

View File

@ -26,10 +26,10 @@
- (id)copyWithZone:(NSZone *)zone {
MPPBaseOptions *baseOptions = [[MPPBaseOptions alloc] init];
baseOptions.modelAssetPath = self.modelAssetPath;
baseOptions.delegate = self.delegate;
return baseOptions;
}

View File

@ -0,0 +1,21 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
NS_ASSUME_NONNULL_BEGIN
typedef void (^MPPResultCallback)(id oputput, id input, NSError *error);
NS_ASSUME_NONNULL_END

View File

@ -26,11 +26,11 @@ NS_SWIFT_NAME(TaskResult)
/**
* Timestamp that is associated with the task result object.
*/
@property(nonatomic, assign, readonly) long timestamp;
@property(nonatomic, assign, readonly) NSInteger timestampMs;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithTimestamp:(long)timestamp NS_DESIGNATED_INITIALIZER;
- (instancetype)initWithTimestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER;
@end

View File

@ -16,16 +16,16 @@
@implementation MPPTaskResult
- (instancetype)initWithTimestamp:(long)timestamp {
- (instancetype)initWithTimestampMs:(NSInteger)timestampMs {
self = [super init];
if (self) {
_timestamp = timestamp;
_timestampMs = timestampMs;
}
return self;
}
- (id)copyWithZone:(NSZone *)zone {
return [[MPPTaskResult alloc] initWithTimestamp:self.timestamp];
return [[MPPTaskResult alloc] initWithTimestampMs:self.timestampMs];
}
@end

View File

@ -20,23 +20,63 @@
NS_ASSUME_NONNULL_BEGIN
/**
* This class is used to create and call appropriate methods on the C++ Task Runner.
* This class is used to create and call appropriate methods on the C++ Task Runner to initialize,
* execute and terminate any Mediapipe task.
*
* An instance of the newly created C++ task runner will
* be stored until this class is destroyed. When methods are called for processing (performing
* inference), closing etc., on this class, internally the appropriate methods will be called on the
* C++ task runner instance to execute the appropriate actions. For each type of task, a subclass of
* this class must be defined to add any additional functionality. For eg:, vision tasks must create
* an `MPPVisionTaskRunner` and provide additional functionality. An instance of
* `MPPVisionTaskRunner` can in turn be used by the each vision task for creation and execution of
* the task. Please see the documentation for the C++ Task Runner for more details on how the taks
* runner operates.
*/
@interface MPPTaskRunner : NSObject
/**
* Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto.
* Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto and an optional C++
* packets callback.
*
* You can pass `nullptr` for `packetsCallback` in case the mode of operation
* requested by the user is synchronous.
*
* If the task is operating in asynchronous mode, any iOS Mediapipe task that uses the `MPPTaskRunner`
* must define a C++ callback function to obtain the results of inference asynchronously and deliver
* the results to the user. To accomplish this, callback function will in turn invoke the block
* provided by the user in the task options supplied to create the task.
* Please see the documentation of the C++ Task Runner for more information on the synchronous and
* asynchronous modes of operation.
*
* @param graphConfig A mediapipe task graph config proto.
*
* @return An instance of `MPPTaskRunner` initialized to the given graph config proto.
* @param packetsCallback An optional C++ callback function that takes a list of output packets as
* the input argument. If provided, the callback must in turn call the block provided by the user in
* the appropriate task options.
*
* @return An instance of `MPPTaskRunner` initialized to the given graph config proto and optional
* packetsCallback.
*/
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
packetsCallback:
(mediapipe::tasks::core::PacketsCallback)packetsCallback
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
/** A synchronous method for processing batch data or offline streaming data. This method is
designed for processing either batch data such as unrelated images and texts or offline streaming
data such as the decoded frames from a video file and an audio file. The call blocks the current
thread until a failure status or a successful result is returned. If the input packets have no
timestamp, an internal timestamp will be assigend per invocation. Otherwise, when the timestamp is
set in the input packets, the caller must ensure that the input packet timestamps are greater than
the timestamps of the previous invocation. This method is thread-unsafe and it is the caller's
responsibility to synchronize access to this method across multiple threads and to ensure that the
input packet timestamps are in order.*/
- (absl::StatusOr<mediapipe::tasks::core::PacketMap>)process:
(const mediapipe::tasks::core::PacketMap &)packetMap;
/** Shuts down the C++ task runner. After the runner is closed, any calls that send input data to
* the runner are illegal and will receive errors. */
- (absl::Status)close;
- (instancetype)init NS_UNAVAILABLE;

View File

@ -13,11 +13,15 @@
// limitations under the License.
#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
#include "tensorflow/lite/core/api/op_resolver.h"
namespace {
using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::tasks::core::MediaPipeBuiltinOpResolver;
using ::mediapipe::tasks::core::PacketMap;
using ::mediapipe::tasks::core::PacketsCallback;
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
} // namespace
@ -30,15 +34,17 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
@implementation MPPTaskRunner
- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig
packetsCallback:(PacketsCallback)packetsCallback
error:(NSError **)error {
self = [super init];
if (self) {
auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig));
auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig),
absl::make_unique<MediaPipeBuiltinOpResolver>(),
std::move(packetsCallback));
if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) {
return nil;
}
_cppTaskRunner = std::move(taskRunnerResult.value());
}
return self;

View File

@ -1,18 +1,15 @@
load(
"//mediapipe/tasks:ios/ios.bzl",
"MPP_TASK_MINIMUM_OS_VERSION",
"MPP_TASK_DEFAULT_TAGS",
"MPP_TASK_DISABLED_SANITIZER_TAGS",
)
load(
"@build_bazel_rules_apple//apple:ios.bzl",
"@build_bazel_rules_apple//apple:ios.bzl",
"ios_unit_test",
)
load(
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
"tflite_ios_lab_runner"
)
load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
@ -25,7 +22,7 @@ objc_library(
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
"//mediapipe/tasks/testdata/text:text_classifier_models",
],
tags = MPP_TASK_DEFAULT_TAGS,
tags = [],
copts = [
"-ObjC++",
"-std=c++17",
@ -38,10 +35,27 @@ objc_library(
ios_unit_test(
name = "MPPTextClassifierObjcTest",
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = MPP_TASK_DEFAULT_TAGS + MPP_TASK_DISABLED_SANITIZER_TAGS,
tags =[],
deps = [
":MPPTextClassifierObjcTestLibrary",
],
)
swift_library(
name = "MPPTextClassifierSwiftTestLibrary",
testonly = 1,
srcs = ["TextClassifierTests.swift"],
tags = [],
)
ios_unit_test(
name = "MPPTextClassifierSwiftTest",
minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = [],
deps = [
":MPPTextClassifierSwiftTestLibrary",
],
)

View File

@ -0,0 +1,110 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import <XCTest/XCTest.h>
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
static NSString *const kBertTextClassifierModelName = @"bert_text_classifier";
static NSString *const kNegativeText = @"unflinchingly bleak and desperate";
static NSString *const kPositiveText = @"it's a charming and often affecting journey";
#define AssertCategoriesAre(categories, expectedCategories) \
XCTAssertEqual(categories.count, expectedCategories.count); \
for (int i = 0; i < categories.count; i++) { \
XCTAssertEqual(categories[i].index, expectedCategories[i].index); \
XCTAssertEqualWithAccuracy(categories[i].score, expectedCategories[i].score, 1e-6); \
XCTAssertEqualObjects(categories[i].categoryName, expectedCategories[i].categoryName); \
XCTAssertEqualObjects(categories[i].displayName, expectedCategories[i].displayName); \
}
#define AssertHasOneHead(textClassifierResult) \
XCTAssertNotNil(textClassifierResult); \
XCTAssertNotNil(textClassifierResult.classificationResult); \
XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); \
XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0);
@interface MPPTextClassifierTests : XCTestCase
@end
@implementation MPPTextClassifierTests
- (void)setUp {
}
- (void)tearDown {
// Put teardown code here. This method is called after the invocation of each test method in the class.
}
- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName
ofType:extension];
XCTAssertNotNil(filePath);
return filePath;
}
- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName {
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
MPPTextClassifierOptions *textClassifierOptions =
[[MPPTextClassifierOptions alloc] init];
textClassifierOptions.baseOptions.modelAssetPath = modelPath;
return textClassifierOptions;
}
- (MPPTextClassifier *)createTextClassifierFromOptionsWithModelName:(NSString *)modelName {
MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:modelName];
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
XCTAssertNotNil(textClassifier);
return textClassifier;
}
- (void)testClassifyWithBertSucceeds {
MPPTextClassifier *textClassifier = [self createTextClassifierFromOptionsWithModelName:kBertTextClassifierModelName];
MPPTextClassifierResult *negativeResult = [textClassifier classifyWithText:kNegativeText error:nil];
AssertHasOneHead(negativeResult);
NSArray<MPPCategory *> *expectedNegativeCategories = @[[[MPPCategory alloc] initWithIndex:0
score:0.956187f
categoryName:@"negative"
displayName:nil],
[[MPPCategory alloc] initWithIndex:1
score:0.043812f
categoryName:@"positive"
displayName:nil]];
AssertCategoriesAre(negativeResult.classificationResult.classifications[0].categories,
expectedNegativeCategories
);
// MPPTextClassifierResult *positiveResult = [textClassifier classifyWithText:kPositiveText error:nil];
// AssertHasOneHead(positiveResult);
// NSArray<MPPCategory *> *expectedPositiveCategories = @[[[MPPCategory alloc] initWithIndex:0
// score:0.99997187f
// label:@"positive"
// displayName:nil],
// [[MPPCategory alloc] initWithIndex:1
// score:2.8132641E-5f
// label:@"negative"
// displayName:nil]];
// AssertCategoriesAre(negativeResult.classificationResult.classifications[0].categories,
// expectedPositiveCategories
// );
}
@end

View File

@ -0,0 +1,272 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// import GMLImageUtils
import XCTest
// @testable import TFLImageSegmenter
class TextClassifierTests: XCTestCase {
func testExample() throws {
XCTAssertEqual(1, 1)
}
// static let bundle = Bundle(for: TextClassifierTests.self)
// static let modelPath = bundle.path(
// forResource: "deeplabv3",
// ofType: "tflite")
// // The maximum fraction of pixels in the candidate mask that can have a
// // different class than the golden mask for the test to pass.
// let kGoldenMaskTolerance: Float = 1e-2
// // Magnification factor used when creating the golden category masks to make
// // them more human-friendly. Each pixel in the golden masks has its value
// // multiplied by this factor, i.e. a value of 10 means class index 1, a value of
// // 20 means class index 2, etc.
// let kGoldenMaskMagnificationFactor: UInt8 = 10
// let deepLabV3SegmentationWidth = 257
// let deepLabV3SegmentationHeight = 257
// func verifyDeeplabV3PartialSegmentationResult(_ coloredLabels: [ColoredLabel]) {
// self.verifyColoredLabel(
// coloredLabels[0],
// expectedR: 0,
// expectedG: 0,
// expectedB: 0,
// expectedLabel: "background")
// self.verifyColoredLabel(
// coloredLabels[1],
// expectedR: 128,
// expectedG: 0,
// expectedB: 0,
// expectedLabel: "aeroplane")
// self.verifyColoredLabel(
// coloredLabels[2],
// expectedR: 0,
// expectedG: 128,
// expectedB: 0,
// expectedLabel: "bicycle")
// self.verifyColoredLabel(
// coloredLabels[3],
// expectedR: 128,
// expectedG: 128,
// expectedB: 0,
// expectedLabel: "bird")
// self.verifyColoredLabel(
// coloredLabels[4],
// expectedR: 0,
// expectedG: 0,
// expectedB: 128,
// expectedLabel: "boat")
// self.verifyColoredLabel(
// coloredLabels[5],
// expectedR: 128,
// expectedG: 0,
// expectedB: 128,
// expectedLabel: "bottle")
// self.verifyColoredLabel(
// coloredLabels[6],
// expectedR: 0,
// expectedG: 128,
// expectedB: 128,
// expectedLabel: "bus")
// self.verifyColoredLabel(
// coloredLabels[7],
// expectedR: 128,
// expectedG: 128,
// expectedB: 128,
// expectedLabel: "car")
// self.verifyColoredLabel(
// coloredLabels[8],
// expectedR: 64,
// expectedG: 0,
// expectedB: 0,
// expectedLabel: "cat")
// self.verifyColoredLabel(
// coloredLabels[9],
// expectedR: 192,
// expectedG: 0,
// expectedB: 0,
// expectedLabel: "chair")
// self.verifyColoredLabel(
// coloredLabels[10],
// expectedR: 64,
// expectedG: 128,
// expectedB: 0,
// expectedLabel: "cow")
// self.verifyColoredLabel(
// coloredLabels[11],
// expectedR: 192,
// expectedG: 128,
// expectedB: 0,
// expectedLabel: "dining table")
// self.verifyColoredLabel(
// coloredLabels[12],
// expectedR: 64,
// expectedG: 0,
// expectedB: 128,
// expectedLabel: "dog")
// self.verifyColoredLabel(
// coloredLabels[13],
// expectedR: 192,
// expectedG: 0,
// expectedB: 128,
// expectedLabel: "horse")
// self.verifyColoredLabel(
// coloredLabels[14],
// expectedR: 64,
// expectedG: 128,
// expectedB: 128,
// expectedLabel: "motorbike")
// self.verifyColoredLabel(
// coloredLabels[15],
// expectedR: 192,
// expectedG: 128,
// expectedB: 128,
// expectedLabel: "person")
// self.verifyColoredLabel(
// coloredLabels[16],
// expectedR: 0,
// expectedG: 64,
// expectedB: 0,
// expectedLabel: "potted plant")
// self.verifyColoredLabel(
// coloredLabels[17],
// expectedR: 128,
// expectedG: 64,
// expectedB: 0,
// expectedLabel: "sheep")
// self.verifyColoredLabel(
// coloredLabels[18],
// expectedR: 0,
// expectedG: 192,
// expectedB: 0,
// expectedLabel: "sofa")
// self.verifyColoredLabel(
// coloredLabels[19],
// expectedR: 128,
// expectedG: 192,
// expectedB: 0,
// expectedLabel: "train")
// self.verifyColoredLabel(
// coloredLabels[20],
// expectedR: 0,
// expectedG: 64,
// expectedB: 128,
// expectedLabel: "tv")
// }
// func verifyColoredLabel(
// _ coloredLabel: ColoredLabel,
// expectedR: UInt,
// expectedG: UInt,
// expectedB: UInt,
// expectedLabel: String
// ) {
// XCTAssertEqual(
// coloredLabel.r,
// expectedR)
// XCTAssertEqual(
// coloredLabel.g,
// expectedG)
// XCTAssertEqual(
// coloredLabel.b,
// expectedB)
// XCTAssertEqual(
// coloredLabel.label,
// expectedLabel)
// }
// func testSuccessfullInferenceOnMLImageWithUIImage() throws {
// let modelPath = try XCTUnwrap(ImageSegmenterTests.modelPath)
// let imageSegmenterOptions = ImageSegmenterOptions(modelPath: modelPath)
// let imageSegmenter =
// try ImageSegmenter.segmenter(options: imageSegmenterOptions)
// let gmlImage = try XCTUnwrap(
// MLImage.imageFromBundle(
// class: type(of: self),
// filename: "segmentation_input_rotation0",
// type: "jpg"))
// let segmentationResult: SegmentationResult =
// try XCTUnwrap(imageSegmenter.segment(mlImage: gmlImage))
// XCTAssertEqual(segmentationResult.segmentations.count, 1)
// let coloredLabels = try XCTUnwrap(segmentationResult.segmentations[0].coloredLabels)
// verifyDeeplabV3PartialSegmentationResult(coloredLabels)
// let categoryMask = try XCTUnwrap(segmentationResult.segmentations[0].categoryMask)
// XCTAssertEqual(deepLabV3SegmentationWidth, categoryMask.width)
// XCTAssertEqual(deepLabV3SegmentationHeight, categoryMask.height)
// let goldenMaskImage = try XCTUnwrap(
// MLImage.imageFromBundle(
// class: type(of: self),
// filename: "segmentation_golden_rotation0",
// type: "png"))
// let pixelBuffer = goldenMaskImage.grayScalePixelBuffer().takeRetainedValue()
// CVPixelBufferLockBaseAddress(pixelBuffer, CVPixelBufferLockFlags.readOnly)
// let pixelBufferBaseAddress = (try XCTUnwrap(CVPixelBufferGetBaseAddress(pixelBuffer)))
// .assumingMemoryBound(to: UInt8.self)
// let numPixels = deepLabV3SegmentationWidth * deepLabV3SegmentationHeight
// let mask = try XCTUnwrap(categoryMask.mask)
// var inconsistentPixels: Float = 0.0
// for i in 0..<numPixels {
// if mask[i] * kGoldenMaskMagnificationFactor != pixelBufferBaseAddress[i] {
// inconsistentPixels += 1
// }
// }
// CVPixelBufferUnlockBaseAddress(pixelBuffer, CVPixelBufferLockFlags.readOnly)
// XCTAssertLessThan(inconsistentPixels / Float(numPixels), kGoldenMaskTolerance)
// }
}

View File

@ -1,89 +0,0 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import <XCTest/XCTest.h>
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
NS_ASSUME_NONNULL_BEGIN
static NSString *const kBertTextClassifierModelName = @"bert_text_classifier";
static NSString *const kNegativeText = @"unflinchingly bleak and desperate";
#define VerifyCategory(category, expectedIndex, expectedScore, expectedLabel, expectedDisplayName) \
XCTAssertEqual(category.index, expectedIndex); \
XCTAssertEqualWithAccuracy(category.score, expectedScore, 1e-6); \
XCTAssertEqualObjects(category.label, expectedLabel); \
XCTAssertEqualObjects(category.displayName, expectedDisplayName);
#define VerifyClassifications(classifications, expectedHeadIndex, expectedCategoryCount) \
XCTAssertEqual(classifications.categories.count, expectedCategoryCount);
#define VerifyClassificationResult(classificationResult, expectedClassificationsCount) \
XCTAssertNotNil(classificationResult); \
XCTAssertEqual(classificationResult.classifications.count, expectedClassificationsCount)
#define AssertClassificationResultHasOneHead(classificationResult) \
XCTAssertNotNil(classificationResult); \
XCTAssertEqual(classificationResult.classifications.count, 1);
XCTAssertEqual(classificationResult.classifications[0].headIndex, 1);
#define AssertTextClassifierResultIsNotNil(textClassifierResult) \
XCTAssertNotNil(textClassifierResult);
@interface MPPTextClassifierTests : XCTestCase
@end
@implementation MPPTextClassifierTests
- (void)setUp {
[super setUp];
}
- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName
ofType:extension];
XCTAssertNotNil(filePath);
return filePath;
}
- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName {
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
MPPTextClassifierOptions *textClassifierOptions =
[[MPPTextClassifierOptions alloc] init];
textClassifierOptions.baseOptions.modelAssetPath = modelPath;
return textClassifierOptions;
}
kBertTextClassifierModelName
- (MPPTextClassifier *)createTextClassifierFromOptionsWithModelName:(NSString *)modelName {
MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:modelName];
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
XCTAssertNotNil(textClassifier);
return textClassifier
}
- (void)classifyWithBertSucceeds {
MPPTextClassifier *textClassifier = [self createTextClassifierWithModelName:kBertTextClassifierModelName];
MPPTextClassifierResult *textClassifierResult = [textClassifier classifyWithText:kNegativeText];
}
@end
NS_ASSUME_NONNULL_END

View File

@ -17,17 +17,15 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPBaseTextTaskApi",
srcs = ["sources/MPPBaseTextTaskApi.mm"],
hdrs = ["sources/MPPBaseTextTaskApi.h"],
name = "MPPTextTaskRunner",
srcs = ["sources/MPPTextTaskRunner.mm"],
hdrs = ["sources/MPPTextTaskRunner.h"],
copts = [
"-ObjC++",
"-std=c++17",
],
deps = [
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
"//mediapipe/tasks/ios/core:MPPTaskRunner",
],
)

View File

@ -1,48 +0,0 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import <Foundation/Foundation.h>
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
NS_ASSUME_NONNULL_BEGIN
/**
* The base class of the user-facing iOS mediapipe text task api classes.
*/
NS_SWIFT_NAME(BaseTextTaskApi)
@interface MPPBaseTextTaskApi : NSObject {
@protected
std::unique_ptr<mediapipe::tasks::core::TaskRunner> cppTaskRunner;
}
/**
* Initializes a new `MPPBaseTextTaskApi` with the mediapipe text task graph config proto.
*
* @param graphConfig A mediapipe text task graph config proto.
*
* @return An instance of `MPPBaseTextTaskApi` initialized to the given graph config proto.
*/
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
error:(NSError **)error;
- (void)close;
- (instancetype)init NS_UNAVAILABLE;
+ (instancetype)new NS_UNAVAILABLE;
@end
NS_ASSUME_NONNULL_END

View File

@ -1,52 +0,0 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import "mediapipe/tasks/ios/text/core/sources/MPPBaseTextTaskApi.h"
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
namespace {
using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Packet;
using ::mediapipe::tasks::core::PacketMap;
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
} // namespace
@interface MPPBaseTextTaskApi () {
/** TextSearcher backed by C++ API */
std::unique_ptr<TaskRunnerCpp> _cppTaskRunner;
}
@end
@implementation MPPBaseTextTaskApi
- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig
error:(NSError **)error {
self = [super init];
if (self) {
auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig));
if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) {
return nil;
}
_cppTaskRunner = std::move(taskRunnerResult.value());
}
return self;
}
- (void)close {
_cppTaskRunner->Close();
}
@end

View File

@ -0,0 +1,37 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h"
NS_ASSUME_NONNULL_BEGIN
/**
* This class is used to create and call appropriate methods on the C++ Task Runner to initialize, execute and terminate any Mediapipe text task.
*/
@interface MPPTextTaskRunner : MPPTaskRunner
/**
* Initializes a new `MPPTextTaskRunner` with the mediapipe task graph config proto.
*
* @param graphConfig A mediapipe task graph config proto.
*
* @return An instance of `MPPTextTaskRunner` initialized to the given graph config proto.
*/
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,29 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h"
namespace {
using ::mediapipe::CalculatorGraphConfig;
} // namespace
@implementation MPPTextTaskRunner
- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig
error:(NSError **)error {
self = [super initWithCalculatorGraphConfig:graphConfig packetsCallback:nullptr error:error];
return self;
}
@end

View File

@ -1,33 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPBaseTextTaskApi",
srcs = ["sources/MPPBaseTextTaskApi.mm"],
hdrs = ["sources/MPPBaseTextTaskApi.h"],
copts = [
"-ObjC++",
"-std=c++17",
],
deps = [
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
],
)

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -25,9 +25,11 @@ objc_library(
"-std=c++17",
],
deps = [
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"//mediapipe/tasks/ios/core:MPPTaskOptions",
"//mediapipe/tasks/ios/core:MPPTaskInfo",
"//mediapipe/tasks/ios/core:MPPTaskRunner",
"//mediapipe/tasks/ios/text/core:MPPTextTaskRunner",
"//mediapipe/tasks/ios/core:MPPTextPacketCreator",
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers",
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers",
@ -35,6 +37,9 @@ objc_library(
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
":MPPTextClassifierOptions",
],
sdk_frameworks = [
"MetalKit",
],
)
objc_library(

View File

@ -1,34 +1,61 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h"
NS_ASSUME_NONNULL_BEGIN
/**
* A Mediapipe iOS Text Classifier.
* This API expects a TFLite model with (optional) [TFLite Model
* Metadata](https://www.tensorflow.org/lite/convert/metadata")that contains the mandatory
* (described below) input tensors, output tensor, and the optional (but recommended) label items as
* AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor.
*
* Metadata is required for models with int32 input tensors because it contains the input process
* unit for the model's Tokenizer. No metadata is required for models with string input tensors.
*
* Input tensors
* - Three input tensors `kTfLiteInt32` of shape `[batch_size xbert_max_seq_len]`
* representing the input ids, mask ids, and segment ids. This input signature requires a
* Bert Tokenizer process unit in the model metadata.
* - Or one input tensor `kTfLiteInt32` of shape `[batch_size xmax_seq_len]` representing
* the input ids. This input signature requires a Regex Tokenizer process unit in the
* model metadata.
* - Or one input tensor ({@code kTfLiteString}) that is shapeless or has shape `[1]` containing
* the input string.
*
* At least one output tensor `(kTfLiteFloat32}/kBool)` with:
* - `N` classes and shape `[1 x N]`
* - optional (but recommended) label map(s) as AssociatedFile-s with type TENSOR_AXIS_LABELS,
* containing one label per line. The first such AssociatedFile (if any) is used to fill the
* `class_name` field of the results. The `display_name` field is filled from the AssociatedFile
* (if any) whose locale matches the `display_names_locale` field of the
* `MPPTextClassifierOptions` used at creation time ("en" by default, i.e. English). If none of
* these are available, only the `index` field of the results will be filled.
*
* @brief Performs classification on text.
*/
NS_SWIFT_NAME(TextClassifier)
@interface MPPTextClassifier : NSObject
/**
* Creates a new instance of `MPPTextClassifier` from an absolute path to a TensorFlow Lite model
* file stored locally on the device.
* file stored locally on the device and the default `MPPTextClassifierOptions`.
*
* @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
*
@ -41,9 +68,11 @@ NS_SWIFT_NAME(TextClassifier)
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error;
/**
* Creates a new instance of `MPPTextClassifier` from the given text classifier options.
* Creates a new instance of `MPPTextClassifier` from the given `MPPTextClassifierOptions`.
*
* @param options The options of type `MPPTextClassifierOptions` to use for configuring the
* `MPPTextClassifier`.
*
* @param options The options to use for configuring the `MPPTextClassifier`.
* @param error An optional error parameter populated when there is an error in initializing
* the text classifier.
*
@ -52,6 +81,16 @@ NS_SWIFT_NAME(TextClassifier)
*/
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error;
/**
* Performs classification on the input text.
*
* @param text The `NSString` on which classification is to be performed.
*
* @param error An optional error parameter populated when there is an error in performing
* classification on the input text.
*
* @return A `MPPTextClassifierResult` object that contains a list of text classifications.
*/
- (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error;
- (instancetype)init NS_UNAVAILABLE;

View File

@ -1,25 +1,27 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h"
#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h"
#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h"
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h"
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "absl/status/statusor.h"
@ -37,14 +39,13 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
@interface MPPTextClassifier () {
/** TextSearcher backed by C++ API */
MPPTaskRunner *_taskRunner;
MPPTextTaskRunner *_taskRunner;
}
@end
@implementation MPPTextClassifier
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error {
MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc]
initWithTaskGraphName:kTaskGraphName
inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ]
@ -58,10 +59,11 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
return nil;
}
_taskRunner = [[MPPTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error];
_taskRunner =
[[MPPTextTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig]
error:error];
self = [super init];
return self;
}
@ -76,14 +78,23 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
- (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error {
Packet packet = [MPPTextPacketCreator createWithText:text];
absl::StatusOr<PacketMap> output_packet_map = [_taskRunner process:{{kTextInStreamName.cppString, packet}} error:error];
if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) {
std::map<std::string, Packet> packet_map = {{kTextInStreamName.cppString, packet}};
absl::StatusOr<PacketMap> status_or_output_packet_map = [_taskRunner process:packet_map];
if (![MPPCommonUtils checkCppError:status_or_output_packet_map.status() toError:error]) {
return nil;
}
Packet classifications_packet =
status_or_output_packet_map.value()[kClassificationsStreamName.cppString];
return [MPPTextClassifierResult
textClassifierResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString]
.Get<ClassificationResult>()];
textClassifierResultWithClassificationsPacket:status_or_output_packet_map.value()
[kClassificationsStreamName.cppString]];
// return [MPPTextClassifierResult
// textClassifierResultWithClassificationsPacket:output_packet_map.value()[kClassificationsStreamName.cppString]
// .Get<ClassificationResult>()];
}
@end

View File

@ -1,17 +1,17 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h"
@ -20,32 +20,16 @@
NS_ASSUME_NONNULL_BEGIN
/**
* Options to configure MPPTextClassifierOptions.
* Options for setting up a `MPPTextClassifierOptions`.
*/
NS_SWIFT_NAME(TextClassifierOptions)
@interface MPPTextClassifierOptions : MPPTaskOptions
/**
* Options controlling the behavior of the embedding model specified in the
* base options.
* Options for configuring the classifier behavior, such as score threshold, number of results, etc.
*/
@property(nonatomic, copy) MPPClassifierOptions *classifierOptions;
// /**
// * Initializes a new `MPPTextClassifierOptions` with the absolute path to the model file
// * stored locally on the device, set to the given the model path.
// *
// * @discussion The external model file must be a single standalone TFLite file. It could be packed
// * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the
// * necessary metadata and associated files might result in errors. Check the [documentation]
// * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement.
// *
// * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
// *
// * @return An instance of `MPPTextClassifierOptions` initialized to the given model path.
// */
// - (instancetype)initWithModelPath:(NSString *)modelPath;
@end
NS_ASSUME_NONNULL_END

View File

@ -1,27 +1,27 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
@implementation MPPTextClassifierOptions
// - (instancetype)initWithModelPath:(NSString *)modelPath {
// self = [super initWithModelPath:modelPath];
// if (self) {
// _classifierOptions = [[MPPClassifierOptions alloc] init];
// }
// return self;
// }
- (instancetype)init {
self = [super init];
if (self) {
_classifierOptions = [[MPPClassifierOptions alloc] init];
}
return self;
}
@end

View File

@ -1,4 +1,4 @@
// Copyright 2022 The MediaPipe Authors.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -18,23 +18,27 @@
NS_ASSUME_NONNULL_BEGIN
/** Encapsulates list of predicted classes (aka labels) for a given image classifier head. */
/** Represents the classification results generated by `MPPTextClassifier`. */
NS_SWIFT_NAME(TextClassifierResult)
@interface MPPTextClassifierResult : MPPTaskResult
/** The `MPPClassificationResult` instance containing one set of results per classifier head. */
@property(nonatomic, readonly) MPPClassificationResult *classificationResult;
/**
* Initializes a new `MPPClassificationResult` with the given array of classifications.
* Initializes a new `MPPTextClassifierResult` with the given `MPPClassificationResult` and time
* stamp (in milliseconds).
*
* @param classifications An Aaray of `MPPClassifications` objects containing classifier
* predictions per classifier head.
* @param classificationResult The `MPPClassificationResult` instance containing one set of results
* per classifier head.
*
* @return An instance of MPPClassificationResult initialized with the given array of
* classifications.
* @param timeStampMs The time stamp for this result.
*
* @return An instance of `MPPTextClassifierResult` initialized with the given
* `MPPClassificationResult` and time stamp (in milliseconds).
*/
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
timeStamp:(long)timeStamp;
timestampMs:(NSInteger)timestampMs;
@end

View File

@ -1,4 +1,4 @@
// Copyright 2022 The MediaPipe Authors.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -17,8 +17,8 @@
@implementation MPPTextClassifierResult
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
timeStamp:(long)timeStamp {
self = [super initWithTimestamp:timeStamp];
timestampMs:(NSInteger)timestampMs {
self = [super initWithTimestampMs:timestampMs];
if (self) {
_classificationResult = classificationResult;
}

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -36,5 +36,6 @@ objc_library(
deps = [
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierResult",
"//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers",
"//mediapipe/framework:packet",
],
)

View File

@ -1,4 +1,4 @@
// Copyright 2022 The MediaPipe Authors.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
NS_ASSUME_NONNULL_BEGIN

View File

@ -1,4 +1,4 @@
// Copyright 2022 The MediaPipe Authors.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
// Copyright 2022 The MediaPipe Authors.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h"
#include "mediapipe/framework/packet.h"
NS_ASSUME_NONNULL_BEGIN
@interface MPPTextClassifierResult (Helpers)
+ (MPPTextClassifierResult *)textClassifierResultWithProto:
(const mediapipe::tasks::components::containers::proto::ClassificationResult &)
classificationResultProto;
+ (MPPTextClassifierResult *)textClassifierResultWithClassificationsPacket:
(const mediapipe::Packet &)packet;
@end

View File

@ -1,4 +1,4 @@
// Copyright 2022 The MediaPipe Authors.
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -12,28 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
static const int kMicroSecondsPerMilliSecond = 1000;
namespace {
using ClassificationResultProto =
::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::Packet;
} // namespace
#define int kMicroSecondsPerMilliSecond = 1000;
@implementation MPPTextClassifierResult (Helpers)
+ (MPPTextClassifierResult *)textClassifierResultWithProto:
(const ClassificationResultProto &)classificationResultProto {
long timeStamp;
+ (MPPTextClassifierResult *)textClassifierResultWithClassificationsPacket:(const Packet &)packet {
MPPClassificationResult *classificationResult = [MPPClassificationResult
classificationResultWithProto:packet.Get<ClassificationResultProto>()];
if (classificationResultProto.has_timestamp_ms()) {
timeStamp = classificationResultProto.timestamp_ms();
}
MPPClassificationResult *classificationResult = [MPPClassificationResult classificationResultWithProto:classificationResultProto];
return [[MPPTextClassifierResult alloc] initWithClassificationResult:classificationResult
timeStamp:timeStamp];
return [[MPPTextClassifierResult alloc]
initWithClassificationResult:classificationResult
timestampMs:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)];
}
@end