This commit is contained in:
priankakariatyml 2024-01-02 10:36:34 +01:00 committed by GitHub
commit 1b27c12087
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 91 additions and 3 deletions

View File

@ -83,6 +83,9 @@ strip_api_include_path_prefix(
"//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedder.h", "//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedder.h",
"//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderOptions.h", "//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderOptions.h",
"//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderResult.h", "//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderResult.h",
"//mediapipe/tasks/ios/text/language_detector:sources/MPPLanguageDetector.h",
"//mediapipe/tasks/ios/text/language_detector:sources/MPPLanguageDetectorOptions.h",
"//mediapipe/tasks/ios/text/language_detector:sources/MPPLanguageDetectorResult.h",
"//mediapipe/tasks/ios/vision/core:sources/MPPRunningMode.h", "//mediapipe/tasks/ios/vision/core:sources/MPPRunningMode.h",
"//mediapipe/tasks/ios/vision/core:sources/MPPImage.h", "//mediapipe/tasks/ios/vision/core:sources/MPPImage.h",
"//mediapipe/tasks/ios/vision/core:sources/MPPMask.h", "//mediapipe/tasks/ios/vision/core:sources/MPPMask.h",
@ -147,10 +150,14 @@ apple_static_xcframework(
":MPPTextEmbedder.h", ":MPPTextEmbedder.h",
":MPPTextEmbedderOptions.h", ":MPPTextEmbedderOptions.h",
":MPPTextEmbedderResult.h", ":MPPTextEmbedderResult.h",
":MPPLanguageDetector.h",
":MPPLanguageDetectorOptions.h",
":MPPLanguageDetectorResult.h",
], ],
deps = [ deps = [
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier", "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier",
"//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder", "//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder",
"//mediapipe/tasks/ios/text/language_detector:MPPLanguageDetector",
], ],
) )

View File

@ -116,6 +116,82 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
[self assertResultsOfDetectLanguageOfText:ruText [self assertResultsOfDetectLanguageOfText:ruText
usingLanguageDetector:languageDetector usingLanguageDetector:languageDetector
approximatelyEqualsExpectedLanguagePredictions:expectedRuLanguagePredictions]; approximatelyEqualsExpectedLanguagePredictions:expectedRuLanguagePredictions];
NSString *zhText = @"分久必合合久必分";
NSArray<MPPLanguagePrediction *> *expectedZhLanguagePredictions = @[
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f],
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"ja" probability:0.481617f]
];
[self assertResultsOfDetectLanguageOfText:zhText
usingLanguageDetector:languageDetector
approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions];
}
- (void)testClassifyWithMaxResultsSucceeds {
MPPLanguageDetectorOptions *options =
[self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo];
options.maxResults = 1;
MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options];
NSString *zhText = @"分久必合合久必分";
NSArray<MPPLanguagePrediction *> *expectedZhLanguagePredictions = @[
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f],
];
[self assertResultsOfDetectLanguageOfText:zhText
usingLanguageDetector:languageDetector
approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions];
}
- (void)testClassifyWithScoreThresholdSucceeds {
MPPLanguageDetectorOptions *options =
[self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo];
options.scoreThreshold = 0.5f;
MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options];
NSString *zhText = @"分久必合合久必分";
NSArray<MPPLanguagePrediction *> *expectedZhLanguagePredictions = @[
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f],
];
[self assertResultsOfDetectLanguageOfText:zhText
usingLanguageDetector:languageDetector
approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions];
}
- (void)testClassifyWithCategoryAllowListSucceeds {
MPPLanguageDetectorOptions *options =
[self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo];
options.categoryAllowlist = @[ @"zh" ];
MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options];
NSString *zhText = @"分久必合合久必分";
NSArray<MPPLanguagePrediction *> *expectedZhLanguagePredictions = @[
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f],
];
[self assertResultsOfDetectLanguageOfText:zhText
usingLanguageDetector:languageDetector
approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions];
}
- (void)testClassifyWithCategoryDenyListSucceeds {
MPPLanguageDetectorOptions *options =
[self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo];
options.categoryDenylist = @[ @"zh" ];
MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options];
NSString *zhText = @"分久必合合久必分";
NSArray<MPPLanguagePrediction *> *expectedZhLanguagePredictions = @[
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"ja" probability:0.481617f],
];
[self assertResultsOfDetectLanguageOfText:zhText
usingLanguageDetector:languageDetector
approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions];
} }
#pragma mark Assert Segmenter Results #pragma mark Assert Segmenter Results
@ -125,6 +201,8 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
(NSArray<MPPLanguagePrediction *> *)expectedLanguagePredictions { (NSArray<MPPLanguagePrediction *> *)expectedLanguagePredictions {
MPPLanguageDetectorResult *result = [languageDetector detectText:text error:nil]; MPPLanguageDetectorResult *result = [languageDetector detectText:text error:nil];
XCTAssertNotNil(result); XCTAssertNotNil(result);
XCTAssertEqual(result.languagePredictions.count, expectedLanguagePredictions.count);
XCTAssertEqualWithAccuracy(result.languagePredictions[0].probability, XCTAssertEqualWithAccuracy(result.languagePredictions[0].probability,
expectedLanguagePredictions[0].probability, 1e-3); expectedLanguagePredictions[0].probability, 1e-3);
XCTAssertEqualObjects(result.languagePredictions[0].languageCode, XCTAssertEqualObjects(result.languagePredictions[0].languageCode,

View File

@ -20,7 +20,7 @@
self = [super init]; self = [super init];
if (self) { if (self) {
_maxResults = -1; _maxResults = -1;
_scoreThreshold = 0; _scoreThreshold = -1.0f;
} }
return self; return self;
} }

View File

@ -42,7 +42,10 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto
} }
classifierOptionsProto->set_max_results((int)self.maxResults); classifierOptionsProto->set_max_results((int)self.maxResults);
if (self.scoreThreshold >= 0) {
classifierOptionsProto->set_score_threshold(self.scoreThreshold); classifierOptionsProto->set_score_threshold(self.scoreThreshold);
}
for (NSString *category in self.categoryAllowlist) { for (NSString *category in self.categoryAllowlist) {
classifierOptionsProto->add_category_allowlist(category.cppString); classifierOptionsProto->add_category_allowlist(category.cppString);