Inline formerly nested 'ClassifierOptions' in Java classifier APIs.
PiperOrigin-RevId: 492173060
This commit is contained in:
parent
460aee7933
commit
29c7702984
|
@ -66,10 +66,10 @@ android_library(
|
|||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
|
|
|
@ -27,7 +27,7 @@ import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi;
|
|||
import com.google.mediapipe.tasks.audio.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.components.containers.AudioData;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||
import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
|
@ -266,7 +266,7 @@ public final class AudioClassifier extends BaseAudioTaskApi {
|
|||
|
||||
/*
|
||||
* Sends audio data (a block in a continuous audio stream) to perform audio classification, and
|
||||
* the results will be available via the {@link ResultListener} provided in the
|
||||
* the results will be available via the {@link ResultListener} provided in the
|
||||
* {@link AudioClassifierOptions}. Only use this method when the AudioClassifier is created with
|
||||
* the audio stream mode.
|
||||
*
|
||||
|
@ -320,10 +320,42 @@ public final class AudioClassifier extends BaseAudioTaskApi {
|
|||
public abstract Builder setRunningMode(RunningMode runningMode);
|
||||
|
||||
/**
|
||||
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
|
||||
* score threshold, number of results, etc.
|
||||
* Sets the optional locale to use for display names specified through the TFLite Model
|
||||
* Metadata, if any.
|
||||
*/
|
||||
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
|
||||
public abstract Builder setDisplayNamesLocale(String locale);
|
||||
|
||||
/**
|
||||
* Sets the optional maximum number of top-scored classification results to return.
|
||||
*
|
||||
* <p>If not set, all available results are returned. If set, must be > 0.
|
||||
*/
|
||||
public abstract Builder setMaxResults(Integer maxResults);
|
||||
|
||||
/**
|
||||
* Sets the optional score threshold. Results with score below this value are rejected.
|
||||
*
|
||||
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
|
||||
*/
|
||||
public abstract Builder setScoreThreshold(Float scoreThreshold);
|
||||
|
||||
/**
|
||||
* Sets the optional allowlist of category names.
|
||||
*
|
||||
* <p>If non-empty, detection results whose category name is not in this set will be filtered
|
||||
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||
* categoryDenylist}.
|
||||
*/
|
||||
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
|
||||
|
||||
/**
|
||||
* Sets the optional denylist of category names.
|
||||
*
|
||||
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
|
||||
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||
* categoryAllowlist}.
|
||||
*/
|
||||
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
|
||||
|
||||
/**
|
||||
* Sets the {@link ResultListener} to receive the classification results asynchronously when
|
||||
|
@ -340,9 +372,7 @@ public final class AudioClassifier extends BaseAudioTaskApi {
|
|||
/**
|
||||
* Validates and builds the {@link AudioClassifierOptions} instance.
|
||||
*
|
||||
* @throws IllegalArgumentException if the result listener and the running mode are not
|
||||
* properly configured. The result listener should only be set when the audio classifier
|
||||
* is in the audio stream mode.
|
||||
* @throws IllegalArgumentException if any of the set options are invalid.
|
||||
*/
|
||||
public final AudioClassifierOptions build() {
|
||||
AudioClassifierOptions options = autoBuild();
|
||||
|
@ -357,6 +387,13 @@ public final class AudioClassifier extends BaseAudioTaskApi {
|
|||
"The audio classifier is in the audio clips mode, a user-defined result listener"
|
||||
+ " shouldn't be provided in AudioClassifierOptions.");
|
||||
}
|
||||
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
|
||||
throw new IllegalArgumentException("If specified, maxResults must be > 0.");
|
||||
}
|
||||
if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) {
|
||||
throw new IllegalArgumentException(
|
||||
"Category allowlist and denylist are mutually exclusive.");
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
@ -365,7 +402,15 @@ public final class AudioClassifier extends BaseAudioTaskApi {
|
|||
|
||||
abstract RunningMode runningMode();
|
||||
|
||||
abstract Optional<ClassifierOptions> classifierOptions();
|
||||
abstract Optional<String> displayNamesLocale();
|
||||
|
||||
abstract Optional<Integer> maxResults();
|
||||
|
||||
abstract Optional<Float> scoreThreshold();
|
||||
|
||||
abstract List<String> categoryAllowlist();
|
||||
|
||||
abstract List<String> categoryDenylist();
|
||||
|
||||
abstract Optional<PureResultListener<AudioClassifierResult>> resultListener();
|
||||
|
||||
|
@ -373,7 +418,9 @@ public final class AudioClassifier extends BaseAudioTaskApi {
|
|||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder()
|
||||
.setRunningMode(RunningMode.AUDIO_CLIPS);
|
||||
.setRunningMode(RunningMode.AUDIO_CLIPS)
|
||||
.setCategoryAllowlist(Collections.emptyList())
|
||||
.setCategoryDenylist(Collections.emptyList());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -385,12 +432,21 @@ public final class AudioClassifier extends BaseAudioTaskApi {
|
|||
BaseOptionsProto.BaseOptions.newBuilder();
|
||||
baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM);
|
||||
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||
ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
|
||||
ClassifierOptionsProto.ClassifierOptions.newBuilder();
|
||||
displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale);
|
||||
maxResults().ifPresent(classifierOptionsBuilder::setMaxResults);
|
||||
scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold);
|
||||
if (!categoryAllowlist().isEmpty()) {
|
||||
classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
|
||||
}
|
||||
if (!categoryDenylist().isEmpty()) {
|
||||
classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
|
||||
}
|
||||
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder =
|
||||
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder()
|
||||
.setBaseOptions(baseOptionsBuilder);
|
||||
if (classifierOptions().isPresent()) {
|
||||
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
|
||||
}
|
||||
.setBaseOptions(baseOptionsBuilder)
|
||||
.setClassifierOptions(classifierOptionsBuilder);
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext,
|
||||
|
|
|
@ -49,10 +49,10 @@ android_library(
|
|||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
|
||||
"//third_party:autovalue",
|
||||
|
|
|
@ -24,7 +24,7 @@ import com.google.mediapipe.framework.PacketGetter;
|
|||
import com.google.mediapipe.framework.ProtoUtil;
|
||||
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||
import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||
|
@ -216,20 +216,79 @@ public final class TextClassifier implements AutoCloseable {
|
|||
public abstract Builder setBaseOptions(BaseOptions value);
|
||||
|
||||
/**
|
||||
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
|
||||
* score threshold, number of results, etc.
|
||||
* Sets the optional locale to use for display names specified through the TFLite Model
|
||||
* Metadata, if any.
|
||||
*/
|
||||
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
|
||||
public abstract Builder setDisplayNamesLocale(String locale);
|
||||
|
||||
public abstract TextClassifierOptions build();
|
||||
/**
|
||||
* Sets the optional maximum number of top-scored classification results to return.
|
||||
*
|
||||
* <p>If not set, all available results are returned. If set, must be > 0.
|
||||
*/
|
||||
public abstract Builder setMaxResults(Integer maxResults);
|
||||
|
||||
/**
|
||||
* Sets the optional score threshold. Results with score below this value are rejected.
|
||||
*
|
||||
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
|
||||
*/
|
||||
public abstract Builder setScoreThreshold(Float scoreThreshold);
|
||||
|
||||
/**
|
||||
* Sets the optional allowlist of category names.
|
||||
*
|
||||
* <p>If non-empty, detection results whose category name is not in this set will be filtered
|
||||
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||
* categoryDenylist}.
|
||||
*/
|
||||
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
|
||||
|
||||
/**
|
||||
* Sets the optional denylist of category names.
|
||||
*
|
||||
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
|
||||
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||
* categoryAllowlist}.
|
||||
*/
|
||||
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
|
||||
|
||||
abstract TextClassifierOptions autoBuild();
|
||||
|
||||
/**
|
||||
* Validates and builds the {@link TextClassifierOptions} instance.
|
||||
*
|
||||
* @throws IllegalArgumentException if any of the set options are invalid.
|
||||
*/
|
||||
public final TextClassifierOptions build() {
|
||||
TextClassifierOptions options = autoBuild();
|
||||
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
|
||||
throw new IllegalArgumentException("If specified, maxResults must be > 0.");
|
||||
}
|
||||
if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) {
|
||||
throw new IllegalArgumentException(
|
||||
"Category allowlist and denylist are mutually exclusive.");
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract Optional<ClassifierOptions> classifierOptions();
|
||||
abstract Optional<String> displayNamesLocale();
|
||||
|
||||
abstract Optional<Integer> maxResults();
|
||||
|
||||
abstract Optional<Float> scoreThreshold();
|
||||
|
||||
abstract List<String> categoryAllowlist();
|
||||
|
||||
abstract List<String> categoryDenylist();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_TextClassifier_TextClassifierOptions.Builder();
|
||||
return new AutoValue_TextClassifier_TextClassifierOptions.Builder()
|
||||
.setCategoryAllowlist(Collections.emptyList())
|
||||
.setCategoryDenylist(Collections.emptyList());
|
||||
}
|
||||
|
||||
/** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */
|
||||
|
@ -238,12 +297,21 @@ public final class TextClassifier implements AutoCloseable {
|
|||
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
|
||||
BaseOptionsProto.BaseOptions.newBuilder();
|
||||
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||
ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
|
||||
ClassifierOptionsProto.ClassifierOptions.newBuilder();
|
||||
displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale);
|
||||
maxResults().ifPresent(classifierOptionsBuilder::setMaxResults);
|
||||
scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold);
|
||||
if (!categoryAllowlist().isEmpty()) {
|
||||
classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
|
||||
}
|
||||
if (!categoryDenylist().isEmpty()) {
|
||||
classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
|
||||
}
|
||||
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder =
|
||||
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder()
|
||||
.setBaseOptions(baseOptionsBuilder);
|
||||
if (classifierOptions().isPresent()) {
|
||||
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
|
||||
}
|
||||
.setBaseOptions(baseOptionsBuilder)
|
||||
.setClassifierOptions(classifierOptionsBuilder);
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext,
|
||||
|
|
|
@ -98,10 +98,10 @@ android_library(
|
|||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
|
|
|
@ -27,7 +27,7 @@ import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
|||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||
import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
|
@ -376,10 +376,42 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
public abstract Builder setRunningMode(RunningMode runningMode);
|
||||
|
||||
/**
|
||||
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
|
||||
* score threshold, number of results, etc.
|
||||
* Sets the optional locale to use for display names specified through the TFLite Model
|
||||
* Metadata, if any.
|
||||
*/
|
||||
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
|
||||
public abstract Builder setDisplayNamesLocale(String locale);
|
||||
|
||||
/**
|
||||
* Sets the optional maximum number of top-scored classification results to return.
|
||||
*
|
||||
* <p>If not set, all available results are returned. If set, must be > 0.
|
||||
*/
|
||||
public abstract Builder setMaxResults(Integer maxResults);
|
||||
|
||||
/**
|
||||
* Sets the optional score threshold. Results with score below this value are rejected.
|
||||
*
|
||||
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
|
||||
*/
|
||||
public abstract Builder setScoreThreshold(Float scoreThreshold);
|
||||
|
||||
/**
|
||||
* Sets the optional allowlist of category names.
|
||||
*
|
||||
* <p>If non-empty, detection results whose category name is not in this set will be filtered
|
||||
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||
* categoryDenylist}.
|
||||
*/
|
||||
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
|
||||
|
||||
/**
|
||||
* Sets the optional denylist of category names.
|
||||
*
|
||||
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
|
||||
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||
* categoryAllowlist}.
|
||||
*/
|
||||
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
|
||||
|
||||
/**
|
||||
* Sets the {@link ResultListener} to receive the classification results asynchronously when
|
||||
|
@ -396,9 +428,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
/**
|
||||
* Validates and builds the {@link ImageClassifierOptions} instance. *
|
||||
*
|
||||
* @throws IllegalArgumentException if the result listener and the running mode are not
|
||||
* properly configured. The result listener should only be set when the image classifier
|
||||
* is in the live stream mode.
|
||||
* @throws IllegalArgumentException if any of the set options are invalid.
|
||||
*/
|
||||
public final ImageClassifierOptions build() {
|
||||
ImageClassifierOptions options = autoBuild();
|
||||
|
@ -413,6 +443,13 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
"The image classifier is in the image or video mode, a user-defined result listener"
|
||||
+ " shouldn't be provided in ImageClassifierOptions.");
|
||||
}
|
||||
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
|
||||
throw new IllegalArgumentException("If specified, maxResults must be > 0.");
|
||||
}
|
||||
if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) {
|
||||
throw new IllegalArgumentException(
|
||||
"Category allowlist and denylist are mutually exclusive.");
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
@ -421,7 +458,15 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
|
||||
abstract RunningMode runningMode();
|
||||
|
||||
abstract Optional<ClassifierOptions> classifierOptions();
|
||||
abstract Optional<String> displayNamesLocale();
|
||||
|
||||
abstract Optional<Integer> maxResults();
|
||||
|
||||
abstract Optional<Float> scoreThreshold();
|
||||
|
||||
abstract List<String> categoryAllowlist();
|
||||
|
||||
abstract List<String> categoryDenylist();
|
||||
|
||||
abstract Optional<ResultListener<ImageClassifierResult, MPImage>> resultListener();
|
||||
|
||||
|
@ -429,7 +474,9 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder()
|
||||
.setRunningMode(RunningMode.IMAGE);
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.setCategoryAllowlist(Collections.emptyList())
|
||||
.setCategoryDenylist(Collections.emptyList());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -441,12 +488,21 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
BaseOptionsProto.BaseOptions.newBuilder();
|
||||
baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
|
||||
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||
ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
|
||||
ClassifierOptionsProto.ClassifierOptions.newBuilder();
|
||||
displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale);
|
||||
maxResults().ifPresent(classifierOptionsBuilder::setMaxResults);
|
||||
scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold);
|
||||
if (!categoryAllowlist().isEmpty()) {
|
||||
classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
|
||||
}
|
||||
if (!categoryDenylist().isEmpty()) {
|
||||
classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
|
||||
}
|
||||
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder =
|
||||
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder()
|
||||
.setBaseOptions(baseOptionsBuilder);
|
||||
if (classifierOptions().isPresent()) {
|
||||
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
|
||||
}
|
||||
.setBaseOptions(baseOptionsBuilder)
|
||||
.setClassifierOptions(classifierOptionsBuilder);
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext,
|
||||
|
|
|
@ -40,6 +40,37 @@ public class TextClassifierTest {
|
|||
private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate";
|
||||
private static final String POSITIVE_TEXT = "it's a charming and often affecting journey";
|
||||
|
||||
@Test
|
||||
public void options_failsWithNegativeMaxResults() throws Exception {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
TextClassifierOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build())
|
||||
.setMaxResults(-1)
|
||||
.build());
|
||||
assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void options_failsWithBothAllowlistAndDenylist() throws Exception {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
TextClassifierOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build())
|
||||
.setCategoryAllowlist(Arrays.asList("foo"))
|
||||
.setCategoryDenylist(Arrays.asList("bar"))
|
||||
.build());
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("Category allowlist and denylist are mutually exclusive");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void create_failsWithMissingModel() throws Exception {
|
||||
String nonExistentFile = "/path/to/non/existent/file";
|
||||
|
|
|
@ -26,7 +26,6 @@ import com.google.mediapipe.framework.MediaPipeException;
|
|||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.TestUtils;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
|
@ -55,6 +54,37 @@ public class ImageClassifierTest {
|
|||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class General extends ImageClassifierTest {
|
||||
|
||||
@Test
|
||||
public void options_failsWithNegativeMaxResults() throws Exception {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
ImageClassifierOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||
.setMaxResults(-1)
|
||||
.build());
|
||||
assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void options_failsWithBothAllowlistAndDenylist() throws Exception {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
ImageClassifierOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||
.setCategoryAllowlist(Arrays.asList("foo"))
|
||||
.setCategoryDenylist(Arrays.asList("bar"))
|
||||
.build());
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("Category allowlist and denylist are mutually exclusive");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void create_failsWithMissingModel() throws Exception {
|
||||
String nonExistentFile = "/path/to/non/existent/file";
|
||||
|
@ -105,7 +135,7 @@ public class ImageClassifierTest {
|
|||
ImageClassifierOptions options =
|
||||
ImageClassifierOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build())
|
||||
.setMaxResults(3)
|
||||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
|
@ -125,7 +155,7 @@ public class ImageClassifierTest {
|
|||
ImageClassifierOptions options =
|
||||
ImageClassifierOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build())
|
||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
||||
.setMaxResults(1)
|
||||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
|
@ -141,7 +171,7 @@ public class ImageClassifierTest {
|
|||
ImageClassifierOptions options =
|
||||
ImageClassifierOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||
.setClassifierOptions(ClassifierOptions.builder().setScoreThreshold(0.02f).build())
|
||||
.setScoreThreshold(0.02f)
|
||||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
|
@ -160,10 +190,7 @@ public class ImageClassifierTest {
|
|||
ImageClassifierOptions options =
|
||||
ImageClassifierOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||
.setClassifierOptions(
|
||||
ClassifierOptions.builder()
|
||||
.setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf"))
|
||||
.build())
|
||||
.setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf"))
|
||||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
|
@ -183,11 +210,8 @@ public class ImageClassifierTest {
|
|||
ImageClassifierOptions options =
|
||||
ImageClassifierOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||
.setClassifierOptions(
|
||||
ClassifierOptions.builder()
|
||||
.setMaxResults(3)
|
||||
.setCategoryDenylist(Arrays.asList("bagel"))
|
||||
.build())
|
||||
.setMaxResults(3)
|
||||
.setCategoryDenylist(Arrays.asList("bagel"))
|
||||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
|
@ -207,7 +231,7 @@ public class ImageClassifierTest {
|
|||
ImageClassifierOptions options =
|
||||
ImageClassifierOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
||||
.setMaxResults(1)
|
||||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
|
@ -228,7 +252,7 @@ public class ImageClassifierTest {
|
|||
ImageClassifierOptions options =
|
||||
ImageClassifierOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build())
|
||||
.setMaxResults(3)
|
||||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
|
@ -251,7 +275,7 @@ public class ImageClassifierTest {
|
|||
ImageClassifierOptions options =
|
||||
ImageClassifierOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
||||
.setMaxResults(1)
|
||||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
|
@ -322,14 +346,14 @@ public class ImageClassifierTest {
|
|||
MediaPipeException.class,
|
||||
() ->
|
||||
imageClassifier.classifyForVideo(
|
||||
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||
getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
imageClassifier.classifyAsync(
|
||||
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||
getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
|
@ -353,7 +377,7 @@ public class ImageClassifierTest {
|
|||
MediaPipeException.class,
|
||||
() ->
|
||||
imageClassifier.classifyAsync(
|
||||
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||
getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
|
@ -379,7 +403,7 @@ public class ImageClassifierTest {
|
|||
MediaPipeException.class,
|
||||
() ->
|
||||
imageClassifier.classifyForVideo(
|
||||
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||
getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
}
|
||||
|
||||
|
@ -388,7 +412,7 @@ public class ImageClassifierTest {
|
|||
ImageClassifierOptions options =
|
||||
ImageClassifierOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
||||
.setMaxResults(1)
|
||||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
|
@ -405,13 +429,14 @@ public class ImageClassifierTest {
|
|||
ImageClassifierOptions options =
|
||||
ImageClassifierOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
||||
.setMaxResults(1)
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i);
|
||||
ImageClassifierResult results =
|
||||
imageClassifier.classifyForVideo(image, /* timestampMs= */ i);
|
||||
assertHasOneHead(results);
|
||||
assertCategoriesAre(
|
||||
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
|
||||
|
@ -424,7 +449,7 @@ public class ImageClassifierTest {
|
|||
ImageClassifierOptions options =
|
||||
ImageClassifierOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
||||
.setMaxResults(1)
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(imageClassificationResult, inputImage) -> {
|
||||
|
@ -436,11 +461,11 @@ public class ImageClassifierTest {
|
|||
.build();
|
||||
try (ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1);
|
||||
imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 1);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> imageClassifier.classifyAsync(image, /*timestampMs=*/ 0));
|
||||
() -> imageClassifier.classifyAsync(image, /* timestampMs= */ 0));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("having a smaller timestamp than the processed timestamp");
|
||||
|
@ -453,7 +478,7 @@ public class ImageClassifierTest {
|
|||
ImageClassifierOptions options =
|
||||
ImageClassifierOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
||||
.setMaxResults(1)
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(imageClassificationResult, inputImage) -> {
|
||||
|
@ -466,7 +491,7 @@ public class ImageClassifierTest {
|
|||
try (ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
imageClassifier.classifyAsync(image, /*timestampMs=*/ i);
|
||||
imageClassifier.classifyAsync(image, /* timestampMs= */ i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user