This commit is contained in:
priankakariatyml 2023-01-08 18:02:13 +05:30 committed by GitHub
commit d74e41b7e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 1864 additions and 5 deletions

View File

@ -0,0 +1,32 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPCategory",
srcs = ["sources/MPPCategory.m"],
hdrs = ["sources/MPPCategory.h"],
)
objc_library(
name = "MPPClassificationResult",
srcs = ["sources/MPPClassificationResult.m"],
hdrs = ["sources/MPPClassificationResult.h"],
deps = [
":MPPCategory",
],
)

View File

@ -0,0 +1,65 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
NS_ASSUME_NONNULL_BEGIN
/** Category is a util class, contains a label, its display name, a float value as score, and the
* index of the label in the corresponding label file. Typically it's used as the result of
* classification tasks. */
NS_SWIFT_NAME(ClassificationCategory)
@interface MPPCategory : NSObject
/** The index of the label in the corresponding label file. It takes the value -1 if the index is
* not set. */
@property(nonatomic, readonly) NSInteger index;
/** Confidence score for this class . */
@property(nonatomic, readonly) float score;
/** The label of this category object. */
@property(nonatomic, readonly, nullable) NSString *categoryName;
/** The display name of the label, which may be translated for different locales. For example, a
* label, "apple", may be translated into Spanish for display purpose, so that the display name is
* "manzana". */
@property(nonatomic, readonly, nullable) NSString *displayName;
/**
* Initializes a new `MPPCategory` with the given index, score, category name and display name.
*
* @param index The index of the label in the corresponding label file.
*
* @param score The probability score of this label category.
*
* @param categoryName The label of this category object..
*
* @param displayName The display name of the label.
*
* @return An instance of `MPPCategory` initialized with the given index, score, category name and
* display name.
*/
- (instancetype)initWithIndex:(NSInteger)index
score:(float)score
categoryName:(nullable NSString *)categoryName
displayName:(nullable NSString *)displayName;
- (instancetype)init NS_UNAVAILABLE;
+ (instancetype)new NS_UNAVAILABLE;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,33 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h"
@implementation MPPCategory
- (instancetype)initWithIndex:(NSInteger)index
score:(float)score
categoryName:(nullable NSString *)categoryName
displayName:(nullable NSString *)displayName {
self = [super init];
if (self) {
_index = index;
_score = score;
_categoryName = categoryName;
_displayName = displayName;
}
return self;
}
@end

View File

@ -0,0 +1,114 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h"
NS_ASSUME_NONNULL_BEGIN
/** Represents the list of classification for a given classifier head. Typically used as a result
* for classification tasks. */
NS_SWIFT_NAME(Classifications)
@interface MPPClassifications : NSObject
/** The index of the classifier head these entries refer to. This is useful for multi-head models.
*/
@property(nonatomic, readonly) NSInteger headIndex;
/** The optional name of the classifier head, which is the corresponding tensor metadata name. */
@property(nonatomic, readonly, nullable) NSString *headName;
/** An array of `MPPCategory` objects containing the predicted categories. */
@property(nonatomic, readonly) NSArray<MPPCategory *> *categories;
/**
* Initializes a new `MPPClassifications` object with the given head index and array of categories.
* Head name is initialized to `nil`.
*
* @param headIndex The index of the classifier head.
* @param categories An array of `MPPCategory` objects containing the predicted categories.
*
* @return An instance of `MPPClassifications` initialized with the given head index and
* array of categories.
*/
- (instancetype)initWithHeadIndex:(NSInteger)headIndex
categories:(NSArray<MPPCategory *> *)categories;
/**
* Initializes a new `MPPClassifications` with the given head index, head name and array of
* categories.
*
* @param headIndex The index of the classifier head.
* @param headName The name of the classifier head, which is the corresponding tensor metadata
* name.
* @param categories An array of `MPPCategory` objects containing the predicted categories.
*
* @return An object of `MPPClassifications` initialized with the given head index, head name and
* array of categories.
*/
- (instancetype)initWithHeadIndex:(NSInteger)headIndex
headName:(nullable NSString *)headName
categories:(NSArray<MPPCategory *> *)categories;
@end
/**
* Represents the classification results of a model. Typically used as a result for classification
* tasks.
*/
NS_SWIFT_NAME(ClassificationResult)
@interface MPPClassificationResult : NSObject
/** An Array of `MPPClassifications` objects containing the predicted categories for each head of
* the model. */
@property(nonatomic, readonly) NSArray<MPPClassifications *> *classifications;
/** The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to
* these results. If it is set to the value -1, it signifies the absence of a time stamp. This is
* only used for classification on time series (e.g. audio classification). In these use cases, the
* amount of data to process might exceed the maximum size that the model can process: to solve
* this, the input data is split into multiple chunks starting at different timestamps. */
@property(nonatomic, readonly) NSInteger timestampMs;
/**
* Initializes a new `MPPClassificationResult` with the given array of classifications. This method
* must be used when no time stamp needs to be specified. It sets the property `timestampMs` to -1.
*
* @param classifications An Aaray of `MPPClassifications` objects containing classifier
* predictions per classifier head.
*
* @return An instance of MPPClassificationResult initialized with the given array of
* classifications.
*/
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications;
/**
* Initializes a new `MPPClassificationResult` with the given array of classifications and time
* stamp (in milliseconds).
*
* @param classifications An Array of `MPPClassifications` objects containing the predicted
* categories for each head of the model.
*
* @param timeStampMs The timestamp (in milliseconds) of the start of the chunk of data
* corresponding to these results.
*
* @return An instance of `MPPClassificationResult` initialized with the given array of
* classifications and timestampMs.
*/
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
timestampMs:(NSInteger)timestampMs;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,57 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h"
@implementation MPPClassifications
- (instancetype)initWithHeadIndex:(NSInteger)headIndex
headName:(nullable NSString *)headName
categories:(NSArray<MPPCategory *> *)categories {
self = [super init];
if (self) {
_headIndex = headIndex;
_headName = headName;
_categories = categories;
}
return self;
}
- (instancetype)initWithHeadIndex:(NSInteger)headIndex
categories:(NSArray<MPPCategory *> *)categories {
return [self initWithHeadIndex:headIndex headName:nil categories:categories];
}
@end
@implementation MPPClassificationResult
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
timestampMs:(NSInteger)timestampMs {
self = [super init];
if (self) {
_classifications = classifications;
_timestampMs = timestampMs;
}
return self;
}
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications {
return [self initWithClassifications:classifications timestampMs:-1];
return self;
}
@end

View File

@ -0,0 +1,40 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPCategoryHelpers",
srcs = ["sources/MPPCategory+Helpers.mm"],
hdrs = ["sources/MPPCategory+Helpers.h"],
deps = [
"//mediapipe/tasks/ios/components/containers:MPPCategory",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
],
)
objc_library(
name = "MPPClassificationResultHelpers",
srcs = ["sources/MPPClassificationResult+Helpers.mm"],
hdrs = ["sources/MPPClassificationResult+Helpers.h"],
deps = [
"//mediapipe/tasks/ios/components/containers:MPPClassificationResult",
":MPPCategoryHelpers",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
],
)

View File

@ -0,0 +1,26 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/formats/classification.pb.h"
#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h"
NS_ASSUME_NONNULL_BEGIN
@interface MPPCategory (Helpers)
+ (MPPCategory *)categoryWithProto:(const mediapipe::Classification &)classificationProto;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,42 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h"
namespace {
using ClassificationProto = ::mediapipe::Classification;
}
@implementation MPPCategory (Helpers)
+ (MPPCategory *)categoryWithProto:(const ClassificationProto &)clasificationProto {
NSString *categoryName;
NSString *displayName;
if (clasificationProto.has_label()) {
categoryName = [NSString stringWithCppString:clasificationProto.label()];
}
if (clasificationProto.has_display_name()) {
displayName = [NSString stringWithCppString:clasificationProto.display_name()];
}
return [[MPPCategory alloc] initWithIndex:clasificationProto.index()
score:clasificationProto.score()
categoryName:categoryName
displayName:displayName];
}
@end

View File

@ -0,0 +1,35 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h"
NS_ASSUME_NONNULL_BEGIN
@interface MPPClassifications (Helpers)
+ (MPPClassifications *)classificationsWithProto:
(const mediapipe::tasks::components::containers::proto::Classifications &)classificationsProto;
@end
@interface MPPClassificationResult (Helpers)
+ (MPPClassificationResult *)classificationResultWithProto:
(const mediapipe::tasks::components::containers::proto::ClassificationResult &)
classificationResultProto;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,68 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
namespace {
using ClassificationsProto = ::mediapipe::tasks::components::containers::proto::Classifications;
using ClassificationResultProto =
::mediapipe::tasks::components::containers::proto::ClassificationResult;
} // namespace
@implementation MPPClassifications (Helpers)
+ (MPPClassifications *)classificationsWithProto:
(const ClassificationsProto &)classificationsProto {
NSMutableArray *categories = [[NSMutableArray alloc] init];
for (const auto &classification : classificationsProto.classification_list().classification()) {
[categories addObject:[MPPCategory categoryWithProto:classification]];
}
NSString *headName;
if (classificationsProto.has_head_name()) {
headName = [NSString stringWithCppString:classificationsProto.head_name()];
}
return [[MPPClassifications alloc] initWithHeadIndex:(NSInteger)classificationsProto.head_index()
headName:headName
categories:categories];
}
@end
@implementation MPPClassificationResult (Helpers)
+ (MPPClassificationResult *)classificationResultWithProto:
(const ClassificationResultProto &)classificationResultProto {
NSMutableArray *classifications = [[NSMutableArray alloc] init];
for (const auto &classifications_proto : classificationResultProto.classifications()) {
[classifications addObject:[MPPClassifications classificationsWithProto:classifications_proto]];
}
MPPClassificationResult *classificationResult;
if (classificationResultProto.has_timestamp_ms()) {
classificationResult = [[MPPClassificationResult alloc] initWithClassifications:classifications timestampMs:(NSInteger)classificationResultProto.timestamp_ms()];
}
else {
classificationResult = [[MPPClassificationResult alloc] initWithClassifications:classifications];
}
return classificationResult;
}
@end

View File

@ -90,6 +90,13 @@ objc_library(
deps = [
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
],
)
objc_library(
name = "MPPResultCallback",
hdrs = ["sources/MPPResultCallback.h"],
)

View File

@ -0,0 +1,28 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import <Foundation/Foundation.h>
NS_ASSUME_NONNULL_BEGIN
/**
* Holds information about an external file.
*/
NS_SWIFT_NAME(ExternalFile)
@interface MPPExternalFile : NSObject <NSCopying>
/** Path to the file in bundle. */
@property(nonatomic, copy) NSString *filePath;
/// Add provision for other sources in future.
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,27 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import "mediapipe/tasks/ios/core/sources/MPPExternalFile.h"
@implementation MPPExternalFile
- (id)copyWithZone:(NSZone *)zone {
MPPExternalFile *externalFile = [[MPPExternalFile alloc] init];
externalFile.filePath = self.filePath;
return externalFile;
}
@end

View File

@ -0,0 +1,21 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
NS_ASSUME_NONNULL_BEGIN
typedef void (^MPPResultCallback)(id oputput, id input, NSError *error);
NS_ASSUME_NONNULL_END

View File

@ -20,23 +20,63 @@
NS_ASSUME_NONNULL_BEGIN
/**
* This class is used to create and call appropriate methods on the C++ Task Runner.
* This class is used to create and call appropriate methods on the C++ Task Runner to initialize,
* execute and terminate any Mediapipe task.
*
* An instance of the newly created C++ task runner will
* be stored until this class is destroyed. When methods are called for processing (performing
* inference), closing etc., on this class, internally the appropriate methods will be called on the
* C++ task runner instance to execute the appropriate actions. For each type of task, a subclass of
* this class must be defined to add any additional functionality. For eg:, vision tasks must create
* an `MPPVisionTaskRunner` and provide additional functionality. An instance of
* `MPPVisionTaskRunner` can in turn be used by the each vision task for creation and execution of
* the task. Please see the documentation for the C++ Task Runner for more details on how the taks
* runner operates.
*/
@interface MPPTaskRunner : NSObject
/**
* Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto.
* Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto and an optional C++
* packets callback.
*
* You can pass `nullptr` for `packetsCallback` in case the mode of operation
* requested by the user is synchronous.
*
* If the task is operating in asynchronous mode, any iOS Mediapipe task that uses the `MPPTaskRunner`
* must define a C++ callback function to obtain the results of inference asynchronously and deliver
* the results to the user. To accomplish this, callback function will in turn invoke the block
* provided by the user in the task options supplied to create the task.
* Please see the documentation of the C++ Task Runner for more information on the synchronous and
* asynchronous modes of operation.
*
* @param graphConfig A mediapipe task graph config proto.
*
* @return An instance of `MPPTaskRunner` initialized to the given graph config proto.
* @param packetsCallback An optional C++ callback function that takes a list of output packets as
* the input argument. If provided, the callback must in turn call the block provided by the user in
* the appropriate task options.
*
* @return An instance of `MPPTaskRunner` initialized to the given graph config proto and optional
* packetsCallback.
*/
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
packetsCallback:
(mediapipe::tasks::core::PacketsCallback)packetsCallback
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
/** A synchronous method for processing batch data or offline streaming data. This method is
designed for processing either batch data such as unrelated images and texts or offline streaming
data such as the decoded frames from a video file and an audio file. The call blocks the current
thread until a failure status or a successful result is returned. If the input packets have no
timestamp, an internal timestamp will be assigend per invocation. Otherwise, when the timestamp is
set in the input packets, the caller must ensure that the input packet timestamps are greater than
the timestamps of the previous invocation. This method is thread-unsafe and it is the caller's
responsibility to synchronize access to this method across multiple threads and to ensure that the
input packet timestamps are in order.*/
- (absl::StatusOr<mediapipe::tasks::core::PacketMap>)process:
(const mediapipe::tasks::core::PacketMap &)packetMap;
/** Shuts down the C++ task runner. After the runner is closed, any calls that send input data to
* the runner are illegal and will receive errors. */
- (absl::Status)close;
- (instancetype)init NS_UNAVAILABLE;

View File

@ -13,11 +13,15 @@
// limitations under the License.
#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
#include "tensorflow/lite/core/api/op_resolver.h"
namespace {
using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::tasks::core::MediaPipeBuiltinOpResolver;
using ::mediapipe::tasks::core::PacketMap;
using ::mediapipe::tasks::core::PacketsCallback;
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
} // namespace
@ -30,15 +34,17 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
@implementation MPPTaskRunner
- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig
packetsCallback:(PacketsCallback)packetsCallback
error:(NSError **)error {
self = [super init];
if (self) {
auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig));
auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig),
absl::make_unique<MediaPipeBuiltinOpResolver>(),
std::move(packetsCallback));
if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) {
return nil;
}
_cppTaskRunner = std::move(taskRunnerResult.value());
}
return self;

View File

@ -0,0 +1,27 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPBaseOptionsHelpers",
srcs = ["sources/MPPBaseOptions+Helpers.mm"],
hdrs = ["sources/MPPBaseOptions+Helpers.h"],
deps = [
"//mediapipe/tasks/ios/core:MPPBaseOptions",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
],
)

View File

@ -0,0 +1,26 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h"
NS_ASSUME_NONNULL_BEGIN
@interface MPPBaseOptions (Helpers)
- (void)copyToProto:(mediapipe::tasks::core::proto::BaseOptions *)baseOptionsProto;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,40 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h"
namespace {
using BaseOptionsProto = ::mediapipe::tasks::core::proto::BaseOptions;
}
@implementation MPPBaseOptions (Helpers)
- (void)copyToProto:(BaseOptionsProto *)baseOptionsProto {
if (self.modelAssetPath) {
baseOptionsProto->mutable_model_asset()->set_file_name(self.modelAssetPath.UTF8String);
}
switch (self.delegate) {
case MPPDelegateCPU: {
baseOptionsProto->mutable_acceleration()->mutable_tflite();
break;
}
case MPPDelegateGPU:
break;
default:
break;
}
}
@end

View File

@ -0,0 +1,15 @@
"""Mediapipe Task Library Helper Rules for iOS"""
MPP_TASK_MINIMUM_OS_VERSION = "11.0"
# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
MPP_TASK_DEFAULT_TAGS = [
"apple",
]
# Following sanitizer tests are not supported by iOS test targets.
MPP_TASK_DISABLED_SANITIZER_TAGS = [
"noasan",
"nomsan",
"notsan",
]

View File

@ -0,0 +1,61 @@
load(
"@build_bazel_rules_apple//apple:ios.bzl",
"ios_unit_test",
)
load(
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
"tflite_ios_lab_runner"
)
load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPTextClassifierObjcTestLibrary",
testonly = 1,
srcs = ["MPPTextClassifierTests.m"],
data = [
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
"//mediapipe/tasks/testdata/text:text_classifier_models",
],
tags = [],
copts = [
"-ObjC++",
"-std=c++17",
],
deps = [
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier",
],
)
ios_unit_test(
name = "MPPTextClassifierObjcTest",
minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags =[],
deps = [
":MPPTextClassifierObjcTestLibrary",
],
)
swift_library(
name = "MPPTextClassifierSwiftTestLibrary",
testonly = 1,
srcs = ["TextClassifierTests.swift"],
tags = [],
)
ios_unit_test(
name = "MPPTextClassifierSwiftTest",
minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = [],
deps = [
":MPPTextClassifierSwiftTestLibrary",
],
)

View File

@ -0,0 +1,110 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import <XCTest/XCTest.h>
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
static NSString *const kBertTextClassifierModelName = @"bert_text_classifier";
static NSString *const kNegativeText = @"unflinchingly bleak and desperate";
static NSString *const kPositiveText = @"it's a charming and often affecting journey";
#define AssertCategoriesAre(categories, expectedCategories) \
XCTAssertEqual(categories.count, expectedCategories.count); \
for (int i = 0; i < categories.count; i++) { \
XCTAssertEqual(categories[i].index, expectedCategories[i].index); \
XCTAssertEqualWithAccuracy(categories[i].score, expectedCategories[i].score, 1e-6); \
XCTAssertEqualObjects(categories[i].categoryName, expectedCategories[i].categoryName); \
XCTAssertEqualObjects(categories[i].displayName, expectedCategories[i].displayName); \
}
#define AssertHasOneHead(textClassifierResult) \
XCTAssertNotNil(textClassifierResult); \
XCTAssertNotNil(textClassifierResult.classificationResult); \
XCTAssertEqual(textClassifierResult.classificationResult.classifications.count, 1); \
XCTAssertEqual(textClassifierResult.classificationResult.classifications[0].headIndex, 0);
@interface MPPTextClassifierTests : XCTestCase
@end
@implementation MPPTextClassifierTests
- (void)setUp {
}
- (void)tearDown {
// Put teardown code here. This method is called after the invocation of each test method in the class.
}
- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName
ofType:extension];
XCTAssertNotNil(filePath);
return filePath;
}
- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName {
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
MPPTextClassifierOptions *textClassifierOptions =
[[MPPTextClassifierOptions alloc] init];
textClassifierOptions.baseOptions.modelAssetPath = modelPath;
return textClassifierOptions;
}
- (MPPTextClassifier *)createTextClassifierFromOptionsWithModelName:(NSString *)modelName {
MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:modelName];
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
XCTAssertNotNil(textClassifier);
return textClassifier;
}
- (void)testClassifyWithBertSucceeds {
MPPTextClassifier *textClassifier = [self createTextClassifierFromOptionsWithModelName:kBertTextClassifierModelName];
MPPTextClassifierResult *negativeResult = [textClassifier classifyWithText:kNegativeText error:nil];
AssertHasOneHead(negativeResult);
NSArray<MPPCategory *> *expectedNegativeCategories = @[[[MPPCategory alloc] initWithIndex:0
score:0.956187f
categoryName:@"negative"
displayName:nil],
[[MPPCategory alloc] initWithIndex:1
score:0.043812f
categoryName:@"positive"
displayName:nil]];
AssertCategoriesAre(negativeResult.classificationResult.classifications[0].categories,
expectedNegativeCategories
);
// MPPTextClassifierResult *positiveResult = [textClassifier classifyWithText:kPositiveText error:nil];
// AssertHasOneHead(positiveResult);
// NSArray<MPPCategory *> *expectedPositiveCategories = @[[[MPPCategory alloc] initWithIndex:0
// score:0.99997187f
// label:@"positive"
// displayName:nil],
// [[MPPCategory alloc] initWithIndex:1
// score:2.8132641E-5f
// label:@"negative"
// displayName:nil]];
// AssertCategoriesAre(negativeResult.classificationResult.classifications[0].categories,
// expectedPositiveCategories
// );
}
@end

View File

@ -0,0 +1,272 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// import GMLImageUtils
import XCTest
// @testable import TFLImageSegmenter
class TextClassifierTests: XCTestCase {
func testExample() throws {
XCTAssertEqual(1, 1)
}
// static let bundle = Bundle(for: TextClassifierTests.self)
// static let modelPath = bundle.path(
// forResource: "deeplabv3",
// ofType: "tflite")
// // The maximum fraction of pixels in the candidate mask that can have a
// // different class than the golden mask for the test to pass.
// let kGoldenMaskTolerance: Float = 1e-2
// // Magnification factor used when creating the golden category masks to make
// // them more human-friendly. Each pixel in the golden masks has its value
// // multiplied by this factor, i.e. a value of 10 means class index 1, a value of
// // 20 means class index 2, etc.
// let kGoldenMaskMagnificationFactor: UInt8 = 10
// let deepLabV3SegmentationWidth = 257
// let deepLabV3SegmentationHeight = 257
// func verifyDeeplabV3PartialSegmentationResult(_ coloredLabels: [ColoredLabel]) {
// self.verifyColoredLabel(
// coloredLabels[0],
// expectedR: 0,
// expectedG: 0,
// expectedB: 0,
// expectedLabel: "background")
// self.verifyColoredLabel(
// coloredLabels[1],
// expectedR: 128,
// expectedG: 0,
// expectedB: 0,
// expectedLabel: "aeroplane")
// self.verifyColoredLabel(
// coloredLabels[2],
// expectedR: 0,
// expectedG: 128,
// expectedB: 0,
// expectedLabel: "bicycle")
// self.verifyColoredLabel(
// coloredLabels[3],
// expectedR: 128,
// expectedG: 128,
// expectedB: 0,
// expectedLabel: "bird")
// self.verifyColoredLabel(
// coloredLabels[4],
// expectedR: 0,
// expectedG: 0,
// expectedB: 128,
// expectedLabel: "boat")
// self.verifyColoredLabel(
// coloredLabels[5],
// expectedR: 128,
// expectedG: 0,
// expectedB: 128,
// expectedLabel: "bottle")
// self.verifyColoredLabel(
// coloredLabels[6],
// expectedR: 0,
// expectedG: 128,
// expectedB: 128,
// expectedLabel: "bus")
// self.verifyColoredLabel(
// coloredLabels[7],
// expectedR: 128,
// expectedG: 128,
// expectedB: 128,
// expectedLabel: "car")
// self.verifyColoredLabel(
// coloredLabels[8],
// expectedR: 64,
// expectedG: 0,
// expectedB: 0,
// expectedLabel: "cat")
// self.verifyColoredLabel(
// coloredLabels[9],
// expectedR: 192,
// expectedG: 0,
// expectedB: 0,
// expectedLabel: "chair")
// self.verifyColoredLabel(
// coloredLabels[10],
// expectedR: 64,
// expectedG: 128,
// expectedB: 0,
// expectedLabel: "cow")
// self.verifyColoredLabel(
// coloredLabels[11],
// expectedR: 192,
// expectedG: 128,
// expectedB: 0,
// expectedLabel: "dining table")
// self.verifyColoredLabel(
// coloredLabels[12],
// expectedR: 64,
// expectedG: 0,
// expectedB: 128,
// expectedLabel: "dog")
// self.verifyColoredLabel(
// coloredLabels[13],
// expectedR: 192,
// expectedG: 0,
// expectedB: 128,
// expectedLabel: "horse")
// self.verifyColoredLabel(
// coloredLabels[14],
// expectedR: 64,
// expectedG: 128,
// expectedB: 128,
// expectedLabel: "motorbike")
// self.verifyColoredLabel(
// coloredLabels[15],
// expectedR: 192,
// expectedG: 128,
// expectedB: 128,
// expectedLabel: "person")
// self.verifyColoredLabel(
// coloredLabels[16],
// expectedR: 0,
// expectedG: 64,
// expectedB: 0,
// expectedLabel: "potted plant")
// self.verifyColoredLabel(
// coloredLabels[17],
// expectedR: 128,
// expectedG: 64,
// expectedB: 0,
// expectedLabel: "sheep")
// self.verifyColoredLabel(
// coloredLabels[18],
// expectedR: 0,
// expectedG: 192,
// expectedB: 0,
// expectedLabel: "sofa")
// self.verifyColoredLabel(
// coloredLabels[19],
// expectedR: 128,
// expectedG: 192,
// expectedB: 0,
// expectedLabel: "train")
// self.verifyColoredLabel(
// coloredLabels[20],
// expectedR: 0,
// expectedG: 64,
// expectedB: 128,
// expectedLabel: "tv")
// }
// func verifyColoredLabel(
// _ coloredLabel: ColoredLabel,
// expectedR: UInt,
// expectedG: UInt,
// expectedB: UInt,
// expectedLabel: String
// ) {
// XCTAssertEqual(
// coloredLabel.r,
// expectedR)
// XCTAssertEqual(
// coloredLabel.g,
// expectedG)
// XCTAssertEqual(
// coloredLabel.b,
// expectedB)
// XCTAssertEqual(
// coloredLabel.label,
// expectedLabel)
// }
// func testSuccessfullInferenceOnMLImageWithUIImage() throws {
// let modelPath = try XCTUnwrap(ImageSegmenterTests.modelPath)
// let imageSegmenterOptions = ImageSegmenterOptions(modelPath: modelPath)
// let imageSegmenter =
// try ImageSegmenter.segmenter(options: imageSegmenterOptions)
// let gmlImage = try XCTUnwrap(
// MLImage.imageFromBundle(
// class: type(of: self),
// filename: "segmentation_input_rotation0",
// type: "jpg"))
// let segmentationResult: SegmentationResult =
// try XCTUnwrap(imageSegmenter.segment(mlImage: gmlImage))
// XCTAssertEqual(segmentationResult.segmentations.count, 1)
// let coloredLabels = try XCTUnwrap(segmentationResult.segmentations[0].coloredLabels)
// verifyDeeplabV3PartialSegmentationResult(coloredLabels)
// let categoryMask = try XCTUnwrap(segmentationResult.segmentations[0].categoryMask)
// XCTAssertEqual(deepLabV3SegmentationWidth, categoryMask.width)
// XCTAssertEqual(deepLabV3SegmentationHeight, categoryMask.height)
// let goldenMaskImage = try XCTUnwrap(
// MLImage.imageFromBundle(
// class: type(of: self),
// filename: "segmentation_golden_rotation0",
// type: "png"))
// let pixelBuffer = goldenMaskImage.grayScalePixelBuffer().takeRetainedValue()
// CVPixelBufferLockBaseAddress(pixelBuffer, CVPixelBufferLockFlags.readOnly)
// let pixelBufferBaseAddress = (try XCTUnwrap(CVPixelBufferGetBaseAddress(pixelBuffer)))
// .assumingMemoryBound(to: UInt8.self)
// let numPixels = deepLabV3SegmentationWidth * deepLabV3SegmentationHeight
// let mask = try XCTUnwrap(categoryMask.mask)
// var inconsistentPixels: Float = 0.0
// for i in 0..<numPixels {
// if mask[i] * kGoldenMaskMagnificationFactor != pixelBufferBaseAddress[i] {
// inconsistentPixels += 1
// }
// }
// CVPixelBufferUnlockBaseAddress(pixelBuffer, CVPixelBufferLockFlags.readOnly)
// XCTAssertLessThan(inconsistentPixels / Float(numPixels), kGoldenMaskTolerance)
// }
}

View File

@ -0,0 +1,31 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPTextTaskRunner",
srcs = ["sources/MPPTextTaskRunner.mm"],
hdrs = ["sources/MPPTextTaskRunner.h"],
copts = [
"-ObjC++",
"-std=c++17",
],
deps = [
"//mediapipe/tasks/ios/core:MPPTaskRunner",
],
)

View File

@ -0,0 +1,37 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h"
NS_ASSUME_NONNULL_BEGIN
/**
* This class is used to create and call appropriate methods on the C++ Task Runner to initialize, execute and terminate any Mediapipe text task.
*/
@interface MPPTextTaskRunner : MPPTaskRunner
/**
* Initializes a new `MPPTextTaskRunner` with the mediapipe task graph config proto.
*
* @param graphConfig A mediapipe task graph config proto.
*
* @return An instance of `MPPTextTaskRunner` initialized to the given graph config proto.
*/
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,29 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h"
namespace {
using ::mediapipe::CalculatorGraphConfig;
} // namespace
@implementation MPPTextTaskRunner
- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig
error:(NSError **)error {
self = [super initWithCalculatorGraphConfig:graphConfig packetsCallback:nullptr error:error];
return self;
}
@end

View File

@ -0,0 +1,67 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPTextClassifier",
srcs = ["sources/MPPTextClassifier.mm"],
hdrs = ["sources/MPPTextClassifier.h"],
copts = [
"-ObjC++",
"-std=c++17",
],
deps = [
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"//mediapipe/tasks/ios/core:MPPTaskOptions",
"//mediapipe/tasks/ios/core:MPPTaskInfo",
"//mediapipe/tasks/ios/text/core:MPPTextTaskRunner",
"//mediapipe/tasks/ios/core:MPPTextPacketCreator",
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers",
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
":MPPTextClassifierOptions",
],
sdk_frameworks = [
"MetalKit",
],
)
objc_library(
name = "MPPTextClassifierOptions",
srcs = ["sources/MPPTextClassifierOptions.m"],
hdrs = ["sources/MPPTextClassifierOptions.h"],
copts = [
"-ObjC++",
"-std=c++17",
],
deps = [
"//mediapipe/tasks/ios/core:MPPTaskOptions",
"//mediapipe/tasks/ios/components/processors:MPPClassifierOptions",
],
)
objc_library(
name = "MPPTextClassifierResult",
srcs = ["sources/MPPTextClassifierResult.m"],
hdrs = ["sources/MPPTextClassifierResult.h"],
deps = [
"//mediapipe/tasks/ios/core:MPPTaskResult",
"//mediapipe/tasks/ios/components/containers:MPPClassificationResult",
],
)

View File

@ -0,0 +1,102 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h"
NS_ASSUME_NONNULL_BEGIN
/**
* This API expects a TFLite model with (optional) [TFLite Model
* Metadata](https://www.tensorflow.org/lite/convert/metadata")that contains the mandatory
* (described below) input tensors, output tensor, and the optional (but recommended) label items as
* AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor.
*
* Metadata is required for models with int32 input tensors because it contains the input process
* unit for the model's Tokenizer. No metadata is required for models with string input tensors.
*
* Input tensors
* - Three input tensors `kTfLiteInt32` of shape `[batch_size xbert_max_seq_len]`
* representing the input ids, mask ids, and segment ids. This input signature requires a
* Bert Tokenizer process unit in the model metadata.
* - Or one input tensor `kTfLiteInt32` of shape `[batch_size xmax_seq_len]` representing
* the input ids. This input signature requires a Regex Tokenizer process unit in the
* model metadata.
* - Or one input tensor ({@code kTfLiteString}) that is shapeless or has shape `[1]` containing
* the input string.
*
* At least one output tensor `(kTfLiteFloat32}/kBool)` with:
* - `N` classes and shape `[1 x N]`
* - optional (but recommended) label map(s) as AssociatedFile-s with type TENSOR_AXIS_LABELS,
* containing one label per line. The first such AssociatedFile (if any) is used to fill the
* `class_name` field of the results. The `display_name` field is filled from the AssociatedFile
* (if any) whose locale matches the `display_names_locale` field of the
* `MPPTextClassifierOptions` used at creation time ("en" by default, i.e. English). If none of
* these are available, only the `index` field of the results will be filled.
*
* @brief Performs classification on text.
*/
NS_SWIFT_NAME(TextClassifier)
@interface MPPTextClassifier : NSObject
/**
* Creates a new instance of `MPPTextClassifier` from an absolute path to a TensorFlow Lite model
* file stored locally on the device and the default `MPPTextClassifierOptions`.
*
* @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
*
* @param error An optional error parameter populated when there is an error in initializing
* the text classifier.
*
* @return A new instance of `MPPTextClassifier` with the given model path. `nil` if there is an
* error in initializing the text classifier.
*/
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error;
/**
* Creates a new instance of `MPPTextClassifier` from the given `MPPTextClassifierOptions`.
*
* @param options The options of type `MPPTextClassifierOptions` to use for configuring the
* `MPPTextClassifier`.
*
* @param error An optional error parameter populated when there is an error in initializing
* the text classifier.
*
* @return A new instance of `MPPTextClassifier` with the given options. `nil` if there is an error
* in initializing the text classifier.
*/
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error;
/**
* Performs classification on the input text.
*
* @param text The `NSString` on which classification is to be performed.
*
* @param error An optional error parameter populated when there is an error in performing
* classification on the input text.
*
* @return A `MPPTextClassifierResult` object that contains a list of text classifications.
*/
- (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error;
- (instancetype)init NS_UNAVAILABLE;
+ (instancetype)new NS_UNAVAILABLE;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,93 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h"
#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h"
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h"
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "absl/status/statusor.h"
namespace {
using ::mediapipe::Packet;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::PacketMap;
} // namespace
static NSString *const kClassificationsStreamName = @"classifications_out";
static NSString *const kClassificationsTag = @"CLASSIFICATIONS";
static NSString *const kTextInStreamName = @"text_in";
static NSString *const kTextTag = @"TEXT";
static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
@interface MPPTextClassifier () {
/** TextSearcher backed by C++ API */
MPPTextTaskRunner *_taskRunner;
}
@end
@implementation MPPTextClassifier
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error {
MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc]
initWithTaskGraphName:kTaskGraphName
inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ]
outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kClassificationsTag,
kClassificationsStreamName] ]
taskOptions:options
enableFlowLimiting:NO
error:error];
if (!taskInfo) {
return nil;
}
_taskRunner =
[[MPPTextTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig]
error:error];
self = [super init];
return self;
}
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error {
MPPTextClassifierOptions *options = [[MPPTextClassifierOptions alloc] init];
options.baseOptions.modelAssetPath = modelPath;
return [self initWithOptions:options error:error];
}
- (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error {
Packet packet = [MPPTextPacketCreator createWithText:text];
std::map<std::string, Packet> packet_map = {{kTextInStreamName.cppString, packet}};
absl::StatusOr<PacketMap> status_or_output_packet_map = [_taskRunner process:packet_map];
if (![MPPCommonUtils checkCppError:status_or_output_packet_map.status() toError:error]) {
return nil;
}
return [MPPTextClassifierResult
textClassifierResultWithClassificationsPacket:status_or_output_packet_map.value()
[kClassificationsStreamName.cppString]];
}
@end

View File

@ -0,0 +1,35 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
NS_ASSUME_NONNULL_BEGIN
/**
* Options for setting up a `MPPTextClassifierOptions`.
*/
NS_SWIFT_NAME(TextClassifierOptions)
@interface MPPTextClassifierOptions : MPPTaskOptions
/**
* Options for configuring the classifier behavior, such as score threshold, number of results, etc.
*/
@property(nonatomic, copy) MPPClassifierOptions *classifierOptions;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,27 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
@implementation MPPTextClassifierOptions
- (instancetype)init {
self = [super init];
if (self) {
_classifierOptions = [[MPPClassifierOptions alloc] init];
}
return self;
}
@end

View File

@ -0,0 +1,45 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
NS_ASSUME_NONNULL_BEGIN
/** Represents the classification results generated by `MPPTextClassifier`. */
NS_SWIFT_NAME(TextClassifierResult)
@interface MPPTextClassifierResult : MPPTaskResult
/** The `MPPClassificationResult` instance containing one set of results per classifier head. */
@property(nonatomic, readonly) MPPClassificationResult *classificationResult;
/**
* Initializes a new `MPPTextClassifierResult` with the given `MPPClassificationResult` and time
* stamp (in milliseconds).
*
* @param classificationResult The `MPPClassificationResult` instance containing one set of results
* per classifier head.
*
* @param timeStampMs The time stamp for this result.
*
* @return An instance of `MPPTextClassifierResult` initialized with the given
* `MPPClassificationResult` and time stamp (in milliseconds).
*/
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
timestampMs:(NSInteger)timestampMs;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,28 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h"
@implementation MPPTextClassifierResult
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
timestampMs:(NSInteger)timestampMs {
self = [super initWithTimestampMs:timestampMs];
if (self) {
_classificationResult = classificationResult;
}
return self;
}
@end

View File

@ -0,0 +1,41 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPTextClassifierOptionsHelpers",
srcs = ["sources/MPPTextClassifierOptions+Helpers.mm"],
hdrs = ["sources/MPPTextClassifierOptions+Helpers.h"],
deps = [
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierOptions",
"//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers",
"//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol",
"//mediapipe/tasks/ios/components/processors/utils:MPPClassifierOptionsHelpers",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
],
)
objc_library(
name = "MPPTextClassifierResultHelpers",
srcs = ["sources/MPPTextClassifierResult+Helpers.mm"],
hdrs = ["sources/MPPTextClassifierResult+Helpers.h"],
deps = [
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierResult",
"//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers",
"//mediapipe/framework:packet",
],
)

View File

@ -0,0 +1,26 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
NS_ASSUME_NONNULL_BEGIN
@interface MPPTextClassifierOptions (Helpers) <MPPTaskOptionsProtocol>
- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,36 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
#import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h"
#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h"
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h"
namespace {
using CalculatorOptionsProto = ::mediapipe::CalculatorOptions;
using TextClassifierGraphOptionsProto =
::mediapipe::tasks::text::text_classifier::proto::TextClassifierGraphOptions;
} // namespace
@implementation MPPTextClassifierOptions (Helpers)
- (void)copyToProto:(CalculatorOptionsProto *)optionsProto {
TextClassifierGraphOptionsProto *graph_options =
optionsProto->MutableExtension(TextClassifierGraphOptionsProto::ext);
[self.baseOptions copyToProto:graph_options->mutable_base_options()];
[self.classifierOptions copyToProto:graph_options->mutable_classifier_options()];
}
@end

View File

@ -0,0 +1,28 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h"
#include "mediapipe/framework/packet.h"
NS_ASSUME_NONNULL_BEGIN
@interface MPPTextClassifierResult (Helpers)
+ (MPPTextClassifierResult *)textClassifierResultWithClassificationsPacket:
(const mediapipe::Packet &)packet;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,42 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
static const int kMicroSecondsPerMilliSecond = 1000;
namespace {
using ClassificationResultProto =
::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::Packet;
} // namespace
#define int kMicroSecondsPerMilliSecond = 1000;
@implementation MPPTextClassifierResult (Helpers)
+ (MPPTextClassifierResult *)textClassifierResultWithClassificationsPacket:(const Packet &)packet {
MPPClassificationResult *classificationResult = [MPPClassificationResult
classificationResultWithProto:packet.Get<ClassificationResultProto>()];
return [[MPPTextClassifierResult alloc]
initWithClassificationResult:classificationResult
timestampMs:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)];
}
@end