diff --git a/mediapipe/tasks/ios/BUILD b/mediapipe/tasks/ios/BUILD index 009431eb1..e165fd61e 100644 --- a/mediapipe/tasks/ios/BUILD +++ b/mediapipe/tasks/ios/BUILD @@ -83,6 +83,9 @@ strip_api_include_path_prefix( "//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedder.h", "//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderOptions.h", "//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderResult.h", + "//mediapipe/tasks/ios/text/language_detector:sources/MPPLanguageDetector.h", + "//mediapipe/tasks/ios/text/language_detector:sources/MPPLanguageDetectorOptions.h", + "//mediapipe/tasks/ios/text/language_detector:sources/MPPLanguageDetectorResult.h", "//mediapipe/tasks/ios/vision/core:sources/MPPRunningMode.h", "//mediapipe/tasks/ios/vision/core:sources/MPPImage.h", "//mediapipe/tasks/ios/vision/core:sources/MPPMask.h", @@ -147,10 +150,14 @@ apple_static_xcframework( ":MPPTextEmbedder.h", ":MPPTextEmbedderOptions.h", ":MPPTextEmbedderResult.h", + ":MPPLanguageDetector.h", + ":MPPLanguageDetectorOptions.h", + ":MPPLanguageDetectorResult.h", ], deps = [ "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier", "//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder", + "//mediapipe/tasks/ios/text/language_detector:MPPLanguageDetector", ], ) diff --git a/mediapipe/tasks/ios/test/text/language_detector/MPPLanguageDetectorTests.mm b/mediapipe/tasks/ios/test/text/language_detector/MPPLanguageDetectorTests.mm index 28d2ea5c0..1b738a124 100644 --- a/mediapipe/tasks/ios/test/text/language_detector/MPPLanguageDetectorTests.mm +++ b/mediapipe/tasks/ios/test/text/language_detector/MPPLanguageDetectorTests.mm @@ -116,6 +116,82 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; [self assertResultsOfDetectLanguageOfText:ruText usingLanguageDetector:languageDetector approximatelyEqualsExpectedLanguagePredictions:expectedRuLanguagePredictions]; + + NSString *zhText = @"分久必合合久必分"; + NSArray *expectedZhLanguagePredictions = @[ + [[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f], + [[MPPLanguagePrediction alloc] initWithLanguageCode:@"ja" probability:0.481617f] + ]; + + [self assertResultsOfDetectLanguageOfText:zhText + usingLanguageDetector:languageDetector + approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions]; +} + +- (void)testClassifyWithMaxResultsSucceeds { + MPPLanguageDetectorOptions *options = + [self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo]; + options.maxResults = 1; + MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options]; + + NSString *zhText = @"分久必合合久必分"; + NSArray *expectedZhLanguagePredictions = @[ + [[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f], + ]; + + [self assertResultsOfDetectLanguageOfText:zhText + usingLanguageDetector:languageDetector + approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions]; +} + +- (void)testClassifyWithScoreThresholdSucceeds { + MPPLanguageDetectorOptions *options = + [self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo]; + options.scoreThreshold = 0.5f; + MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options]; + + NSString *zhText = @"分久必合合久必分"; + NSArray *expectedZhLanguagePredictions = @[ + [[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f], + ]; + + [self assertResultsOfDetectLanguageOfText:zhText + usingLanguageDetector:languageDetector + approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions]; +} + +- (void)testClassifyWithCategoryAllowListSucceeds { + MPPLanguageDetectorOptions *options = + [self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo]; + options.categoryAllowlist = @[ @"zh" ]; + + MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options]; + + NSString *zhText = @"分久必合合久必分"; + NSArray *expectedZhLanguagePredictions = @[ + [[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f], + ]; + + [self assertResultsOfDetectLanguageOfText:zhText + usingLanguageDetector:languageDetector + approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions]; +} + +- (void)testClassifyWithCategoryDenyListSucceeds { + MPPLanguageDetectorOptions *options = + [self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo]; + options.categoryDenylist = @[ @"zh" ]; + + MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options]; + + NSString *zhText = @"分久必合合久必分"; + NSArray *expectedZhLanguagePredictions = @[ + [[MPPLanguagePrediction alloc] initWithLanguageCode:@"ja" probability:0.481617f], + ]; + + [self assertResultsOfDetectLanguageOfText:zhText + usingLanguageDetector:languageDetector + approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions]; } #pragma mark Assert Segmenter Results @@ -125,6 +201,8 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; (NSArray *)expectedLanguagePredictions { MPPLanguageDetectorResult *result = [languageDetector detectText:text error:nil]; XCTAssertNotNil(result); + + XCTAssertEqual(result.languagePredictions.count, expectedLanguagePredictions.count); XCTAssertEqualWithAccuracy(result.languagePredictions[0].probability, expectedLanguagePredictions[0].probability, 1e-3); XCTAssertEqualObjects(result.languagePredictions[0].languageCode, diff --git a/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorOptions.m b/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorOptions.m index df36493ef..9113a9a47 100644 --- a/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorOptions.m +++ b/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorOptions.m @@ -20,7 +20,7 @@ self = [super init]; if (self) { _maxResults = -1; - _scoreThreshold = 0; + _scoreThreshold = -1.0f; } return self; } diff --git a/mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorOptions+Helpers.mm b/mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorOptions+Helpers.mm index 9d75105b4..45cad7a18 100644 --- a/mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorOptions+Helpers.mm +++ b/mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorOptions+Helpers.mm @@ -40,9 +40,12 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto if (self.displayNamesLocale) { classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); } - + classifierOptionsProto->set_max_results((int)self.maxResults); - classifierOptionsProto->set_score_threshold(self.scoreThreshold); + + if (self.scoreThreshold >= 0) { + classifierOptionsProto->set_score_threshold(self.scoreThreshold); + } for (NSString *category in self.categoryAllowlist) { classifierOptionsProto->add_category_allowlist(category.cppString);