Inline formerly nested 'ClassifierOptions' in Java classifier APIs.

PiperOrigin-RevId: 492173060
This commit is contained in:
MediaPipe Team 2022-12-01 05:50:46 -08:00 committed by Copybara-Service
parent 460aee7933
commit 29c7702984
8 changed files with 305 additions and 69 deletions

View File

@ -66,10 +66,10 @@ android_library(
"//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue", "//third_party:autovalue",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",

View File

@ -27,7 +27,7 @@ import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi;
import com.google.mediapipe.tasks.audio.core.RunningMode; import com.google.mediapipe.tasks.audio.core.RunningMode;
import com.google.mediapipe.tasks.components.containers.AudioData; import com.google.mediapipe.tasks.components.containers.AudioData;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler; import com.google.mediapipe.tasks.core.OutputHandler;
@ -320,10 +320,42 @@ public final class AudioClassifier extends BaseAudioTaskApi {
public abstract Builder setRunningMode(RunningMode runningMode); public abstract Builder setRunningMode(RunningMode runningMode);
/** /**
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as * Sets the optional locale to use for display names specified through the TFLite Model
* score threshold, number of results, etc. * Metadata, if any.
*/ */
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); public abstract Builder setDisplayNamesLocale(String locale);
/**
* Sets the optional maximum number of top-scored classification results to return.
*
* <p>If not set, all available results are returned. If set, must be > 0.
*/
public abstract Builder setMaxResults(Integer maxResults);
/**
* Sets the optional score threshold. Results with score below this value are rejected.
*
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
*/
public abstract Builder setScoreThreshold(Float scoreThreshold);
/**
* Sets the optional allowlist of category names.
*
* <p>If non-empty, detection results whose category name is not in this set will be filtered
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryDenylist}.
*/
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
/**
* Sets the optional denylist of category names.
*
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryAllowlist}.
*/
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
/** /**
* Sets the {@link ResultListener} to receive the classification results asynchronously when * Sets the {@link ResultListener} to receive the classification results asynchronously when
@ -340,9 +372,7 @@ public final class AudioClassifier extends BaseAudioTaskApi {
/** /**
* Validates and builds the {@link AudioClassifierOptions} instance. * Validates and builds the {@link AudioClassifierOptions} instance.
* *
* @throws IllegalArgumentException if the result listener and the running mode are not * @throws IllegalArgumentException if any of the set options are invalid.
* properly configured. The result listener should only be set when the audio classifier
* is in the audio stream mode.
*/ */
public final AudioClassifierOptions build() { public final AudioClassifierOptions build() {
AudioClassifierOptions options = autoBuild(); AudioClassifierOptions options = autoBuild();
@ -357,6 +387,13 @@ public final class AudioClassifier extends BaseAudioTaskApi {
"The audio classifier is in the audio clips mode, a user-defined result listener" "The audio classifier is in the audio clips mode, a user-defined result listener"
+ " shouldn't be provided in AudioClassifierOptions."); + " shouldn't be provided in AudioClassifierOptions.");
} }
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
throw new IllegalArgumentException("If specified, maxResults must be > 0.");
}
if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) {
throw new IllegalArgumentException(
"Category allowlist and denylist are mutually exclusive.");
}
return options; return options;
} }
} }
@ -365,7 +402,15 @@ public final class AudioClassifier extends BaseAudioTaskApi {
abstract RunningMode runningMode(); abstract RunningMode runningMode();
abstract Optional<ClassifierOptions> classifierOptions(); abstract Optional<String> displayNamesLocale();
abstract Optional<Integer> maxResults();
abstract Optional<Float> scoreThreshold();
abstract List<String> categoryAllowlist();
abstract List<String> categoryDenylist();
abstract Optional<PureResultListener<AudioClassifierResult>> resultListener(); abstract Optional<PureResultListener<AudioClassifierResult>> resultListener();
@ -373,7 +418,9 @@ public final class AudioClassifier extends BaseAudioTaskApi {
public static Builder builder() { public static Builder builder() {
return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder() return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder()
.setRunningMode(RunningMode.AUDIO_CLIPS); .setRunningMode(RunningMode.AUDIO_CLIPS)
.setCategoryAllowlist(Collections.emptyList())
.setCategoryDenylist(Collections.emptyList());
} }
/** /**
@ -385,12 +432,21 @@ public final class AudioClassifier extends BaseAudioTaskApi {
BaseOptionsProto.BaseOptions.newBuilder(); BaseOptionsProto.BaseOptions.newBuilder();
baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM); baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM);
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
ClassifierOptionsProto.ClassifierOptions.newBuilder();
displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale);
maxResults().ifPresent(classifierOptionsBuilder::setMaxResults);
scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold);
if (!categoryAllowlist().isEmpty()) {
classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
}
if (!categoryDenylist().isEmpty()) {
classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
}
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder = AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder =
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder() AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder); .setBaseOptions(baseOptionsBuilder)
if (classifierOptions().isPresent()) { .setClassifierOptions(classifierOptionsBuilder);
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
}
return CalculatorOptions.newBuilder() return CalculatorOptions.newBuilder()
.setExtension( .setExtension(
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext, AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext,

View File

@ -49,10 +49,10 @@ android_library(
"//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
"//third_party:autovalue", "//third_party:autovalue",

View File

@ -24,7 +24,7 @@ import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.ClassificationResult;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.OutputHandler; import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.TaskInfo; import com.google.mediapipe.tasks.core.TaskInfo;
@ -216,20 +216,79 @@ public final class TextClassifier implements AutoCloseable {
public abstract Builder setBaseOptions(BaseOptions value); public abstract Builder setBaseOptions(BaseOptions value);
/** /**
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as * Sets the optional locale to use for display names specified through the TFLite Model
* score threshold, number of results, etc. * Metadata, if any.
*/ */
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); public abstract Builder setDisplayNamesLocale(String locale);
public abstract TextClassifierOptions build(); /**
* Sets the optional maximum number of top-scored classification results to return.
*
* <p>If not set, all available results are returned. If set, must be > 0.
*/
public abstract Builder setMaxResults(Integer maxResults);
/**
* Sets the optional score threshold. Results with score below this value are rejected.
*
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
*/
public abstract Builder setScoreThreshold(Float scoreThreshold);
/**
* Sets the optional allowlist of category names.
*
* <p>If non-empty, detection results whose category name is not in this set will be filtered
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryDenylist}.
*/
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
/**
* Sets the optional denylist of category names.
*
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryAllowlist}.
*/
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
abstract TextClassifierOptions autoBuild();
/**
* Validates and builds the {@link TextClassifierOptions} instance.
*
* @throws IllegalArgumentException if any of the set options are invalid.
*/
public final TextClassifierOptions build() {
TextClassifierOptions options = autoBuild();
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
throw new IllegalArgumentException("If specified, maxResults must be > 0.");
}
if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) {
throw new IllegalArgumentException(
"Category allowlist and denylist are mutually exclusive.");
}
return options;
}
} }
abstract BaseOptions baseOptions(); abstract BaseOptions baseOptions();
abstract Optional<ClassifierOptions> classifierOptions(); abstract Optional<String> displayNamesLocale();
abstract Optional<Integer> maxResults();
abstract Optional<Float> scoreThreshold();
abstract List<String> categoryAllowlist();
abstract List<String> categoryDenylist();
public static Builder builder() { public static Builder builder() {
return new AutoValue_TextClassifier_TextClassifierOptions.Builder(); return new AutoValue_TextClassifier_TextClassifierOptions.Builder()
.setCategoryAllowlist(Collections.emptyList())
.setCategoryDenylist(Collections.emptyList());
} }
/** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */ /** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */
@ -238,12 +297,21 @@ public final class TextClassifier implements AutoCloseable {
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
BaseOptionsProto.BaseOptions.newBuilder(); BaseOptionsProto.BaseOptions.newBuilder();
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
ClassifierOptionsProto.ClassifierOptions.newBuilder();
displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale);
maxResults().ifPresent(classifierOptionsBuilder::setMaxResults);
scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold);
if (!categoryAllowlist().isEmpty()) {
classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
}
if (!categoryDenylist().isEmpty()) {
classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
}
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder = TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder =
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder() TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder); .setBaseOptions(baseOptionsBuilder)
if (classifierOptions().isPresent()) { .setClassifierOptions(classifierOptionsBuilder);
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
}
return CalculatorOptions.newBuilder() return CalculatorOptions.newBuilder()
.setExtension( .setExtension(
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext, TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext,

View File

@ -98,10 +98,10 @@ android_library(
"//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue", "//third_party:autovalue",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",

View File

@ -27,7 +27,7 @@ import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.ClassificationResult;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler; import com.google.mediapipe.tasks.core.OutputHandler;
@ -376,10 +376,42 @@ public final class ImageClassifier extends BaseVisionTaskApi {
public abstract Builder setRunningMode(RunningMode runningMode); public abstract Builder setRunningMode(RunningMode runningMode);
/** /**
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as * Sets the optional locale to use for display names specified through the TFLite Model
* score threshold, number of results, etc. * Metadata, if any.
*/ */
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); public abstract Builder setDisplayNamesLocale(String locale);
/**
* Sets the optional maximum number of top-scored classification results to return.
*
* <p>If not set, all available results are returned. If set, must be > 0.
*/
public abstract Builder setMaxResults(Integer maxResults);
/**
* Sets the optional score threshold. Results with score below this value are rejected.
*
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
*/
public abstract Builder setScoreThreshold(Float scoreThreshold);
/**
* Sets the optional allowlist of category names.
*
* <p>If non-empty, detection results whose category name is not in this set will be filtered
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryDenylist}.
*/
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
/**
* Sets the optional denylist of category names.
*
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryAllowlist}.
*/
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
/** /**
* Sets the {@link ResultListener} to receive the classification results asynchronously when * Sets the {@link ResultListener} to receive the classification results asynchronously when
@ -396,9 +428,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
/** /**
* Validates and builds the {@link ImageClassifierOptions} instance. * * Validates and builds the {@link ImageClassifierOptions} instance. *
* *
* @throws IllegalArgumentException if the result listener and the running mode are not * @throws IllegalArgumentException if any of the set options are invalid.
* properly configured. The result listener should only be set when the image classifier
* is in the live stream mode.
*/ */
public final ImageClassifierOptions build() { public final ImageClassifierOptions build() {
ImageClassifierOptions options = autoBuild(); ImageClassifierOptions options = autoBuild();
@ -413,6 +443,13 @@ public final class ImageClassifier extends BaseVisionTaskApi {
"The image classifier is in the image or video mode, a user-defined result listener" "The image classifier is in the image or video mode, a user-defined result listener"
+ " shouldn't be provided in ImageClassifierOptions."); + " shouldn't be provided in ImageClassifierOptions.");
} }
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
throw new IllegalArgumentException("If specified, maxResults must be > 0.");
}
if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) {
throw new IllegalArgumentException(
"Category allowlist and denylist are mutually exclusive.");
}
return options; return options;
} }
} }
@ -421,7 +458,15 @@ public final class ImageClassifier extends BaseVisionTaskApi {
abstract RunningMode runningMode(); abstract RunningMode runningMode();
abstract Optional<ClassifierOptions> classifierOptions(); abstract Optional<String> displayNamesLocale();
abstract Optional<Integer> maxResults();
abstract Optional<Float> scoreThreshold();
abstract List<String> categoryAllowlist();
abstract List<String> categoryDenylist();
abstract Optional<ResultListener<ImageClassifierResult, MPImage>> resultListener(); abstract Optional<ResultListener<ImageClassifierResult, MPImage>> resultListener();
@ -429,7 +474,9 @@ public final class ImageClassifier extends BaseVisionTaskApi {
public static Builder builder() { public static Builder builder() {
return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder() return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder()
.setRunningMode(RunningMode.IMAGE); .setRunningMode(RunningMode.IMAGE)
.setCategoryAllowlist(Collections.emptyList())
.setCategoryDenylist(Collections.emptyList());
} }
/** /**
@ -441,12 +488,21 @@ public final class ImageClassifier extends BaseVisionTaskApi {
BaseOptionsProto.BaseOptions.newBuilder(); BaseOptionsProto.BaseOptions.newBuilder();
baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
ClassifierOptionsProto.ClassifierOptions.newBuilder();
displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale);
maxResults().ifPresent(classifierOptionsBuilder::setMaxResults);
scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold);
if (!categoryAllowlist().isEmpty()) {
classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
}
if (!categoryDenylist().isEmpty()) {
classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
}
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder = ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder =
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder() ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder); .setBaseOptions(baseOptionsBuilder)
if (classifierOptions().isPresent()) { .setClassifierOptions(classifierOptionsBuilder);
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
}
return CalculatorOptions.newBuilder() return CalculatorOptions.newBuilder()
.setExtension( .setExtension(
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext, ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext,

View File

@ -40,6 +40,37 @@ public class TextClassifierTest {
private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate"; private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate";
private static final String POSITIVE_TEXT = "it's a charming and often affecting journey"; private static final String POSITIVE_TEXT = "it's a charming and often affecting journey";
@Test
public void options_failsWithNegativeMaxResults() throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
TextClassifierOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build())
.setMaxResults(-1)
.build());
assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0");
}
@Test
public void options_failsWithBothAllowlistAndDenylist() throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
TextClassifierOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build())
.setCategoryAllowlist(Arrays.asList("foo"))
.setCategoryDenylist(Arrays.asList("bar"))
.build());
assertThat(exception)
.hasMessageThat()
.contains("Category allowlist and denylist are mutually exclusive");
}
@Test @Test
public void create_failsWithMissingModel() throws Exception { public void create_failsWithMissingModel() throws Exception {
String nonExistentFile = "/path/to/non/existent/file"; String nonExistentFile = "/path/to/non/existent/file";

View File

@ -26,7 +26,6 @@ import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.TestUtils; import com.google.mediapipe.tasks.core.TestUtils;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
@ -55,6 +54,37 @@ public class ImageClassifierTest {
@RunWith(AndroidJUnit4.class) @RunWith(AndroidJUnit4.class)
public static final class General extends ImageClassifierTest { public static final class General extends ImageClassifierTest {
@Test
public void options_failsWithNegativeMaxResults() throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
ImageClassifierOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setMaxResults(-1)
.build());
assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0");
}
@Test
public void options_failsWithBothAllowlistAndDenylist() throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
ImageClassifierOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setCategoryAllowlist(Arrays.asList("foo"))
.setCategoryDenylist(Arrays.asList("bar"))
.build());
assertThat(exception)
.hasMessageThat()
.contains("Category allowlist and denylist are mutually exclusive");
}
@Test @Test
public void create_failsWithMissingModel() throws Exception { public void create_failsWithMissingModel() throws Exception {
String nonExistentFile = "/path/to/non/existent/file"; String nonExistentFile = "/path/to/non/existent/file";
@ -105,7 +135,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options = ImageClassifierOptions options =
ImageClassifierOptions.builder() ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) .setMaxResults(3)
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -125,7 +155,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options = ImageClassifierOptions options =
ImageClassifierOptions.builder() ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) .setMaxResults(1)
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -141,7 +171,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options = ImageClassifierOptions options =
ImageClassifierOptions.builder() ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setScoreThreshold(0.02f).build()) .setScoreThreshold(0.02f)
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -160,10 +190,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options = ImageClassifierOptions options =
ImageClassifierOptions.builder() ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(
ClassifierOptions.builder()
.setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf")) .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf"))
.build())
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -183,11 +210,8 @@ public class ImageClassifierTest {
ImageClassifierOptions options = ImageClassifierOptions options =
ImageClassifierOptions.builder() ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(
ClassifierOptions.builder()
.setMaxResults(3) .setMaxResults(3)
.setCategoryDenylist(Arrays.asList("bagel")) .setCategoryDenylist(Arrays.asList("bagel"))
.build())
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -207,7 +231,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options = ImageClassifierOptions options =
ImageClassifierOptions.builder() ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) .setMaxResults(1)
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -228,7 +252,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options = ImageClassifierOptions options =
ImageClassifierOptions.builder() ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) .setMaxResults(3)
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -251,7 +275,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options = ImageClassifierOptions options =
ImageClassifierOptions.builder() ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) .setMaxResults(1)
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -388,7 +412,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options = ImageClassifierOptions options =
ImageClassifierOptions.builder() ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) .setMaxResults(1)
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -405,13 +429,14 @@ public class ImageClassifierTest {
ImageClassifierOptions options = ImageClassifierOptions options =
ImageClassifierOptions.builder() ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) .setMaxResults(1)
.setRunningMode(RunningMode.VIDEO) .setRunningMode(RunningMode.VIDEO)
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i); ImageClassifierResult results =
imageClassifier.classifyForVideo(image, /* timestampMs= */ i);
assertHasOneHead(results); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
@ -424,7 +449,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options = ImageClassifierOptions options =
ImageClassifierOptions.builder() ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) .setMaxResults(1)
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener( .setResultListener(
(imageClassificationResult, inputImage) -> { (imageClassificationResult, inputImage) -> {
@ -453,7 +478,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options = ImageClassifierOptions options =
ImageClassifierOptions.builder() ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) .setMaxResults(1)
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener( .setResultListener(
(imageClassificationResult, inputImage) -> { (imageClassificationResult, inputImage) -> {