Inline formerly nested 'ClassifierOptions' in Java classifier APIs.

PiperOrigin-RevId: 492173060
This commit is contained in:
MediaPipe Team 2022-12-01 05:50:46 -08:00 committed by Copybara-Service
parent 460aee7933
commit 29c7702984
8 changed files with 305 additions and 69 deletions

View File

@ -66,10 +66,10 @@ android_library(
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",

View File

@ -27,7 +27,7 @@ import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi;
import com.google.mediapipe.tasks.audio.core.RunningMode;
import com.google.mediapipe.tasks.components.containers.AudioData;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler;
@ -320,10 +320,42 @@ public final class AudioClassifier extends BaseAudioTaskApi {
public abstract Builder setRunningMode(RunningMode runningMode);
/**
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
* score threshold, number of results, etc.
* Sets the optional locale to use for display names specified through the TFLite Model
* Metadata, if any.
*/
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
public abstract Builder setDisplayNamesLocale(String locale);
/**
* Sets the optional maximum number of top-scored classification results to return.
*
* <p>If not set, all available results are returned. If set, must be > 0.
*/
public abstract Builder setMaxResults(Integer maxResults);
/**
* Sets the optional score threshold. Results with score below this value are rejected.
*
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
*/
public abstract Builder setScoreThreshold(Float scoreThreshold);
/**
* Sets the optional allowlist of category names.
*
* <p>If non-empty, detection results whose category name is not in this set will be filtered
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryDenylist}.
*/
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
/**
* Sets the optional denylist of category names.
*
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryAllowlist}.
*/
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
/**
* Sets the {@link ResultListener} to receive the classification results asynchronously when
@ -340,9 +372,7 @@ public final class AudioClassifier extends BaseAudioTaskApi {
/**
* Validates and builds the {@link AudioClassifierOptions} instance.
*
* @throws IllegalArgumentException if the result listener and the running mode are not
* properly configured. The result listener should only be set when the audio classifier
* is in the audio stream mode.
* @throws IllegalArgumentException if any of the set options are invalid.
*/
public final AudioClassifierOptions build() {
AudioClassifierOptions options = autoBuild();
@ -357,6 +387,13 @@ public final class AudioClassifier extends BaseAudioTaskApi {
"The audio classifier is in the audio clips mode, a user-defined result listener"
+ " shouldn't be provided in AudioClassifierOptions.");
}
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
throw new IllegalArgumentException("If specified, maxResults must be > 0.");
}
if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) {
throw new IllegalArgumentException(
"Category allowlist and denylist are mutually exclusive.");
}
return options;
}
}
@ -365,7 +402,15 @@ public final class AudioClassifier extends BaseAudioTaskApi {
abstract RunningMode runningMode();
abstract Optional<ClassifierOptions> classifierOptions();
abstract Optional<String> displayNamesLocale();
abstract Optional<Integer> maxResults();
abstract Optional<Float> scoreThreshold();
abstract List<String> categoryAllowlist();
abstract List<String> categoryDenylist();
abstract Optional<PureResultListener<AudioClassifierResult>> resultListener();
@ -373,7 +418,9 @@ public final class AudioClassifier extends BaseAudioTaskApi {
public static Builder builder() {
return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder()
.setRunningMode(RunningMode.AUDIO_CLIPS);
.setRunningMode(RunningMode.AUDIO_CLIPS)
.setCategoryAllowlist(Collections.emptyList())
.setCategoryDenylist(Collections.emptyList());
}
/**
@ -385,12 +432,21 @@ public final class AudioClassifier extends BaseAudioTaskApi {
BaseOptionsProto.BaseOptions.newBuilder();
baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM);
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
ClassifierOptionsProto.ClassifierOptions.newBuilder();
displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale);
maxResults().ifPresent(classifierOptionsBuilder::setMaxResults);
scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold);
if (!categoryAllowlist().isEmpty()) {
classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
}
if (!categoryDenylist().isEmpty()) {
classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
}
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder =
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder);
if (classifierOptions().isPresent()) {
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
}
.setBaseOptions(baseOptionsBuilder)
.setClassifierOptions(classifierOptionsBuilder);
return CalculatorOptions.newBuilder()
.setExtension(
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext,

View File

@ -49,10 +49,10 @@ android_library(
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
"//third_party:autovalue",

View File

@ -24,7 +24,7 @@ import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.TaskInfo;
@ -216,20 +216,79 @@ public final class TextClassifier implements AutoCloseable {
public abstract Builder setBaseOptions(BaseOptions value);
/**
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
* score threshold, number of results, etc.
* Sets the optional locale to use for display names specified through the TFLite Model
* Metadata, if any.
*/
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
public abstract Builder setDisplayNamesLocale(String locale);
public abstract TextClassifierOptions build();
/**
* Sets the optional maximum number of top-scored classification results to return.
*
* <p>If not set, all available results are returned. If set, must be > 0.
*/
public abstract Builder setMaxResults(Integer maxResults);
/**
* Sets the optional score threshold. Results with score below this value are rejected.
*
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
*/
public abstract Builder setScoreThreshold(Float scoreThreshold);
/**
* Sets the optional allowlist of category names.
*
* <p>If non-empty, detection results whose category name is not in this set will be filtered
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryDenylist}.
*/
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
/**
* Sets the optional denylist of category names.
*
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryAllowlist}.
*/
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
abstract TextClassifierOptions autoBuild();
/**
* Validates and builds the {@link TextClassifierOptions} instance.
*
* @throws IllegalArgumentException if any of the set options are invalid.
*/
public final TextClassifierOptions build() {
TextClassifierOptions options = autoBuild();
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
throw new IllegalArgumentException("If specified, maxResults must be > 0.");
}
if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) {
throw new IllegalArgumentException(
"Category allowlist and denylist are mutually exclusive.");
}
return options;
}
}
abstract BaseOptions baseOptions();
abstract Optional<ClassifierOptions> classifierOptions();
abstract Optional<String> displayNamesLocale();
abstract Optional<Integer> maxResults();
abstract Optional<Float> scoreThreshold();
abstract List<String> categoryAllowlist();
abstract List<String> categoryDenylist();
public static Builder builder() {
return new AutoValue_TextClassifier_TextClassifierOptions.Builder();
return new AutoValue_TextClassifier_TextClassifierOptions.Builder()
.setCategoryAllowlist(Collections.emptyList())
.setCategoryDenylist(Collections.emptyList());
}
/** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */
@ -238,12 +297,21 @@ public final class TextClassifier implements AutoCloseable {
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
BaseOptionsProto.BaseOptions.newBuilder();
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
ClassifierOptionsProto.ClassifierOptions.newBuilder();
displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale);
maxResults().ifPresent(classifierOptionsBuilder::setMaxResults);
scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold);
if (!categoryAllowlist().isEmpty()) {
classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
}
if (!categoryDenylist().isEmpty()) {
classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
}
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder =
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder);
if (classifierOptions().isPresent()) {
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
}
.setBaseOptions(baseOptionsBuilder)
.setClassifierOptions(classifierOptionsBuilder);
return CalculatorOptions.newBuilder()
.setExtension(
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext,

View File

@ -98,10 +98,10 @@ android_library(
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",

View File

@ -27,7 +27,7 @@ import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler;
@ -376,10 +376,42 @@ public final class ImageClassifier extends BaseVisionTaskApi {
public abstract Builder setRunningMode(RunningMode runningMode);
/**
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
* score threshold, number of results, etc.
* Sets the optional locale to use for display names specified through the TFLite Model
* Metadata, if any.
*/
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
public abstract Builder setDisplayNamesLocale(String locale);
/**
* Sets the optional maximum number of top-scored classification results to return.
*
* <p>If not set, all available results are returned. If set, must be > 0.
*/
public abstract Builder setMaxResults(Integer maxResults);
/**
* Sets the optional score threshold. Results with score below this value are rejected.
*
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
*/
public abstract Builder setScoreThreshold(Float scoreThreshold);
/**
* Sets the optional allowlist of category names.
*
* <p>If non-empty, detection results whose category name is not in this set will be filtered
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryDenylist}.
*/
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
/**
* Sets the optional denylist of category names.
*
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryAllowlist}.
*/
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
/**
* Sets the {@link ResultListener} to receive the classification results asynchronously when
@ -396,9 +428,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
/**
* Validates and builds the {@link ImageClassifierOptions} instance. *
*
* @throws IllegalArgumentException if the result listener and the running mode are not
* properly configured. The result listener should only be set when the image classifier
* is in the live stream mode.
* @throws IllegalArgumentException if any of the set options are invalid.
*/
public final ImageClassifierOptions build() {
ImageClassifierOptions options = autoBuild();
@ -413,6 +443,13 @@ public final class ImageClassifier extends BaseVisionTaskApi {
"The image classifier is in the image or video mode, a user-defined result listener"
+ " shouldn't be provided in ImageClassifierOptions.");
}
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
throw new IllegalArgumentException("If specified, maxResults must be > 0.");
}
if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) {
throw new IllegalArgumentException(
"Category allowlist and denylist are mutually exclusive.");
}
return options;
}
}
@ -421,7 +458,15 @@ public final class ImageClassifier extends BaseVisionTaskApi {
abstract RunningMode runningMode();
abstract Optional<ClassifierOptions> classifierOptions();
abstract Optional<String> displayNamesLocale();
abstract Optional<Integer> maxResults();
abstract Optional<Float> scoreThreshold();
abstract List<String> categoryAllowlist();
abstract List<String> categoryDenylist();
abstract Optional<ResultListener<ImageClassifierResult, MPImage>> resultListener();
@ -429,7 +474,9 @@ public final class ImageClassifier extends BaseVisionTaskApi {
public static Builder builder() {
return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder()
.setRunningMode(RunningMode.IMAGE);
.setRunningMode(RunningMode.IMAGE)
.setCategoryAllowlist(Collections.emptyList())
.setCategoryDenylist(Collections.emptyList());
}
/**
@ -441,12 +488,21 @@ public final class ImageClassifier extends BaseVisionTaskApi {
BaseOptionsProto.BaseOptions.newBuilder();
baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
ClassifierOptionsProto.ClassifierOptions.newBuilder();
displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale);
maxResults().ifPresent(classifierOptionsBuilder::setMaxResults);
scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold);
if (!categoryAllowlist().isEmpty()) {
classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
}
if (!categoryDenylist().isEmpty()) {
classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
}
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder =
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder);
if (classifierOptions().isPresent()) {
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
}
.setBaseOptions(baseOptionsBuilder)
.setClassifierOptions(classifierOptionsBuilder);
return CalculatorOptions.newBuilder()
.setExtension(
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext,

View File

@ -40,6 +40,37 @@ public class TextClassifierTest {
private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate";
private static final String POSITIVE_TEXT = "it's a charming and often affecting journey";
@Test
public void options_failsWithNegativeMaxResults() throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
TextClassifierOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build())
.setMaxResults(-1)
.build());
assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0");
}
@Test
public void options_failsWithBothAllowlistAndDenylist() throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
TextClassifierOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build())
.setCategoryAllowlist(Arrays.asList("foo"))
.setCategoryDenylist(Arrays.asList("bar"))
.build());
assertThat(exception)
.hasMessageThat()
.contains("Category allowlist and denylist are mutually exclusive");
}
@Test
public void create_failsWithMissingModel() throws Exception {
String nonExistentFile = "/path/to/non/existent/file";

View File

@ -26,7 +26,6 @@ import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.TestUtils;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
@ -55,6 +54,37 @@ public class ImageClassifierTest {
@RunWith(AndroidJUnit4.class)
public static final class General extends ImageClassifierTest {
@Test
public void options_failsWithNegativeMaxResults() throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
ImageClassifierOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setMaxResults(-1)
.build());
assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0");
}
@Test
public void options_failsWithBothAllowlistAndDenylist() throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
ImageClassifierOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setCategoryAllowlist(Arrays.asList("foo"))
.setCategoryDenylist(Arrays.asList("bar"))
.build());
assertThat(exception)
.hasMessageThat()
.contains("Category allowlist and denylist are mutually exclusive");
}
@Test
public void create_failsWithMissingModel() throws Exception {
String nonExistentFile = "/path/to/non/existent/file";
@ -105,7 +135,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build())
.setMaxResults(3)
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -125,7 +155,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
.setMaxResults(1)
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -141,7 +171,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setScoreThreshold(0.02f).build())
.setScoreThreshold(0.02f)
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -160,10 +190,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(
ClassifierOptions.builder()
.setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf"))
.build())
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -183,11 +210,8 @@ public class ImageClassifierTest {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(
ClassifierOptions.builder()
.setMaxResults(3)
.setCategoryDenylist(Arrays.asList("bagel"))
.build())
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -207,7 +231,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
.setMaxResults(1)
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -228,7 +252,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build())
.setMaxResults(3)
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -251,7 +275,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
.setMaxResults(1)
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -322,14 +346,14 @@ public class ImageClassifierTest {
MediaPipeException.class,
() ->
imageClassifier.classifyForVideo(
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
exception =
assertThrows(
MediaPipeException.class,
() ->
imageClassifier.classifyAsync(
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@ -353,7 +377,7 @@ public class ImageClassifierTest {
MediaPipeException.class,
() ->
imageClassifier.classifyAsync(
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@ -379,7 +403,7 @@ public class ImageClassifierTest {
MediaPipeException.class,
() ->
imageClassifier.classifyForVideo(
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
}
@ -388,7 +412,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
.setMaxResults(1)
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -405,13 +429,14 @@ public class ImageClassifierTest {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
.setMaxResults(1)
.setRunningMode(RunningMode.VIDEO)
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
for (int i = 0; i < 3; i++) {
ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i);
ImageClassifierResult results =
imageClassifier.classifyForVideo(image, /* timestampMs= */ i);
assertHasOneHead(results);
assertCategoriesAre(
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
@ -424,7 +449,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
.setMaxResults(1)
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(imageClassificationResult, inputImage) -> {
@ -436,11 +461,11 @@ public class ImageClassifierTest {
.build();
try (ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1);
imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 1);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> imageClassifier.classifyAsync(image, /*timestampMs=*/ 0));
() -> imageClassifier.classifyAsync(image, /* timestampMs= */ 0));
assertThat(exception)
.hasMessageThat()
.contains("having a smaller timestamp than the processed timestamp");
@ -453,7 +478,7 @@ public class ImageClassifierTest {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
.setMaxResults(1)
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(imageClassificationResult, inputImage) -> {
@ -466,7 +491,7 @@ public class ImageClassifierTest {
try (ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
for (int i = 0; i < 3; ++i) {
imageClassifier.classifyAsync(image, /*timestampMs=*/ i);
imageClassifier.classifyAsync(image, /* timestampMs= */ i);
}
}
}