Updated implementation of text classifier
This commit is contained in:
parent
7f7776ef80
commit
7e0fec7c28
|
@ -28,6 +28,5 @@ objc_library(
|
||||||
hdrs = ["sources/MPPClassificationResult.h"],
|
hdrs = ["sources/MPPClassificationResult.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":MPPCategory",
|
":MPPCategory",
|
||||||
"//mediapipe/tasks/ios/core:MPPTaskResult",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
|
|
||||||
#import <Foundation/Foundation.h>
|
#import <Foundation/Foundation.h>
|
||||||
#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h"
|
#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h"
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
|
|
||||||
|
|
||||||
NS_ASSUME_NONNULL_BEGIN
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
|
@ -72,7 +71,7 @@ NS_SWIFT_NAME(Classifications)
|
||||||
|
|
||||||
/** Encapsulates results of any classification task. */
|
/** Encapsulates results of any classification task. */
|
||||||
NS_SWIFT_NAME(ClassificationResult)
|
NS_SWIFT_NAME(ClassificationResult)
|
||||||
@interface MPPClassificationResult : MPPTaskResult
|
@interface MPPClassificationResult : NSObject
|
||||||
|
|
||||||
/** Array of MPPClassifications objects containing classifier predictions per image classifier
|
/** Array of MPPClassifications objects containing classifier predictions per image classifier
|
||||||
* head.
|
* head.
|
||||||
|
@ -88,8 +87,7 @@ NS_SWIFT_NAME(ClassificationResult)
|
||||||
* @return An instance of MPPClassificationResult initialized with the given array of
|
* @return An instance of MPPClassificationResult initialized with the given array of
|
||||||
* classifications.
|
* classifications.
|
||||||
*/
|
*/
|
||||||
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
|
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications;
|
||||||
timeStamp:(long)timeStamp;
|
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
||||||
|
|
|
@ -39,9 +39,9 @@
|
||||||
NSArray<MPPClassifications *> *_classifications;
|
NSArray<MPPClassifications *> *_classifications;
|
||||||
}
|
}
|
||||||
|
|
||||||
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
|
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications {
|
||||||
timeStamp:(long)timeStamp {
|
|
||||||
self = [super initWithTimeStamp:timeStamp];
|
self = [super init];
|
||||||
if (self) {
|
if (self) {
|
||||||
_classifications = classifications;
|
_classifications = classifications;
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,14 +53,7 @@ using ClassificationResultProto =
|
||||||
[classifications addObject:[MPPClassifications classificationsWithProto:classifications_proto]];
|
[classifications addObject:[MPPClassifications classificationsWithProto:classifications_proto]];
|
||||||
}
|
}
|
||||||
|
|
||||||
long timeStamp;
|
return [[MPPClassificationResult alloc] initWithClassifications:classifications];
|
||||||
|
|
||||||
if (classificationResultProto.has_timestamp_ms()) {
|
|
||||||
timeStamp = classificationResultProto.timestamp_ms();
|
|
||||||
}
|
|
||||||
|
|
||||||
return [[MPPClassificationResult alloc] initWithClassifications:classifications
|
|
||||||
timeStamp:timeStamp];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
|
@ -1,47 +0,0 @@
|
||||||
// Copyright 2022 The MediaPipe Authors.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#import <Foundation/Foundation.h>
|
|
||||||
|
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
|
||||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
|
||||||
|
|
||||||
|
|
||||||
NS_ASSUME_NONNULL_BEGIN
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The base class of the user-facing iOS mediapipe task api classes.
|
|
||||||
*/
|
|
||||||
@interface MPPTaskManager : NSObject
|
|
||||||
/**
|
|
||||||
* Initializes a new `MPPTaskManager` with the mediapipe task graph config proto.
|
|
||||||
*
|
|
||||||
* @param graphConfig A mediapipe task graph config proto.
|
|
||||||
*
|
|
||||||
* @return An instance of `MPPTaskManager` initialized to the given graph config proto.
|
|
||||||
*/
|
|
||||||
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
|
|
||||||
error:(NSError **)error;
|
|
||||||
|
|
||||||
- (absl::StatusOr<mediapipe::tasks::core::PacketMap>)process:(const mediapipe::tasks::core::PacketMap&)packetMap error:(NSError **)error;
|
|
||||||
|
|
||||||
- (void)close;
|
|
||||||
|
|
||||||
- (instancetype)init NS_UNAVAILABLE;
|
|
||||||
|
|
||||||
+ (instancetype)new NS_UNAVAILABLE;
|
|
||||||
|
|
||||||
@end
|
|
||||||
|
|
||||||
NS_ASSUME_NONNULL_END
|
|
|
@ -1,56 +0,0 @@
|
||||||
// Copyright 2022 The MediaPipe Authors.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h"
|
|
||||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
using ::mediapipe::CalculatorGraphConfig;
|
|
||||||
using ::mediapipe::Packet;
|
|
||||||
using ::mediapipe::tasks::core::PacketMap;
|
|
||||||
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
@interface MPPTaskManager () {
|
|
||||||
/** TextSearcher backed by C++ API */
|
|
||||||
std::unique_ptr<TaskRunnerCpp> _cppTaskRunner;
|
|
||||||
}
|
|
||||||
@end
|
|
||||||
|
|
||||||
@implementation MPPTaskManager
|
|
||||||
|
|
||||||
- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig
|
|
||||||
error:(NSError **)error {
|
|
||||||
self = [super init];
|
|
||||||
if (self) {
|
|
||||||
auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig));
|
|
||||||
|
|
||||||
if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) {
|
|
||||||
return nil;
|
|
||||||
}
|
|
||||||
|
|
||||||
_cppTaskRunner = std::move(taskRunnerResult.value());
|
|
||||||
}
|
|
||||||
return self;
|
|
||||||
}
|
|
||||||
|
|
||||||
- (absl::StatusOr<PacketMap>)process:(const PacketMap&)packetMap {
|
|
||||||
return _cppTaskRunner->Process(packetMap);
|
|
||||||
}
|
|
||||||
|
|
||||||
- (void)close {
|
|
||||||
_cppTaskRunner->Close();
|
|
||||||
}
|
|
||||||
|
|
||||||
@end
|
|
|
@ -19,6 +19,28 @@
|
||||||
NS_ASSUME_NONNULL_BEGIN
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
static NSString *const kBertTextClassifierModelName = @"bert_text_classifier";
|
static NSString *const kBertTextClassifierModelName = @"bert_text_classifier";
|
||||||
|
static NSString *const kNegativeText = @"unflinchingly bleak and desperate";
|
||||||
|
|
||||||
|
#define VerifyCategory(category, expectedIndex, expectedScore, expectedLabel, expectedDisplayName) \
|
||||||
|
XCTAssertEqual(category.index, expectedIndex); \
|
||||||
|
XCTAssertEqualWithAccuracy(category.score, expectedScore, 1e-6); \
|
||||||
|
XCTAssertEqualObjects(category.label, expectedLabel); \
|
||||||
|
XCTAssertEqualObjects(category.displayName, expectedDisplayName);
|
||||||
|
|
||||||
|
#define VerifyClassifications(classifications, expectedHeadIndex, expectedCategoryCount) \
|
||||||
|
XCTAssertEqual(classifications.categories.count, expectedCategoryCount);
|
||||||
|
|
||||||
|
#define VerifyClassificationResult(classificationResult, expectedClassificationsCount) \
|
||||||
|
XCTAssertNotNil(classificationResult); \
|
||||||
|
XCTAssertEqual(classificationResult.classifications.count, expectedClassificationsCount)
|
||||||
|
|
||||||
|
#define AssertClassificationResultHasOneHead(classificationResult) \
|
||||||
|
XCTAssertNotNil(classificationResult); \
|
||||||
|
XCTAssertEqual(classificationResult.classifications.count, 1);
|
||||||
|
XCTAssertEqual(classificationResult.classifications[0].headIndex, 1);
|
||||||
|
|
||||||
|
#define AssertTextClassifierResultIsNotNil(textClassifierResult) \
|
||||||
|
XCTAssertNotNil(textClassifierResult);
|
||||||
|
|
||||||
@interface MPPTextClassifierTests : XCTestCase
|
@interface MPPTextClassifierTests : XCTestCase
|
||||||
@end
|
@end
|
||||||
|
@ -41,15 +63,25 @@ static NSString *const kBertTextClassifierModelName = @"bert_text_classifier";
|
||||||
- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName {
|
- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName {
|
||||||
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
|
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
|
||||||
MPPTextClassifierOptions *textClassifierOptions =
|
MPPTextClassifierOptions *textClassifierOptions =
|
||||||
[[MPPTextClassifierOptions alloc] initWithModelPath:modelPath];
|
[[MPPTextClassifierOptions alloc] init];
|
||||||
|
textClassifierOptions.baseOptions.modelAssetPath = modelPath;
|
||||||
|
|
||||||
return textClassifierOptions;
|
return textClassifierOptions;
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)testCreateTextClassifierOptionsSucceeds {
|
kBertTextClassifierModelName
|
||||||
MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
|
|
||||||
|
- (MPPTextClassifier *)createTextClassifierFromOptionsWithModelName:(NSString *)modelName {
|
||||||
|
MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:modelName];
|
||||||
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
|
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
|
||||||
XCTAssertNotNil(textClassifier);
|
XCTAssertNotNil(textClassifier);
|
||||||
|
|
||||||
|
return textClassifier
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)classifyWithBertSucceeds {
|
||||||
|
MPPTextClassifier *textClassifier = [self createTextClassifierWithModelName:kBertTextClassifierModelName];
|
||||||
|
MPPTextClassifierResult *textClassifierResult = [textClassifier classifyWithText:kNegativeText];
|
||||||
}
|
}
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
|
@ -27,10 +27,10 @@ objc_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/ios/core:MPPTaskOptions",
|
"//mediapipe/tasks/ios/core:MPPTaskOptions",
|
||||||
"//mediapipe/tasks/ios/core:MPPTaskInfo",
|
"//mediapipe/tasks/ios/core:MPPTaskInfo",
|
||||||
"//mediapipe/tasks/ios/core:MPPTaskManager",
|
"//mediapipe/tasks/ios/core:MPPTaskRunner",
|
||||||
"//mediapipe/tasks/ios/core:MPPPacketCreator",
|
"//mediapipe/tasks/ios/core:MPPTextPacketCreator",
|
||||||
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers",
|
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers",
|
||||||
"//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers",
|
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers",
|
||||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||||
":MPPTextClassifierOptions",
|
":MPPTextClassifierOptions",
|
||||||
|
@ -51,3 +51,12 @@ objc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPTextClassifierResult",
|
||||||
|
srcs = ["sources/MPPTextClassifierResult.m"],
|
||||||
|
hdrs = ["sources/MPPTextClassifierResult.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/ios/core:MPPTaskResult",
|
||||||
|
"//mediapipe/tasks/ios/components/containers:MPPClassificationResult",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#import <Foundation/Foundation.h>
|
#import <Foundation/Foundation.h>
|
||||||
|
|
||||||
#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h"
|
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h"
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
|
||||||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
|
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ NS_SWIFT_NAME(TextClassifier)
|
||||||
*/
|
*/
|
||||||
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error;
|
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error;
|
||||||
|
|
||||||
- (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error;
|
- (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error;
|
||||||
|
|
||||||
- (instancetype)init NS_UNAVAILABLE;
|
- (instancetype)init NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
|
|
@ -15,9 +15,9 @@
|
||||||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
|
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
|
||||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||||
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||||
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
|
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h"
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h"
|
#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h"
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h"
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h"
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
||||||
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h"
|
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h"
|
||||||
|
|
||||||
|
@ -30,14 +30,14 @@ using ::mediapipe::tasks::core::PacketMap;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
static NSString *const kClassificationsStreamName = @"classifications_out";
|
static NSString *const kClassificationsStreamName = @"classifications_out";
|
||||||
static NSString *const kClassificationsTag = @"classifications";
|
static NSString *const kClassificationsTag = @"CLASSIFICATIONS";
|
||||||
static NSString *const kTextInStreamName = @"text_in";
|
static NSString *const kTextInStreamName = @"text_in";
|
||||||
static NSString *const kTextTag = @"TEXT";
|
static NSString *const kTextTag = @"TEXT";
|
||||||
static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||||
|
|
||||||
@interface MPPTextClassifier () {
|
@interface MPPTextClassifier () {
|
||||||
/** TextSearcher backed by C++ API */
|
/** TextSearcher backed by C++ API */
|
||||||
MPPTaskManager *_taskManager;
|
MPPTaskRunner *_taskRunner;
|
||||||
}
|
}
|
||||||
@end
|
@end
|
||||||
|
|
||||||
|
@ -47,8 +47,8 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
|
||||||
|
|
||||||
MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc]
|
MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc]
|
||||||
initWithTaskGraphName:kTaskGraphName
|
initWithTaskGraphName:kTaskGraphName
|
||||||
inputStreams:@[ [NSString stringWithFormat:@"@:@", kTextTag, kTextInStreamName] ]
|
inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ]
|
||||||
outputStreams:@[ [NSString stringWithFormat:@"@:@", kClassificationsTag,
|
outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kClassificationsTag,
|
||||||
kClassificationsStreamName] ]
|
kClassificationsStreamName] ]
|
||||||
taskOptions:options
|
taskOptions:options
|
||||||
enableFlowLimiting:NO
|
enableFlowLimiting:NO
|
||||||
|
@ -58,7 +58,7 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
|
||||||
return nil;
|
return nil;
|
||||||
}
|
}
|
||||||
|
|
||||||
_taskManager = [[MPPTaskManager alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error];
|
_taskRunner = [[MPPTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error];
|
||||||
|
|
||||||
self = [super init];
|
self = [super init];
|
||||||
|
|
||||||
|
@ -66,22 +66,23 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
|
||||||
}
|
}
|
||||||
|
|
||||||
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error {
|
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error {
|
||||||
MPPTextClassifierOptions *options =
|
MPPTextClassifierOptions *options = [[MPPTextClassifierOptions alloc] init];
|
||||||
[[MPPTextClassifierOptions alloc] initWithModelPath:modelPath];
|
|
||||||
|
options.baseOptions.modelAssetPath = modelPath;
|
||||||
|
|
||||||
return [self initWithOptions:options error:error];
|
return [self initWithOptions:options error:error];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error {
|
- (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error {
|
||||||
Packet packet = [MPPPacketCreator createWithText:text];
|
Packet packet = [MPPTextPacketCreator createWithText:text];
|
||||||
|
|
||||||
absl::StatusOr<PacketMap> output_packet_map = [_taskManager process:{{kTextInStreamName.cppString, packet}} error:error];
|
absl::StatusOr<PacketMap> output_packet_map = [_taskRunner process:{{kTextInStreamName.cppString, packet}} error:error];
|
||||||
if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) {
|
if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) {
|
||||||
return nil;
|
return nil;
|
||||||
}
|
}
|
||||||
|
|
||||||
return [MPPClassificationResult
|
return [MPPTextClassifierResult
|
||||||
classificationResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString]
|
textClassifierResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString]
|
||||||
.Get<ClassificationResult>()];
|
.Get<ClassificationResult>()];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,20 +31,20 @@ NS_SWIFT_NAME(TextClassifierOptions)
|
||||||
*/
|
*/
|
||||||
@property(nonatomic, copy) MPPClassifierOptions *classifierOptions;
|
@property(nonatomic, copy) MPPClassifierOptions *classifierOptions;
|
||||||
|
|
||||||
/**
|
// /**
|
||||||
* Initializes a new `MPPTextClassifierOptions` with the absolute path to the model file
|
// * Initializes a new `MPPTextClassifierOptions` with the absolute path to the model file
|
||||||
* stored locally on the device, set to the given the model path.
|
// * stored locally on the device, set to the given the model path.
|
||||||
*
|
// *
|
||||||
* @discussion The external model file must be a single standalone TFLite file. It could be packed
|
// * @discussion The external model file must be a single standalone TFLite file. It could be packed
|
||||||
* with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the
|
// * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the
|
||||||
* necessary metadata and associated files might result in errors. Check the [documentation]
|
// * necessary metadata and associated files might result in errors. Check the [documentation]
|
||||||
* (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement.
|
// * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement.
|
||||||
*
|
// *
|
||||||
* @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
|
// * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
|
||||||
*
|
// *
|
||||||
* @return An instance of `MPPTextClassifierOptions` initialized to the given model path.
|
// * @return An instance of `MPPTextClassifierOptions` initialized to the given model path.
|
||||||
*/
|
// */
|
||||||
- (instancetype)initWithModelPath:(NSString *)modelPath;
|
// - (instancetype)initWithModelPath:(NSString *)modelPath;
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
||||||
|
|
|
@ -16,12 +16,12 @@
|
||||||
|
|
||||||
@implementation MPPTextClassifierOptions
|
@implementation MPPTextClassifierOptions
|
||||||
|
|
||||||
- (instancetype)initWithModelPath:(NSString *)modelPath {
|
// - (instancetype)initWithModelPath:(NSString *)modelPath {
|
||||||
self = [super initWithModelPath:modelPath];
|
// self = [super initWithModelPath:modelPath];
|
||||||
if (self) {
|
// if (self) {
|
||||||
_classifierOptions = [[MPPClassifierOptions alloc] init];
|
// _classifierOptions = [[MPPClassifierOptions alloc] init];
|
||||||
}
|
// }
|
||||||
return self;
|
// return self;
|
||||||
}
|
// }
|
||||||
|
|
||||||
@end
|
@end
|
|
@ -0,0 +1,41 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import <Foundation/Foundation.h>
|
||||||
|
#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h"
|
||||||
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
|
/** Encapsulates list of predicted classes (aka labels) for a given image classifier head. */
|
||||||
|
NS_SWIFT_NAME(TextClassifierResult)
|
||||||
|
@interface MPPTextClassifierResult : MPPTaskResult
|
||||||
|
|
||||||
|
@property(nonatomic, readonly) MPPClassificationResult *classificationResult;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes a new `MPPClassificationResult` with the given array of classifications.
|
||||||
|
*
|
||||||
|
* @param classifications An Aaray of `MPPClassifications` objects containing classifier
|
||||||
|
* predictions per classifier head.
|
||||||
|
*
|
||||||
|
* @return An instance of MPPClassificationResult initialized with the given array of
|
||||||
|
* classifications.
|
||||||
|
*/
|
||||||
|
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
|
||||||
|
timeStamp:(long)timeStamp;
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_END
|
|
@ -0,0 +1,28 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h"
|
||||||
|
|
||||||
|
@implementation MPPTextClassifierResult
|
||||||
|
|
||||||
|
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
|
||||||
|
timeStamp:(long)timeStamp {
|
||||||
|
self = [super initWithTimestamp:timeStamp];
|
||||||
|
if (self) {
|
||||||
|
_classificationResult = classificationResult;
|
||||||
|
}
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
|
@ -28,3 +28,13 @@ objc_library(
|
||||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPTextClassifierResultHelpers",
|
||||||
|
srcs = ["sources/MPPTextClassifierResult+Helpers.mm"],
|
||||||
|
hdrs = ["sources/MPPTextClassifierResult+Helpers.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierResult",
|
||||||
|
"//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
|
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
|
||||||
#import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h"
|
#import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h"
|
||||||
#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h"
|
#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h"
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
|
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h"
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
|
@interface MPPTextClassifierResult (Helpers)
|
||||||
|
|
||||||
|
+ (MPPTextClassifierResult *)textClassifierResultWithProto:
|
||||||
|
(const mediapipe::tasks::components::containers::proto::ClassificationResult &)
|
||||||
|
classificationResultProto;
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_END
|
|
@ -0,0 +1,39 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h"
|
||||||
|
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
using ClassificationResultProto =
|
||||||
|
::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
@implementation MPPTextClassifierResult (Helpers)
|
||||||
|
|
||||||
|
+ (MPPTextClassifierResult *)textClassifierResultWithProto:
|
||||||
|
(const ClassificationResultProto &)classificationResultProto {
|
||||||
|
long timeStamp;
|
||||||
|
|
||||||
|
if (classificationResultProto.has_timestamp_ms()) {
|
||||||
|
timeStamp = classificationResultProto.timestamp_ms();
|
||||||
|
}
|
||||||
|
|
||||||
|
MPPClassificationResult *classificationResult = [MPPClassificationResult classificationResultWithProto:classificationResultProto];
|
||||||
|
|
||||||
|
return [[MPPTextClassifierResult alloc] initWithClassificationResult:classificationResult
|
||||||
|
timeStamp:timeStamp];
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
Loading…
Reference in New Issue
Block a user