Updated implementation of text classifier

This commit is contained in:
Prianka Liz Kariat 2022-12-23 17:52:00 +05:30
parent 7f7776ef80
commit 7e0fec7c28
19 changed files with 264 additions and 189 deletions

View File

@ -28,6 +28,5 @@ objc_library(
hdrs = ["sources/MPPClassificationResult.h"],
deps = [
":MPPCategory",
"//mediapipe/tasks/ios/core:MPPTaskResult",
],
)

View File

@ -14,7 +14,6 @@
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
NS_ASSUME_NONNULL_BEGIN
@ -72,7 +71,7 @@ NS_SWIFT_NAME(Classifications)
/** Encapsulates results of any classification task. */
NS_SWIFT_NAME(ClassificationResult)
@interface MPPClassificationResult : MPPTaskResult
@interface MPPClassificationResult : NSObject
/** Array of MPPClassifications objects containing classifier predictions per image classifier
* head.
@ -88,8 +87,7 @@ NS_SWIFT_NAME(ClassificationResult)
* @return An instance of MPPClassificationResult initialized with the given array of
* classifications.
*/
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
timeStamp:(long)timeStamp;
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications;
@end

View File

@ -39,9 +39,9 @@
NSArray<MPPClassifications *> *_classifications;
}
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
timeStamp:(long)timeStamp {
self = [super initWithTimeStamp:timeStamp];
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications {
self = [super init];
if (self) {
_classifications = classifications;
}

View File

@ -53,14 +53,7 @@ using ClassificationResultProto =
[classifications addObject:[MPPClassifications classificationsWithProto:classifications_proto]];
}
long timeStamp;
if (classificationResultProto.has_timestamp_ms()) {
timeStamp = classificationResultProto.timestamp_ms();
}
return [[MPPClassificationResult alloc] initWithClassifications:classifications
timeStamp:timeStamp];
return [[MPPClassificationResult alloc] initWithClassifications:classifications];
}
@end

View File

@ -1,47 +0,0 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
NS_ASSUME_NONNULL_BEGIN
/**
* The base class of the user-facing iOS mediapipe task api classes.
*/
@interface MPPTaskManager : NSObject
/**
* Initializes a new `MPPTaskManager` with the mediapipe task graph config proto.
*
* @param graphConfig A mediapipe task graph config proto.
*
* @return An instance of `MPPTaskManager` initialized to the given graph config proto.
*/
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
error:(NSError **)error;
- (absl::StatusOr<mediapipe::tasks::core::PacketMap>)process:(const mediapipe::tasks::core::PacketMap&)packetMap error:(NSError **)error;
- (void)close;
- (instancetype)init NS_UNAVAILABLE;
+ (instancetype)new NS_UNAVAILABLE;
@end
NS_ASSUME_NONNULL_END

View File

@ -1,56 +0,0 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h"
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
namespace {
using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Packet;
using ::mediapipe::tasks::core::PacketMap;
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
} // namespace
@interface MPPTaskManager () {
/** TextSearcher backed by C++ API */
std::unique_ptr<TaskRunnerCpp> _cppTaskRunner;
}
@end
@implementation MPPTaskManager
- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig
error:(NSError **)error {
self = [super init];
if (self) {
auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig));
if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) {
return nil;
}
_cppTaskRunner = std::move(taskRunnerResult.value());
}
return self;
}
- (absl::StatusOr<PacketMap>)process:(const PacketMap&)packetMap {
return _cppTaskRunner->Process(packetMap);
}
- (void)close {
_cppTaskRunner->Close();
}
@end

View File

@ -19,6 +19,28 @@
NS_ASSUME_NONNULL_BEGIN
static NSString *const kBertTextClassifierModelName = @"bert_text_classifier";
static NSString *const kNegativeText = @"unflinchingly bleak and desperate";
#define VerifyCategory(category, expectedIndex, expectedScore, expectedLabel, expectedDisplayName) \
XCTAssertEqual(category.index, expectedIndex); \
XCTAssertEqualWithAccuracy(category.score, expectedScore, 1e-6); \
XCTAssertEqualObjects(category.label, expectedLabel); \
XCTAssertEqualObjects(category.displayName, expectedDisplayName);
#define VerifyClassifications(classifications, expectedHeadIndex, expectedCategoryCount) \
XCTAssertEqual(classifications.categories.count, expectedCategoryCount);
#define VerifyClassificationResult(classificationResult, expectedClassificationsCount) \
XCTAssertNotNil(classificationResult); \
XCTAssertEqual(classificationResult.classifications.count, expectedClassificationsCount)
#define AssertClassificationResultHasOneHead(classificationResult) \
XCTAssertNotNil(classificationResult); \
XCTAssertEqual(classificationResult.classifications.count, 1);
XCTAssertEqual(classificationResult.classifications[0].headIndex, 1);
#define AssertTextClassifierResultIsNotNil(textClassifierResult) \
XCTAssertNotNil(textClassifierResult);
@interface MPPTextClassifierTests : XCTestCase
@end
@ -41,15 +63,25 @@ static NSString *const kBertTextClassifierModelName = @"bert_text_classifier";
- (MPPTextClassifierOptions *)textClassifierOptionsWithModelName:(NSString *)modelName {
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
MPPTextClassifierOptions *textClassifierOptions =
[[MPPTextClassifierOptions alloc] initWithModelPath:modelPath];
[[MPPTextClassifierOptions alloc] init];
textClassifierOptions.baseOptions.modelAssetPath = modelPath;
return textClassifierOptions;
}
- (void)testCreateTextClassifierOptionsSucceeds {
MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:kBertTextClassifierModelName];
kBertTextClassifierModelName
- (MPPTextClassifier *)createTextClassifierFromOptionsWithModelName:(NSString *)modelName {
MPPTextClassifierOptions *options = [self textClassifierOptionsWithModelName:modelName];
MPPTextClassifier *textClassifier = [[MPPTextClassifier alloc] initWithOptions:options error:nil];
XCTAssertNotNil(textClassifier);
return textClassifier
}
- (void)classifyWithBertSucceeds {
MPPTextClassifier *textClassifier = [self createTextClassifierWithModelName:kBertTextClassifierModelName];
MPPTextClassifierResult *textClassifierResult = [textClassifier classifyWithText:kNegativeText];
}
@end

View File

@ -27,10 +27,10 @@ objc_library(
deps = [
"//mediapipe/tasks/ios/core:MPPTaskOptions",
"//mediapipe/tasks/ios/core:MPPTaskInfo",
"//mediapipe/tasks/ios/core:MPPTaskManager",
"//mediapipe/tasks/ios/core:MPPPacketCreator",
"//mediapipe/tasks/ios/core:MPPTaskRunner",
"//mediapipe/tasks/ios/core:MPPTextPacketCreator",
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers",
"//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers",
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
":MPPTextClassifierOptions",
@ -51,3 +51,12 @@ objc_library(
],
)
objc_library(
name = "MPPTextClassifierResult",
srcs = ["sources/MPPTextClassifierResult.m"],
hdrs = ["sources/MPPTextClassifierResult.h"],
deps = [
"//mediapipe/tasks/ios/core:MPPTaskResult",
"//mediapipe/tasks/ios/components/containers:MPPClassificationResult",
],
)

View File

@ -14,7 +14,7 @@
==============================================================================*/
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h"
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
@ -52,7 +52,7 @@ NS_SWIFT_NAME(TextClassifier)
*/
- (instancetype)initWithOptions:(MPPTextClassifierOptions *)options error:(NSError **)error;
- (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error;
- (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error;
- (instancetype)init NS_UNAVAILABLE;

View File

@ -15,9 +15,9 @@
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.h"
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
#import "mediapipe/tasks/ios/core/sources/MPPPacketCreator.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskManager.h"
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h"
#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h"
@ -30,14 +30,14 @@ using ::mediapipe::tasks::core::PacketMap;
} // namespace
static NSString *const kClassificationsStreamName = @"classifications_out";
static NSString *const kClassificationsTag = @"classifications";
static NSString *const kClassificationsTag = @"CLASSIFICATIONS";
static NSString *const kTextInStreamName = @"text_in";
static NSString *const kTextTag = @"TEXT";
static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
@interface MPPTextClassifier () {
/** TextSearcher backed by C++ API */
MPPTaskManager *_taskManager;
MPPTaskRunner *_taskRunner;
}
@end
@ -47,8 +47,8 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc]
initWithTaskGraphName:kTaskGraphName
inputStreams:@[ [NSString stringWithFormat:@"@:@", kTextTag, kTextInStreamName] ]
outputStreams:@[ [NSString stringWithFormat:@"@:@", kClassificationsTag,
inputStreams:@[ [NSString stringWithFormat:@"%@:%@", kTextTag, kTextInStreamName] ]
outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kClassificationsTag,
kClassificationsStreamName] ]
taskOptions:options
enableFlowLimiting:NO
@ -58,7 +58,7 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
return nil;
}
_taskManager = [[MPPTaskManager alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error];
_taskRunner = [[MPPTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] error:error];
self = [super init];
@ -66,22 +66,23 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
}
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error {
MPPTextClassifierOptions *options =
[[MPPTextClassifierOptions alloc] initWithModelPath:modelPath];
MPPTextClassifierOptions *options = [[MPPTextClassifierOptions alloc] init];
options.baseOptions.modelAssetPath = modelPath;
return [self initWithOptions:options error:error];
}
- (nullable MPPClassificationResult *)classifyWithText:(NSString *)text error:(NSError **)error {
Packet packet = [MPPPacketCreator createWithText:text];
- (nullable MPPTextClassifierResult *)classifyWithText:(NSString *)text error:(NSError **)error {
Packet packet = [MPPTextPacketCreator createWithText:text];
absl::StatusOr<PacketMap> output_packet_map = [_taskManager process:{{kTextInStreamName.cppString, packet}} error:error];
absl::StatusOr<PacketMap> output_packet_map = [_taskRunner process:{{kTextInStreamName.cppString, packet}} error:error];
if (![MPPCommonUtils checkCppError:output_packet_map.status() toError:error]) {
return nil;
}
return [MPPClassificationResult
classificationResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString]
return [MPPTextClassifierResult
textClassifierResultWithProto:output_packet_map.value()[kClassificationsStreamName.cppString]
.Get<ClassificationResult>()];
}

View File

@ -31,20 +31,20 @@ NS_SWIFT_NAME(TextClassifierOptions)
*/
@property(nonatomic, copy) MPPClassifierOptions *classifierOptions;
/**
* Initializes a new `MPPTextClassifierOptions` with the absolute path to the model file
* stored locally on the device, set to the given the model path.
*
* @discussion The external model file must be a single standalone TFLite file. It could be packed
* with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the
* necessary metadata and associated files might result in errors. Check the [documentation]
* (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement.
*
* @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
*
* @return An instance of `MPPTextClassifierOptions` initialized to the given model path.
*/
- (instancetype)initWithModelPath:(NSString *)modelPath;
// /**
// * Initializes a new `MPPTextClassifierOptions` with the absolute path to the model file
// * stored locally on the device, set to the given the model path.
// *
// * @discussion The external model file must be a single standalone TFLite file. It could be packed
// * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the
// * necessary metadata and associated files might result in errors. Check the [documentation]
// * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement.
// *
// * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
// *
// * @return An instance of `MPPTextClassifierOptions` initialized to the given model path.
// */
// - (instancetype)initWithModelPath:(NSString *)modelPath;
@end

View File

@ -16,12 +16,12 @@
@implementation MPPTextClassifierOptions
- (instancetype)initWithModelPath:(NSString *)modelPath {
self = [super initWithModelPath:modelPath];
if (self) {
_classifierOptions = [[MPPClassifierOptions alloc] init];
}
return self;
}
// - (instancetype)initWithModelPath:(NSString *)modelPath {
// self = [super initWithModelPath:modelPath];
// if (self) {
// _classifierOptions = [[MPPClassifierOptions alloc] init];
// }
// return self;
// }
@end

View File

@ -0,0 +1,41 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
NS_ASSUME_NONNULL_BEGIN
/** Encapsulates list of predicted classes (aka labels) for a given image classifier head. */
NS_SWIFT_NAME(TextClassifierResult)
@interface MPPTextClassifierResult : MPPTaskResult
@property(nonatomic, readonly) MPPClassificationResult *classificationResult;
/**
* Initializes a new `MPPClassificationResult` with the given array of classifications.
*
* @param classifications An Aaray of `MPPClassifications` objects containing classifier
* predictions per classifier head.
*
* @return An instance of MPPClassificationResult initialized with the given array of
* classifications.
*/
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
timeStamp:(long)timeStamp;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,28 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h"
@implementation MPPTextClassifierResult
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
timeStamp:(long)timeStamp {
self = [super initWithTimestamp:timeStamp];
if (self) {
_classificationResult = classificationResult;
}
return self;
}
@end

View File

@ -28,3 +28,13 @@ objc_library(
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
],
)
objc_library(
name = "MPPTextClassifierResultHelpers",
srcs = ["sources/MPPTextClassifierResult+Helpers.mm"],
hdrs = ["sources/MPPTextClassifierResult+Helpers.h"],
deps = [
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifierResult",
"//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers",
],
)

View File

@ -1,17 +1,17 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierOptions.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"

View File

@ -1,17 +1,17 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
#import "mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.h"
#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h"

View File

@ -0,0 +1,28 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#import "mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h"
NS_ASSUME_NONNULL_BEGIN
@interface MPPTextClassifierResult (Helpers)
+ (MPPTextClassifierResult *)textClassifierResultWithProto:
(const mediapipe::tasks::components::containers::proto::ClassificationResult &)
classificationResultProto;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,39 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h"
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
namespace {
using ClassificationResultProto =
::mediapipe::tasks::components::containers::proto::ClassificationResult;
} // namespace
@implementation MPPTextClassifierResult (Helpers)
+ (MPPTextClassifierResult *)textClassifierResultWithProto:
(const ClassificationResultProto &)classificationResultProto {
long timeStamp;
if (classificationResultProto.has_timestamp_ms()) {
timeStamp = classificationResultProto.timestamp_ms();
}
MPPClassificationResult *classificationResult = [MPPClassificationResult classificationResultWithProto:classificationResultProto];
return [[MPPTextClassifierResult alloc] initWithClassificationResult:classificationResult
timeStamp:timeStamp];
}
@end