Merge a0c482f720
into ec43bea176
This commit is contained in:
commit
1b27c12087
|
@ -83,6 +83,9 @@ strip_api_include_path_prefix(
|
||||||
"//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedder.h",
|
"//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedder.h",
|
||||||
"//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderOptions.h",
|
"//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderOptions.h",
|
||||||
"//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderResult.h",
|
"//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderResult.h",
|
||||||
|
"//mediapipe/tasks/ios/text/language_detector:sources/MPPLanguageDetector.h",
|
||||||
|
"//mediapipe/tasks/ios/text/language_detector:sources/MPPLanguageDetectorOptions.h",
|
||||||
|
"//mediapipe/tasks/ios/text/language_detector:sources/MPPLanguageDetectorResult.h",
|
||||||
"//mediapipe/tasks/ios/vision/core:sources/MPPRunningMode.h",
|
"//mediapipe/tasks/ios/vision/core:sources/MPPRunningMode.h",
|
||||||
"//mediapipe/tasks/ios/vision/core:sources/MPPImage.h",
|
"//mediapipe/tasks/ios/vision/core:sources/MPPImage.h",
|
||||||
"//mediapipe/tasks/ios/vision/core:sources/MPPMask.h",
|
"//mediapipe/tasks/ios/vision/core:sources/MPPMask.h",
|
||||||
|
@ -147,10 +150,14 @@ apple_static_xcframework(
|
||||||
":MPPTextEmbedder.h",
|
":MPPTextEmbedder.h",
|
||||||
":MPPTextEmbedderOptions.h",
|
":MPPTextEmbedderOptions.h",
|
||||||
":MPPTextEmbedderResult.h",
|
":MPPTextEmbedderResult.h",
|
||||||
|
":MPPLanguageDetector.h",
|
||||||
|
":MPPLanguageDetectorOptions.h",
|
||||||
|
":MPPLanguageDetectorResult.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier",
|
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier",
|
||||||
"//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder",
|
"//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder",
|
||||||
|
"//mediapipe/tasks/ios/text/language_detector:MPPLanguageDetector",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -116,6 +116,82 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
[self assertResultsOfDetectLanguageOfText:ruText
|
[self assertResultsOfDetectLanguageOfText:ruText
|
||||||
usingLanguageDetector:languageDetector
|
usingLanguageDetector:languageDetector
|
||||||
approximatelyEqualsExpectedLanguagePredictions:expectedRuLanguagePredictions];
|
approximatelyEqualsExpectedLanguagePredictions:expectedRuLanguagePredictions];
|
||||||
|
|
||||||
|
NSString *zhText = @"分久必合合久必分";
|
||||||
|
NSArray<MPPLanguagePrediction *> *expectedZhLanguagePredictions = @[
|
||||||
|
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f],
|
||||||
|
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"ja" probability:0.481617f]
|
||||||
|
];
|
||||||
|
|
||||||
|
[self assertResultsOfDetectLanguageOfText:zhText
|
||||||
|
usingLanguageDetector:languageDetector
|
||||||
|
approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testClassifyWithMaxResultsSucceeds {
|
||||||
|
MPPLanguageDetectorOptions *options =
|
||||||
|
[self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo];
|
||||||
|
options.maxResults = 1;
|
||||||
|
MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options];
|
||||||
|
|
||||||
|
NSString *zhText = @"分久必合合久必分";
|
||||||
|
NSArray<MPPLanguagePrediction *> *expectedZhLanguagePredictions = @[
|
||||||
|
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f],
|
||||||
|
];
|
||||||
|
|
||||||
|
[self assertResultsOfDetectLanguageOfText:zhText
|
||||||
|
usingLanguageDetector:languageDetector
|
||||||
|
approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testClassifyWithScoreThresholdSucceeds {
|
||||||
|
MPPLanguageDetectorOptions *options =
|
||||||
|
[self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo];
|
||||||
|
options.scoreThreshold = 0.5f;
|
||||||
|
MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options];
|
||||||
|
|
||||||
|
NSString *zhText = @"分久必合合久必分";
|
||||||
|
NSArray<MPPLanguagePrediction *> *expectedZhLanguagePredictions = @[
|
||||||
|
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f],
|
||||||
|
];
|
||||||
|
|
||||||
|
[self assertResultsOfDetectLanguageOfText:zhText
|
||||||
|
usingLanguageDetector:languageDetector
|
||||||
|
approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testClassifyWithCategoryAllowListSucceeds {
|
||||||
|
MPPLanguageDetectorOptions *options =
|
||||||
|
[self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo];
|
||||||
|
options.categoryAllowlist = @[ @"zh" ];
|
||||||
|
|
||||||
|
MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options];
|
||||||
|
|
||||||
|
NSString *zhText = @"分久必合合久必分";
|
||||||
|
NSArray<MPPLanguagePrediction *> *expectedZhLanguagePredictions = @[
|
||||||
|
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f],
|
||||||
|
];
|
||||||
|
|
||||||
|
[self assertResultsOfDetectLanguageOfText:zhText
|
||||||
|
usingLanguageDetector:languageDetector
|
||||||
|
approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testClassifyWithCategoryDenyListSucceeds {
|
||||||
|
MPPLanguageDetectorOptions *options =
|
||||||
|
[self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo];
|
||||||
|
options.categoryDenylist = @[ @"zh" ];
|
||||||
|
|
||||||
|
MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options];
|
||||||
|
|
||||||
|
NSString *zhText = @"分久必合合久必分";
|
||||||
|
NSArray<MPPLanguagePrediction *> *expectedZhLanguagePredictions = @[
|
||||||
|
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"ja" probability:0.481617f],
|
||||||
|
];
|
||||||
|
|
||||||
|
[self assertResultsOfDetectLanguageOfText:zhText
|
||||||
|
usingLanguageDetector:languageDetector
|
||||||
|
approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma mark Assert Segmenter Results
|
#pragma mark Assert Segmenter Results
|
||||||
|
@ -125,6 +201,8 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
(NSArray<MPPLanguagePrediction *> *)expectedLanguagePredictions {
|
(NSArray<MPPLanguagePrediction *> *)expectedLanguagePredictions {
|
||||||
MPPLanguageDetectorResult *result = [languageDetector detectText:text error:nil];
|
MPPLanguageDetectorResult *result = [languageDetector detectText:text error:nil];
|
||||||
XCTAssertNotNil(result);
|
XCTAssertNotNil(result);
|
||||||
|
|
||||||
|
XCTAssertEqual(result.languagePredictions.count, expectedLanguagePredictions.count);
|
||||||
XCTAssertEqualWithAccuracy(result.languagePredictions[0].probability,
|
XCTAssertEqualWithAccuracy(result.languagePredictions[0].probability,
|
||||||
expectedLanguagePredictions[0].probability, 1e-3);
|
expectedLanguagePredictions[0].probability, 1e-3);
|
||||||
XCTAssertEqualObjects(result.languagePredictions[0].languageCode,
|
XCTAssertEqualObjects(result.languagePredictions[0].languageCode,
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
self = [super init];
|
self = [super init];
|
||||||
if (self) {
|
if (self) {
|
||||||
_maxResults = -1;
|
_maxResults = -1;
|
||||||
_scoreThreshold = 0;
|
_scoreThreshold = -1.0f;
|
||||||
}
|
}
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,7 +42,10 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto
|
||||||
}
|
}
|
||||||
|
|
||||||
classifierOptionsProto->set_max_results((int)self.maxResults);
|
classifierOptionsProto->set_max_results((int)self.maxResults);
|
||||||
|
|
||||||
|
if (self.scoreThreshold >= 0) {
|
||||||
classifierOptionsProto->set_score_threshold(self.scoreThreshold);
|
classifierOptionsProto->set_score_threshold(self.scoreThreshold);
|
||||||
|
}
|
||||||
|
|
||||||
for (NSString *category in self.categoryAllowlist) {
|
for (NSString *category in self.categoryAllowlist) {
|
||||||
classifierOptionsProto->add_category_allowlist(category.cppString);
|
classifierOptionsProto->add_category_allowlist(category.cppString);
|
||||||
|
|
Loading…
Reference in New Issue
Block a user