Updated implementation of text classifier
This commit is contained in:
parent
7f7776ef80
commit
7e0fec7c28
|
@ -28,6 +28,5 @@ objc_library(
|
|||
hdrs = ["sources/MPPClassificationResult.h"],
|
||||
deps = [
|
||||
":MPPCategory",
|
||||
"//mediapipe/tasks/ios/core:MPPTaskResult",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
|
||||
#import <Foundation/Foundation.h>
|
||||
#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
|
@ -72,7 +71,7 @@ NS_SWIFT_NAME(Classifications)
|
|||
|
||||
/** Encapsulates results of any classification task. */
|
||||
NS_SWIFT_NAME(ClassificationResult)
|
||||
@interface MPPClassificationResult : MPPTaskResult
|
||||
@interface MPPClassificationResult : NSObject
|
||||
|
||||
/** Array of MPPClassifications objects containing classifier predictions per image classifier
|
||||
* head.
|
||||
|
@ -88,8 +87,7 @@ NS_SWIFT_NAME(ClassificationResult)
|
|||
* @return An instance of MPPClassificationResult initialized with the given array of
|
||||
* classifications.
|
||||
*/
|
||||
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
|
||||
timeStamp:(long)timeStamp;
|
||||
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications;
|
||||
|
||||
@end
|
||||
|
||||
|
|
|
@ -39,9 +39,9 @@
|
|||
NSArray<MPPClassifications *> *_classifications;
|
||||
}
|
||||
|
||||
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
|
||||
timeStamp:(long)timeStamp {
|
||||
self = [super initWithTimeStamp:timeStamp];
|
||||
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications {
|
||||
|
||||
self = [super init];
|
||||
if (self) {
|
||||
_classifications = classifications;
|
||||
}
|
||||
|
|
|
@ -53,14 +53,7 @@ using ClassificationResultProto =
|
|||
[classifications addObject:[MPPClassifications classificationsWithProto:classifications_proto]];
|
||||
}
|
||||
|
||||
long timeStamp;
|
||||
|
||||
if (classificationResultProto.has_timestamp_ms()) {
|
||||
timeStamp = classificationResultProto.timestamp_ms();
|
||||
}
|
||||
|
||||
return [[MPPClassificationResult alloc] initWithClassifications:classifications
|
||||
timeStamp:timeStamp];
|
||||
return [[MPPClassificationResult alloc] initWithClassifications:classifications];
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
|
@ -1,47 +0,0 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* The base class of the user-facing iOS mediapipe task api classes.
|
||||
*/
|
||||
@interface MPPTaskManager : NSObject
|
||||
/**
|
||||
* Initializes a new `MPPTaskManager` with the mediapipe task graph config proto.
|
||||
*
|
||||
* @param graphConfig A mediapipe task graph config proto.
|
||||
*
|
||||
* @return An instance of `MPPTaskManager` initialized to the given graph config proto.
|
||||
*/
|
||||
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
|
||||
error:(NSError **)error;
|
||||
|
||||
- (absl::StatusOr<mediapipe::tasks::core::PacketMap>)process:(const mediapipe::tasks::core::PacketMap&)packetMap error:(NSError **)error;
|
||||
|
||||
- (void)close;
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
+ (instancetype)new NS_UNAVAILABLE;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
|
@ -1,56 +0,0 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h"
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||
|
||||
namespace {
|
||||
using ::mediapipe::CalculatorGraphConfig;
|
||||
using ::mediapipe::Packet;
|
||||
using ::mediapipe::tasks::core::PacketMap;
|
||||
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
|
||||
} // namespace
|
||||
|
||||
@interface MPPTaskManager () {
|
||||
/** TextSearcher backed by C++ API */
|
||||
std::unique_ptr<TaskRunnerCpp> _cppTaskRunner;
|
||||
}
|
||||
@end
|
||||
|
||||
@implementation MPPTaskManager
|
||||
|
||||
- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig
|
||||
error:(NSError **)error {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig));
|
||||
|
||||
if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
_cppTaskRunner = std::move(taskRunnerResult.value());
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (absl::StatusOr<PacketMap>)process:(const PacketMap&)packetMap {
|
||||
return _cppTaskRunner->Process(packetMap);
|
||||
}
|
||||
|
||||
- (void)close {
|
||||
_cppTaskRunner->Close();
|
||||
}
|
||||
|
||||
@end
|
|
@ -19,6 +19,28 @@
|
|||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
static NSString *const kBertTextClassifierModelName = @"bert_text_classifier";
|
||||
static NSString *const kNegativeText = @"unflinchingly bleak and desperate";
|
||||
|
||||
#define VerifyCategory(category, expectedIndex, expectedScore, expectedLabel, expectedDisplayName) \
|
||||
XCTAssertEqual(category.index, expectedIndex); \
|
||||
XCTAssertEqualWithAccuracy(category.score, expectedScore, 1e-6); \
|
||||
XCTAssertEqualObjects(category.label, expectedLabel); \
|
||||
XCTAssertEqualObjects(category.displayName, expectedDisplayName);
|
||||
|
||||
#define VerifyClassifications(classifications, expectedHeadIndex, expectedCategoryCount) \
|
||||
XCTAssertEqual(classifications.categories.count, expectedCategoryCount);
|
||||
|
||||
#define VerifyClassificationResult(classificationResult, expectedClassificationsCount) \
|
||||
XCTAssertNotNil(classificationResult); \
|
||||
XCTAssertEqual(classificationResult.classifications.count, expectedClassificationsCount)
|
||||
|
||||
#define AssertClassificationResultHasOneHead(classificationResult) \
|
||||
XCTAssertNotNil(classificationResult); \
|
||||
XCTAssertEqual(classificationResult.classifications.count, 1);
|
||||
XCTAssertEqual(classificationResult.classifications[0].headIndex, 1);
|
||||
|
||||
#define AssertTextClassifierResultIsNotNil(textClassifierResult) \
|
||||
XCTAssertNotNil(textClassifierResult);
|
||||
|
||||
@interface MPPTextClassifierTests : XCTestCase
|
||||
@end
|
||||
|
@ -41,15 +63,25 @@ static NSString *const kBertTextClassifierModelName = @"bert_text_classifier";
|
|||
- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName {
|
||||
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
|
||||
MPPTextClassifierOptions *textClassifierOptions =
|
||||
[[MPPTextClassifierOptions alloc] initWithModelPath:modelPath];
|
||||
[[MPPTextClassifierOptions alloc] init];
|
||||
textClassifierOptions.baseOptions.modelAssetPath = modelPath;
|
||||
|
||||
return textClassifierOptions;
|
||||
}
|
||||
|
||||
- (void)testCreateTextClassifierOptionsSucceeds {
|
||||
MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
|
||||
kBertTextClassifierModelName
|
||||
|
||||
- (MPPTextClassifier *)createTextClassifierFromOptionsWithModelName:(NSString *)modelName {
|
||||
MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:modelName];
|
||||
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
|
||||
XCTAssertNotNil(textClassifier);
|
||||
|
||||
return textClassifier
|
||||
}
|
||||
|
||||
- (void)classifyWithBertSucceeds {
|
||||
MPPTextClassifier *textClassifier = [self createTextClassifierWithModelName:kBertTextClassifierModelName];
|
||||
MPPTextClassifierResult *textClassifierResult = [textClassifier classifyWithText:kNegativeText];
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
|
@ -27,10 +27,10 @@ objc_library(
|
|||
deps = [
|
||||
"//mediapipe/tasks/ios/core:MPPTaskOptions",
|
||||
"//mediapipe/tasks/ios/core:MPPTaskInfo",
|
||||
"//mediapipe/tasks/ios/core:MPPTaskManager",
|
||||
"//mediapipe/tasks/ios/core:MPPPacketCreator",
|
||||
"//mediapipe/tasks/ios/core:MPPTaskRunner",
|
||||
"//mediapipe/tasks/ios/core:MPPTextPacketCreator",
|
||||
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers",
|
||||
"//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers",
|
||||
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers",
|
||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||
":MPPTextClassifierOptions",
|
||||
|
@ -51,3 +51,12 @@ objc_library(
|
|||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPTextClassifierResult",
|
||||
srcs = ["sources/MPPTextClassifierResult.m"],
|
||||
hdrs = ["sources/MPPTextClassifierResult.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/ios/core:MPPTaskResult",
|
||||
"//mediapipe/tasks/ios/components/containers:MPPClassificationResult",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
==============================================================================*/
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h"
|
||||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
|
||||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
|
||||
|
||||
|
@ -52,7 +52,7 @@ NS_SWIFT_NAME(TextClassifier)
|
|||
*/
|
||||
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error;
|
||||
|
||||
- (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error;
|
||||
- (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error;
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
|
|
|
@ -15,9 +15,9 @@
|
|||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h"
|
||||
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
||||
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h"
|
||||
|
||||
|
@ -30,14 +30,14 @@ using ::mediapipe::tasks::core::PacketMap;
|
|||
} // namespace
|
||||
|
||||
static NSString *const kClassificationsStreamName = @"classifications_out";
|
||||
static NSString *const kClassificationsTag = @"classifications";
|
||||
static NSString *const kClassificationsTag = @"CLASSIFICATIONS";
|
||||
static NSString *const kTextInStreamName = @"text_in";
|
||||
static NSString *const kTextTag = @"TEXT";
|
||||
static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||
|
||||
@interface MPPTextClassifier () {
|
||||
/** TextSearcher backed by C++ API */
|
||||
MPPTaskManager *_taskManager;
|
||||
MPPTaskRunner *_taskRunner;
|
||||
}
|
||||
@end
|
||||
|
||||
|
@ -47,8 +47,8 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
|
|||
|
||||
MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc]
|
||||
initWithTaskGraphName:kTaskGraphName
|
||||
inputStreams:@[ [NSString stringWithFormat:@"@:@", kTextTag, kTextInStreamName] ]
|
||||
outputStreams:@[ [NSString stringWithFormat:@"@:@", kClassificationsTag,
|
||||
inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ]
|
||||
outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kClassificationsTag,
|
||||
kClassificationsStreamName] ]
|
||||
taskOptions:options
|
||||
enableFlowLimiting:NO
|
||||
|
@ -58,7 +58,7 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
|
|||
return nil;
|
||||
}
|
||||
|
||||
_taskManager = [[MPPTaskManager alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error];
|
||||
_taskRunner = [[MPPTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error];
|
||||
|
||||
self = [super init];
|
||||
|
||||
|
@ -66,22 +66,23 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
|
|||
}
|
||||
|
||||
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error {
|
||||
MPPTextClassifierOptions *options =
|
||||
[[MPPTextClassifierOptions alloc] initWithModelPath:modelPath];
|
||||
MPPTextClassifierOptions *options = [[MPPTextClassifierOptions alloc] init];
|
||||
|
||||
options.baseOptions.modelAssetPath = modelPath;
|
||||
|
||||
return [self initWithOptions:options error:error];
|
||||
}
|
||||
|
||||
- (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error {
|
||||
Packet packet = [MPPPacketCreator createWithText:text];
|
||||
- (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error {
|
||||
Packet packet = [MPPTextPacketCreator createWithText:text];
|
||||
|
||||
absl::StatusOr<PacketMap> output_packet_map = [_taskManager process:{{kTextInStreamName.cppString, packet}} error:error];
|
||||
absl::StatusOr<PacketMap> output_packet_map = [_taskRunner process:{{kTextInStreamName.cppString, packet}} error:error];
|
||||
if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
return [MPPClassificationResult
|
||||
classificationResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString]
|
||||
return [MPPTextClassifierResult
|
||||
textClassifierResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString]
|
||||
.Get<ClassificationResult>()];
|
||||
}
|
||||
|
||||
|
|
|
@ -31,20 +31,20 @@ NS_SWIFT_NAME(TextClassifierOptions)
|
|||
*/
|
||||
@property(nonatomic, copy) MPPClassifierOptions *classifierOptions;
|
||||
|
||||
/**
|
||||
* Initializes a new `MPPTextClassifierOptions` with the absolute path to the model file
|
||||
* stored locally on the device, set to the given the model path.
|
||||
*
|
||||
* @discussion The external model file must be a single standalone TFLite file. It could be packed
|
||||
* with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the
|
||||
* necessary metadata and associated files might result in errors. Check the [documentation]
|
||||
* (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement.
|
||||
*
|
||||
* @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
|
||||
*
|
||||
* @return An instance of `MPPTextClassifierOptions` initialized to the given model path.
|
||||
*/
|
||||
- (instancetype)initWithModelPath:(NSString *)modelPath;
|
||||
// /**
|
||||
// * Initializes a new `MPPTextClassifierOptions` with the absolute path to the model file
|
||||
// * stored locally on the device, set to the given the model path.
|
||||
// *
|
||||
// * @discussion The external model file must be a single standalone TFLite file. It could be packed
|
||||
// * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the
|
||||
// * necessary metadata and associated files might result in errors. Check the [documentation]
|
||||
// * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement.
|
||||
// *
|
||||
// * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
|
||||
// *
|
||||
// * @return An instance of `MPPTextClassifierOptions` initialized to the given model path.
|
||||
// */
|
||||
// - (instancetype)initWithModelPath:(NSString *)modelPath;
|
||||
|
||||
@end
|
||||
|
||||
|
|
|
@ -16,12 +16,12 @@
|
|||
|
||||
@implementation MPPTextClassifierOptions
|
||||
|
||||
- (instancetype)initWithModelPath:(NSString *)modelPath {
|
||||
self = [super initWithModelPath:modelPath];
|
||||
if (self) {
|
||||
_classifierOptions = [[MPPClassifierOptions alloc] init];
|
||||
}
|
||||
return self;
|
||||
}
|
||||
// - (instancetype)initWithModelPath:(NSString *)modelPath {
|
||||
// self = [super initWithModelPath:modelPath];
|
||||
// if (self) {
|
||||
// _classifierOptions = [[MPPClassifierOptions alloc] init];
|
||||
// }
|
||||
// return self;
|
||||
// }
|
||||
|
||||
@end
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/** Encapsulates list of predicted classes (aka labels) for a given image classifier head. */
|
||||
NS_SWIFT_NAME(TextClassifierResult)
|
||||
@interface MPPTextClassifierResult : MPPTaskResult
|
||||
|
||||
@property(nonatomic, readonly) MPPClassificationResult *classificationResult;
|
||||
|
||||
/**
|
||||
* Initializes a new `MPPClassificationResult` with the given array of classifications.
|
||||
*
|
||||
* @param classifications An Aaray of `MPPClassifications` objects containing classifier
|
||||
* predictions per classifier head.
|
||||
*
|
||||
* @return An instance of MPPClassificationResult initialized with the given array of
|
||||
* classifications.
|
||||
*/
|
||||
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
|
||||
timeStamp:(long)timeStamp;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h"
|
||||
|
||||
@implementation MPPTextClassifierResult
|
||||
|
||||
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
|
||||
timeStamp:(long)timeStamp {
|
||||
self = [super initWithTimestamp:timeStamp];
|
||||
if (self) {
|
||||
_classificationResult = classificationResult;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
@end
|
|
@ -28,3 +28,13 @@ objc_library(
|
|||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPTextClassifierResultHelpers",
|
||||
srcs = ["sources/MPPTextClassifierResult+Helpers.mm"],
|
||||
hdrs = ["sources/MPPTextClassifierResult+Helpers.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierResult",
|
||||
"//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"
|
||||
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
|
||||
#import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h"
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@interface MPPTextClassifierResult (Helpers)
|
||||
|
||||
+ (MPPTextClassifierResult *)textClassifierResultWithProto:
|
||||
(const mediapipe::tasks::components::containers::proto::ClassificationResult &)
|
||||
classificationResultProto;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
|
||||
|
||||
namespace {
|
||||
using ClassificationResultProto =
|
||||
::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
} // namespace
|
||||
|
||||
@implementation MPPTextClassifierResult (Helpers)
|
||||
|
||||
+ (MPPTextClassifierResult *)textClassifierResultWithProto:
|
||||
(const ClassificationResultProto &)classificationResultProto {
|
||||
long timeStamp;
|
||||
|
||||
if (classificationResultProto.has_timestamp_ms()) {
|
||||
timeStamp = classificationResultProto.timestamp_ms();
|
||||
}
|
||||
|
||||
MPPClassificationResult *classificationResult = [MPPClassificationResult classificationResultWithProto:classificationResultProto];
|
||||
|
||||
return [[MPPTextClassifierResult alloc] initWithClassificationResult:classificationResult
|
||||
timeStamp:timeStamp];
|
||||
}
|
||||
|
||||
@end
|
Loading…
Reference in New Issue
Block a user