Merge pull request #4941 from priankakariatyml:ios-language-detector-tests

PiperOrigin-RevId: 580577290
This commit is contained in:
Copybara-Service 2023-11-08 10:34:06 -08:00
commit 8d4407b04e
3 changed files with 220 additions and 2 deletions

View File

@ -0,0 +1,57 @@
load(
"//mediapipe/framework/tool:ios.bzl",
"MPP_TASK_MINIMUM_OS_VERSION",
)
load(
"@build_bazel_rules_apple//apple:ios.bzl",
"ios_unit_test",
)
load(
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
"tflite_ios_lab_runner",
)
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
TFL_DEFAULT_TAGS = [
"apple",
]
# Following sanitizer tests are not supported by iOS test targets.
TFL_DISABLED_SANITIZER_TAGS = [
"noasan",
"nomsan",
"notsan",
]
objc_library(
name = "MPPLanguageDetectorObjcTestLibrary",
testonly = 1,
srcs = ["MPPLanguageDetectorTests.mm"],
copts = [
"-ObjC++",
"-std=c++17",
"-x objective-c++",
],
data = [
"//mediapipe/tasks/testdata/text:language_detector",
],
deps = [
"//mediapipe/tasks/ios/common:MPPCommon",
"//mediapipe/tasks/ios/test/utils:MPPFileInfo",
"//mediapipe/tasks/ios/text/language_detector:MPPLanguageDetector",
],
)
ios_unit_test(
name = "MPPLanguageDetectorObjcTest",
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
deps = [
":MPPLanguageDetectorObjcTestLibrary",
],
)

View File

@ -0,0 +1,162 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <XCTest/XCTest.h>
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
#import "mediapipe/tasks/ios/test/utils/sources/MPPFileInfo.h"
#import "mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetector.h"
static MPPFileInfo *const kLanguageDetectorModelFileInfo =
[[MPPFileInfo alloc] initWithName:@"language_detector" type:@"tflite"];
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
#define AssertEqualErrors(error, expectedError) \
XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
@interface MPPLanguageDetectorTests : XCTestCase
@end
@implementation MPPLanguageDetectorTests
- (void)testCreateLanguageDetectorFailsWithMissingModelPath {
MPPFileInfo *fileInfo = [[MPPFileInfo alloc] initWithName:@"" type:@""];
NSError *error = nil;
MPPLanguageDetector *languageDetector =
[[MPPLanguageDetector alloc] initWithModelPath:fileInfo.path error:&error];
XCTAssertNil(languageDetector);
NSError *expectedError = [NSError
errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', "
@"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."
}];
AssertEqualErrors(error, expectedError);
}
- (void)testCreateLanguageDetectorFailsWithBothAllowlistAndDenylist {
MPPLanguageDetectorOptions *options =
[self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo];
options.categoryAllowlist = @[ @"en" ];
options.categoryDenylist = @[ @"en" ];
[self assertCreateLanguageDetectorWithOptions:options
failsWithExpectedError:
[NSError
errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"INVALID_ARGUMENT: `category_allowlist` and "
@"`category_denylist` are mutually exclusive options."
}]];
}
- (void)testCreateLanguageDetectorFailsWithInvalidMaxResults {
MPPLanguageDetectorOptions *options =
[self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo];
options.maxResults = 0;
[self
assertCreateLanguageDetectorWithOptions:options
failsWithExpectedError:
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"INVALID_ARGUMENT: Invalid `max_results` option: "
@"value must be != 0."
}]];
}
- (void)testClassifyWithL2CModelSucceeds {
MPPLanguageDetectorOptions *options =
[self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo];
MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options];
NSString *enText = @"To be, or not to be, that is the question";
NSArray<MPPLanguagePrediction *> *expectedEnLanguagePredictions =
@[ [[MPPLanguagePrediction alloc] initWithLanguageCode:@"en" probability:0.9998559f] ];
[self assertResultsOfDetectLanguageOfText:enText
usingLanguageDetector:languageDetector
approximatelyEqualsExpectedLanguagePredictions:expectedEnLanguagePredictions];
NSString *frText = @"Il y a beaucoup de bouches qui parlent et fort peu de têtes qui pensent.";
NSArray<MPPLanguagePrediction *> *expectedFrLanguagePredictions =
@[ [[MPPLanguagePrediction alloc] initWithLanguageCode:@"fr" probability:0.9997813f] ];
[self assertResultsOfDetectLanguageOfText:frText
usingLanguageDetector:languageDetector
approximatelyEqualsExpectedLanguagePredictions:expectedFrLanguagePredictions];
NSString *ruText = @"это какой-то английский язык";
NSArray<MPPLanguagePrediction *> *expectedRuLanguagePredictions =
@[ [[MPPLanguagePrediction alloc] initWithLanguageCode:@"ru" probability:0.9933616f] ];
[self assertResultsOfDetectLanguageOfText:ruText
usingLanguageDetector:languageDetector
approximatelyEqualsExpectedLanguagePredictions:expectedRuLanguagePredictions];
}
#pragma mark Assert Segmenter Results
- (void)assertResultsOfDetectLanguageOfText:(NSString *)text
usingLanguageDetector:(MPPLanguageDetector *)languageDetector
approximatelyEqualsExpectedLanguagePredictions:
(NSArray<MPPLanguagePrediction *> *)expectedLanguagePredictions {
MPPLanguageDetectorResult *result = [languageDetector detectText:text error:nil];
XCTAssertNotNil(result);
XCTAssertEqualWithAccuracy(result.languagePredictions[0].probability,
expectedLanguagePredictions[0].probability, 1e-3);
XCTAssertEqualObjects(result.languagePredictions[0].languageCode,
expectedLanguagePredictions[0].languageCode);
}
#pragma mark Language Detector Initializers
- (MPPLanguageDetectorOptions *)languageDetectorOptionsWithModelFileInfo:(MPPFileInfo *)fileInfo {
MPPLanguageDetectorOptions *options = [[MPPLanguageDetectorOptions alloc] init];
options.baseOptions.modelAssetPath = fileInfo.path;
return options;
}
- (MPPLanguageDetector *)createLanguageDetectorWithOptionsSucceeds:
(MPPLanguageDetectorOptions *)options {
NSError *error;
MPPLanguageDetector *languageDetector = [[MPPLanguageDetector alloc] initWithOptions:options
error:&error];
XCTAssertNotNil(languageDetector);
XCTAssertNil(error);
return languageDetector;
}
- (void)assertCreateLanguageDetectorWithOptions:(MPPLanguageDetectorOptions *)options
failsWithExpectedError:(NSError *)expectedError {
NSError *error = nil;
MPPLanguageDetector *languageDetector = [[MPPLanguageDetector alloc] initWithOptions:options
error:&error];
XCTAssertNil(languageDetector);
AssertEqualErrors(error, expectedError);
}
@end

View File

@ -31,8 +31,7 @@ static NSString *const kClassificationsStreamName = @"classifications_out";
static NSString *const kClassificationsTag = @"CLASSIFICATIONS"; static NSString *const kClassificationsTag = @"CLASSIFICATIONS";
static NSString *const kTextInStreamName = @"text_in"; static NSString *const kTextInStreamName = @"text_in";
static NSString *const kTextTag = @"TEXT"; static NSString *const kTextTag = @"TEXT";
static NSString *const kTaskGraphName = static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
@"mediapipe.tasks.text.language_detector.LanguageDetectorGraph";
@interface MPPLanguageDetector () { @interface MPPLanguageDetector () {
/** iOS Text Task Runner */ /** iOS Text Task Runner */