Fixed issue with iOS Language Detector Prediction Count

This commit is contained in:
Prianka Liz Kariat 2023-12-14 09:21:40 +05:30
parent 47e217896c
commit a7f3321dbb
3 changed files with 16 additions and 5 deletions

View File

@ -113,9 +113,15 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
NSArray<MPPLanguagePrediction *> *expectedRuLanguagePredictions = NSArray<MPPLanguagePrediction *> *expectedRuLanguagePredictions =
@[ [[MPPLanguagePrediction alloc] initWithLanguageCode:@"ru" probability:0.9933616f] ]; @[ [[MPPLanguagePrediction alloc] initWithLanguageCode:@"ru" probability:0.9933616f] ];
[self assertResultsOfDetectLanguageOfText:ruText NSString *zhText = @"分久必合合久必分";
NSArray<MPPLanguagePrediction *> *expectedZhLanguagePredictions = @[
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f],
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"ja" probability:0.481617f]
];
[self assertResultsOfDetectLanguageOfText:zhText
usingLanguageDetector:languageDetector usingLanguageDetector:languageDetector
approximatelyEqualsExpectedLanguagePredictions:expectedRuLanguagePredictions]; approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions];
} }
#pragma mark Assert Segmenter Results #pragma mark Assert Segmenter Results
@ -125,6 +131,8 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
(NSArray<MPPLanguagePrediction *> *)expectedLanguagePredictions { (NSArray<MPPLanguagePrediction *> *)expectedLanguagePredictions {
MPPLanguageDetectorResult *result = [languageDetector detectText:text error:nil]; MPPLanguageDetectorResult *result = [languageDetector detectText:text error:nil];
XCTAssertNotNil(result); XCTAssertNotNil(result);
XCTAssertEqual(result.languagePredictions.count, expectedLanguagePredictions.count);
XCTAssertEqualWithAccuracy(result.languagePredictions[0].probability, XCTAssertEqualWithAccuracy(result.languagePredictions[0].probability,
expectedLanguagePredictions[0].probability, 1e-3); expectedLanguagePredictions[0].probability, 1e-3);
XCTAssertEqualObjects(result.languagePredictions[0].languageCode, XCTAssertEqualObjects(result.languagePredictions[0].languageCode,

View File

@ -20,7 +20,7 @@
self = [super init]; self = [super init];
if (self) { if (self) {
_maxResults = -1; _maxResults = -1;
_scoreThreshold = 0; _scoreThreshold = -1.0f;
} }
return self; return self;
} }

View File

@ -40,9 +40,12 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto
if (self.displayNamesLocale) { if (self.displayNamesLocale) {
classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString);
} }
classifierOptionsProto->set_max_results((int)self.maxResults); classifierOptionsProto->set_max_results((int)self.maxResults);
classifierOptionsProto->set_score_threshold(self.scoreThreshold);
if (self.scoreThreshold >= 0) {
classifierOptionsProto->set_score_threshold(self.scoreThreshold);
}
for (NSString *category in self.categoryAllowlist) { for (NSString *category in self.categoryAllowlist) {
classifierOptionsProto->add_category_allowlist(category.cppString); classifierOptionsProto->add_category_allowlist(category.cppString);