Fixed issue with iOS Language Detector Prediction Count

This commit is contained in:
Prianka Liz Kariat 2023-12-14 09:21:40 +05:30
parent 47e217896c
commit a7f3321dbb
3 changed files with 16 additions and 5 deletions

View File

@ -113,9 +113,15 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
NSArray<MPPLanguagePrediction *> *expectedRuLanguagePredictions =
@[ [[MPPLanguagePrediction alloc] initWithLanguageCode:@"ru" probability:0.9933616f] ];
[self assertResultsOfDetectLanguageOfText:ruText
NSString *zhText = @"分久必合合久必分";
NSArray<MPPLanguagePrediction *> *expectedZhLanguagePredictions = @[
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f],
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"ja" probability:0.481617f]
];
[self assertResultsOfDetectLanguageOfText:zhText
usingLanguageDetector:languageDetector
approximatelyEqualsExpectedLanguagePredictions:expectedRuLanguagePredictions];
approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions];
}
#pragma mark Assert Segmenter Results
@ -125,6 +131,8 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
(NSArray<MPPLanguagePrediction *> *)expectedLanguagePredictions {
MPPLanguageDetectorResult *result = [languageDetector detectText:text error:nil];
XCTAssertNotNil(result);
XCTAssertEqual(result.languagePredictions.count, expectedLanguagePredictions.count);
XCTAssertEqualWithAccuracy(result.languagePredictions[0].probability,
expectedLanguagePredictions[0].probability, 1e-3);
XCTAssertEqualObjects(result.languagePredictions[0].languageCode,

View File

@ -20,7 +20,7 @@
self = [super init];
if (self) {
_maxResults = -1;
_scoreThreshold = 0;
_scoreThreshold = -1.0f;
}
return self;
}

View File

@ -40,9 +40,12 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto
if (self.displayNamesLocale) {
classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString);
}
classifierOptionsProto->set_max_results((int)self.maxResults);
classifierOptionsProto->set_score_threshold(self.scoreThreshold);
if (self.scoreThreshold >= 0) {
classifierOptionsProto->set_score_threshold(self.scoreThreshold);
}
for (NSString *category in self.categoryAllowlist) {
classifierOptionsProto->add_category_allowlist(category.cppString);