From 29c7702984fd0309fbadf64347fdd7cb5604b52f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 1 Dec 2022 05:50:46 -0800 Subject: [PATCH] Inline formerly nested 'ClassifierOptions' in Java classifier APIs. PiperOrigin-RevId: 492173060 --- .../com/google/mediapipe/tasks/audio/BUILD | 2 +- .../audioclassifier/AudioClassifier.java | 84 ++++++++++++++--- .../com/google/mediapipe/tasks/text/BUILD | 2 +- .../text/textclassifier/TextClassifier.java | 90 ++++++++++++++++--- .../com/google/mediapipe/tasks/vision/BUILD | 2 +- .../imageclassifier/ImageClassifier.java | 82 ++++++++++++++--- .../textclassifier/TextClassifierTest.java | 31 +++++++ .../imageclassifier/ImageClassifierTest.java | 81 +++++++++++------ 8 files changed, 305 insertions(+), 69 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD index 6771335ad..2afc75ec0 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD @@ -66,10 +66,10 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java index 0f3374175..d78685fe3 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java @@ -27,7 +27,7 @@ import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi; import com.google.mediapipe.tasks.audio.core.RunningMode; import com.google.mediapipe.tasks.components.containers.AudioData; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; @@ -266,7 +266,7 @@ public final class AudioClassifier extends BaseAudioTaskApi { /* * Sends audio data (a block in a continuous audio stream) to perform audio classification, and - * the results will be available via the {@link ResultListener} provided in the + * the results will be available via the {@link ResultListener} provided in the * {@link AudioClassifierOptions}. Only use this method when the AudioClassifier is created with * the audio stream mode. * @@ -320,10 +320,42 @@ public final class AudioClassifier extends BaseAudioTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); + + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); /** * Sets the {@link ResultListener} to receive the classification results asynchronously when @@ -340,9 +372,7 @@ public final class AudioClassifier extends BaseAudioTaskApi { /** * Validates and builds the {@link AudioClassifierOptions} instance. * - * @throws IllegalArgumentException if the result listener and the running mode are not - * properly configured. The result listener should only be set when the audio classifier - * is in the audio stream mode. + * @throws IllegalArgumentException if any of the set options are invalid. */ public final AudioClassifierOptions build() { AudioClassifierOptions options = autoBuild(); @@ -357,6 +387,13 @@ public final class AudioClassifier extends BaseAudioTaskApi { "The audio classifier is in the audio clips mode, a user-defined result listener" + " shouldn't be provided in AudioClassifierOptions."); } + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } return options; } } @@ -365,7 +402,15 @@ public final class AudioClassifier extends BaseAudioTaskApi { abstract RunningMode runningMode(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); abstract Optional> resultListener(); @@ -373,7 +418,9 @@ public final class AudioClassifier extends BaseAudioTaskApi { public static Builder builder() { return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder() - .setRunningMode(RunningMode.AUDIO_CLIPS); + .setRunningMode(RunningMode.AUDIO_CLIPS) + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** @@ -385,12 +432,21 @@ public final class AudioClassifier extends BaseAudioTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder = AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index 023a1f286..f9c8e7c76 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -49,10 +49,10 @@ android_library( "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", "//third_party:autovalue", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java index 341d6bf91..0ea91a9f8 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java @@ -24,7 +24,7 @@ import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.OutputHandler; import com.google.mediapipe.tasks.core.TaskInfo; @@ -216,20 +216,79 @@ public final class TextClassifier implements AutoCloseable { public abstract Builder setBaseOptions(BaseOptions value); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); - public abstract TextClassifierOptions build(); + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); + + abstract TextClassifierOptions autoBuild(); + + /** + * Validates and builds the {@link TextClassifierOptions} instance. + * + * @throws IllegalArgumentException if any of the set options are invalid. + */ + public final TextClassifierOptions build() { + TextClassifierOptions options = autoBuild(); + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } + return options; + } } abstract BaseOptions baseOptions(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); public static Builder builder() { - return new AutoValue_TextClassifier_TextClassifierOptions.Builder(); + return new AutoValue_TextClassifier_TextClassifierOptions.Builder() + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */ @@ -238,12 +297,21 @@ public final class TextClassifier implements AutoCloseable { BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder = TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index b7febb118..2d130ff05 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -98,10 +98,10 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index 5e278804b..8990f46fd 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -27,7 +27,7 @@ import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; @@ -376,10 +376,42 @@ public final class ImageClassifier extends BaseVisionTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); + + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); /** * Sets the {@link ResultListener} to receive the classification results asynchronously when @@ -396,9 +428,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { /** * Validates and builds the {@link ImageClassifierOptions} instance. * * - * @throws IllegalArgumentException if the result listener and the running mode are not - * properly configured. The result listener should only be set when the image classifier - * is in the live stream mode. + * @throws IllegalArgumentException if any of the set options are invalid. */ public final ImageClassifierOptions build() { ImageClassifierOptions options = autoBuild(); @@ -413,6 +443,13 @@ public final class ImageClassifier extends BaseVisionTaskApi { "The image classifier is in the image or video mode, a user-defined result listener" + " shouldn't be provided in ImageClassifierOptions."); } + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } return options; } } @@ -421,7 +458,15 @@ public final class ImageClassifier extends BaseVisionTaskApi { abstract RunningMode runningMode(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); abstract Optional> resultListener(); @@ -429,7 +474,9 @@ public final class ImageClassifier extends BaseVisionTaskApi { public static Builder builder() { return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder() - .setRunningMode(RunningMode.IMAGE); + .setRunningMode(RunningMode.IMAGE) + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** @@ -441,12 +488,21 @@ public final class ImageClassifier extends BaseVisionTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder = ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java index 5e03d2a4c..5ed413f6a 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java @@ -40,6 +40,37 @@ public class TextClassifierTest { private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate"; private static final String POSITIVE_TEXT = "it's a charming and often affecting journey"; + @Test + public void options_failsWithNegativeMaxResults() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + TextClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build()) + .setMaxResults(-1) + .build()); + assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0"); + } + + @Test + public void options_failsWithBothAllowlistAndDenylist() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + TextClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build()) + .setCategoryAllowlist(Arrays.asList("foo")) + .setCategoryDenylist(Arrays.asList("bar")) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("Category allowlist and denylist are mutually exclusive"); + } + @Test public void create_failsWithMissingModel() throws Exception { String nonExistentFile = "/path/to/non/existent/file"; diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java index 69820ce2d..dac11bf02 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java @@ -26,7 +26,6 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.TestUtils; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; @@ -55,6 +54,37 @@ public class ImageClassifierTest { @RunWith(AndroidJUnit4.class) public static final class General extends ImageClassifierTest { + @Test + public void options_failsWithNegativeMaxResults() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setMaxResults(-1) + .build()); + assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0"); + } + + @Test + public void options_failsWithBothAllowlistAndDenylist() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setCategoryAllowlist(Arrays.asList("foo")) + .setCategoryDenylist(Arrays.asList("bar")) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("Category allowlist and denylist are mutually exclusive"); + } + @Test public void create_failsWithMissingModel() throws Exception { String nonExistentFile = "/path/to/non/existent/file"; @@ -105,7 +135,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) + .setMaxResults(3) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -125,7 +155,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -141,7 +171,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setScoreThreshold(0.02f).build()) + .setScoreThreshold(0.02f) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -160,10 +190,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions( - ClassifierOptions.builder() - .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf")) - .build()) + .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf")) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -183,11 +210,8 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions( - ClassifierOptions.builder() - .setMaxResults(3) - .setCategoryDenylist(Arrays.asList("bagel")) - .build()) + .setMaxResults(3) + .setCategoryDenylist(Arrays.asList("bagel")) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -207,7 +231,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -228,7 +252,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) + .setMaxResults(3) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -251,7 +275,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -322,14 +346,14 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyForVideo( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, () -> imageClassifier.classifyAsync( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -353,7 +377,7 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyAsync( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -379,7 +403,7 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyForVideo( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -388,7 +412,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -405,13 +429,14 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.VIDEO) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); for (int i = 0; i < 3; i++) { - ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i); + ImageClassifierResult results = + imageClassifier.classifyForVideo(image, /* timestampMs= */ i); assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); @@ -424,7 +449,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageClassificationResult, inputImage) -> { @@ -436,11 +461,11 @@ public class ImageClassifierTest { .build(); try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1); + imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 1); MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyAsync(image, /*timestampMs=*/ 0)); + () -> imageClassifier.classifyAsync(image, /* timestampMs= */ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -453,7 +478,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageClassificationResult, inputImage) -> { @@ -466,7 +491,7 @@ public class ImageClassifierTest { try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; ++i) { - imageClassifier.classifyAsync(image, /*timestampMs=*/ i); + imageClassifier.classifyAsync(image, /* timestampMs= */ i); } } }