Inline formerly nested 'ClassifierOptions' in Java classifier APIs.
PiperOrigin-RevId: 492173060
This commit is contained in:
parent
460aee7933
commit
29c7702984
|
@ -66,10 +66,10 @@ android_library(
|
||||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
"//third_party:autovalue",
|
"//third_party:autovalue",
|
||||||
"@maven//:com_google_guava_guava",
|
"@maven//:com_google_guava_guava",
|
||||||
|
|
|
@ -27,7 +27,7 @@ import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi;
|
||||||
import com.google.mediapipe.tasks.audio.core.RunningMode;
|
import com.google.mediapipe.tasks.audio.core.RunningMode;
|
||||||
import com.google.mediapipe.tasks.components.containers.AudioData;
|
import com.google.mediapipe.tasks.components.containers.AudioData;
|
||||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
|
||||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||||
|
@ -320,10 +320,42 @@ public final class AudioClassifier extends BaseAudioTaskApi {
|
||||||
public abstract Builder setRunningMode(RunningMode runningMode);
|
public abstract Builder setRunningMode(RunningMode runningMode);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
|
* Sets the optional locale to use for display names specified through the TFLite Model
|
||||||
* score threshold, number of results, etc.
|
* Metadata, if any.
|
||||||
*/
|
*/
|
||||||
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
|
public abstract Builder setDisplayNamesLocale(String locale);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional maximum number of top-scored classification results to return.
|
||||||
|
*
|
||||||
|
* <p>If not set, all available results are returned. If set, must be > 0.
|
||||||
|
*/
|
||||||
|
public abstract Builder setMaxResults(Integer maxResults);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional score threshold. Results with score below this value are rejected.
|
||||||
|
*
|
||||||
|
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
|
||||||
|
*/
|
||||||
|
public abstract Builder setScoreThreshold(Float scoreThreshold);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional allowlist of category names.
|
||||||
|
*
|
||||||
|
* <p>If non-empty, detection results whose category name is not in this set will be filtered
|
||||||
|
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||||
|
* categoryDenylist}.
|
||||||
|
*/
|
||||||
|
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional denylist of category names.
|
||||||
|
*
|
||||||
|
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
|
||||||
|
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||||
|
* categoryAllowlist}.
|
||||||
|
*/
|
||||||
|
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets the {@link ResultListener} to receive the classification results asynchronously when
|
* Sets the {@link ResultListener} to receive the classification results asynchronously when
|
||||||
|
@ -340,9 +372,7 @@ public final class AudioClassifier extends BaseAudioTaskApi {
|
||||||
/**
|
/**
|
||||||
* Validates and builds the {@link AudioClassifierOptions} instance.
|
* Validates and builds the {@link AudioClassifierOptions} instance.
|
||||||
*
|
*
|
||||||
* @throws IllegalArgumentException if the result listener and the running mode are not
|
* @throws IllegalArgumentException if any of the set options are invalid.
|
||||||
* properly configured. The result listener should only be set when the audio classifier
|
|
||||||
* is in the audio stream mode.
|
|
||||||
*/
|
*/
|
||||||
public final AudioClassifierOptions build() {
|
public final AudioClassifierOptions build() {
|
||||||
AudioClassifierOptions options = autoBuild();
|
AudioClassifierOptions options = autoBuild();
|
||||||
|
@ -357,6 +387,13 @@ public final class AudioClassifier extends BaseAudioTaskApi {
|
||||||
"The audio classifier is in the audio clips mode, a user-defined result listener"
|
"The audio classifier is in the audio clips mode, a user-defined result listener"
|
||||||
+ " shouldn't be provided in AudioClassifierOptions.");
|
+ " shouldn't be provided in AudioClassifierOptions.");
|
||||||
}
|
}
|
||||||
|
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
|
||||||
|
throw new IllegalArgumentException("If specified, maxResults must be > 0.");
|
||||||
|
}
|
||||||
|
if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Category allowlist and denylist are mutually exclusive.");
|
||||||
|
}
|
||||||
return options;
|
return options;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -365,7 +402,15 @@ public final class AudioClassifier extends BaseAudioTaskApi {
|
||||||
|
|
||||||
abstract RunningMode runningMode();
|
abstract RunningMode runningMode();
|
||||||
|
|
||||||
abstract Optional<ClassifierOptions> classifierOptions();
|
abstract Optional<String> displayNamesLocale();
|
||||||
|
|
||||||
|
abstract Optional<Integer> maxResults();
|
||||||
|
|
||||||
|
abstract Optional<Float> scoreThreshold();
|
||||||
|
|
||||||
|
abstract List<String> categoryAllowlist();
|
||||||
|
|
||||||
|
abstract List<String> categoryDenylist();
|
||||||
|
|
||||||
abstract Optional<PureResultListener<AudioClassifierResult>> resultListener();
|
abstract Optional<PureResultListener<AudioClassifierResult>> resultListener();
|
||||||
|
|
||||||
|
@ -373,7 +418,9 @@ public final class AudioClassifier extends BaseAudioTaskApi {
|
||||||
|
|
||||||
public static Builder builder() {
|
public static Builder builder() {
|
||||||
return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder()
|
return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder()
|
||||||
.setRunningMode(RunningMode.AUDIO_CLIPS);
|
.setRunningMode(RunningMode.AUDIO_CLIPS)
|
||||||
|
.setCategoryAllowlist(Collections.emptyList())
|
||||||
|
.setCategoryDenylist(Collections.emptyList());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -385,12 +432,21 @@ public final class AudioClassifier extends BaseAudioTaskApi {
|
||||||
BaseOptionsProto.BaseOptions.newBuilder();
|
BaseOptionsProto.BaseOptions.newBuilder();
|
||||||
baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM);
|
baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM);
|
||||||
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||||
|
ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
|
||||||
|
ClassifierOptionsProto.ClassifierOptions.newBuilder();
|
||||||
|
displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale);
|
||||||
|
maxResults().ifPresent(classifierOptionsBuilder::setMaxResults);
|
||||||
|
scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold);
|
||||||
|
if (!categoryAllowlist().isEmpty()) {
|
||||||
|
classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
|
||||||
|
}
|
||||||
|
if (!categoryDenylist().isEmpty()) {
|
||||||
|
classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
|
||||||
|
}
|
||||||
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder =
|
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder =
|
||||||
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder()
|
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder()
|
||||||
.setBaseOptions(baseOptionsBuilder);
|
.setBaseOptions(baseOptionsBuilder)
|
||||||
if (classifierOptions().isPresent()) {
|
.setClassifierOptions(classifierOptionsBuilder);
|
||||||
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
|
|
||||||
}
|
|
||||||
return CalculatorOptions.newBuilder()
|
return CalculatorOptions.newBuilder()
|
||||||
.setExtension(
|
.setExtension(
|
||||||
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext,
|
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext,
|
||||||
|
|
|
@ -49,10 +49,10 @@ android_library(
|
||||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
|
||||||
"//third_party:autovalue",
|
"//third_party:autovalue",
|
||||||
|
|
|
@ -24,7 +24,7 @@ import com.google.mediapipe.framework.PacketGetter;
|
||||||
import com.google.mediapipe.framework.ProtoUtil;
|
import com.google.mediapipe.framework.ProtoUtil;
|
||||||
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
|
||||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||||
import com.google.mediapipe.tasks.core.TaskInfo;
|
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||||
|
@ -216,20 +216,79 @@ public final class TextClassifier implements AutoCloseable {
|
||||||
public abstract Builder setBaseOptions(BaseOptions value);
|
public abstract Builder setBaseOptions(BaseOptions value);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
|
* Sets the optional locale to use for display names specified through the TFLite Model
|
||||||
* score threshold, number of results, etc.
|
* Metadata, if any.
|
||||||
*/
|
*/
|
||||||
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
|
public abstract Builder setDisplayNamesLocale(String locale);
|
||||||
|
|
||||||
public abstract TextClassifierOptions build();
|
/**
|
||||||
|
* Sets the optional maximum number of top-scored classification results to return.
|
||||||
|
*
|
||||||
|
* <p>If not set, all available results are returned. If set, must be > 0.
|
||||||
|
*/
|
||||||
|
public abstract Builder setMaxResults(Integer maxResults);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional score threshold. Results with score below this value are rejected.
|
||||||
|
*
|
||||||
|
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
|
||||||
|
*/
|
||||||
|
public abstract Builder setScoreThreshold(Float scoreThreshold);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional allowlist of category names.
|
||||||
|
*
|
||||||
|
* <p>If non-empty, detection results whose category name is not in this set will be filtered
|
||||||
|
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||||
|
* categoryDenylist}.
|
||||||
|
*/
|
||||||
|
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional denylist of category names.
|
||||||
|
*
|
||||||
|
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
|
||||||
|
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||||
|
* categoryAllowlist}.
|
||||||
|
*/
|
||||||
|
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
|
||||||
|
|
||||||
|
abstract TextClassifierOptions autoBuild();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates and builds the {@link TextClassifierOptions} instance.
|
||||||
|
*
|
||||||
|
* @throws IllegalArgumentException if any of the set options are invalid.
|
||||||
|
*/
|
||||||
|
public final TextClassifierOptions build() {
|
||||||
|
TextClassifierOptions options = autoBuild();
|
||||||
|
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
|
||||||
|
throw new IllegalArgumentException("If specified, maxResults must be > 0.");
|
||||||
|
}
|
||||||
|
if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Category allowlist and denylist are mutually exclusive.");
|
||||||
|
}
|
||||||
|
return options;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract BaseOptions baseOptions();
|
abstract BaseOptions baseOptions();
|
||||||
|
|
||||||
abstract Optional<ClassifierOptions> classifierOptions();
|
abstract Optional<String> displayNamesLocale();
|
||||||
|
|
||||||
|
abstract Optional<Integer> maxResults();
|
||||||
|
|
||||||
|
abstract Optional<Float> scoreThreshold();
|
||||||
|
|
||||||
|
abstract List<String> categoryAllowlist();
|
||||||
|
|
||||||
|
abstract List<String> categoryDenylist();
|
||||||
|
|
||||||
public static Builder builder() {
|
public static Builder builder() {
|
||||||
return new AutoValue_TextClassifier_TextClassifierOptions.Builder();
|
return new AutoValue_TextClassifier_TextClassifierOptions.Builder()
|
||||||
|
.setCategoryAllowlist(Collections.emptyList())
|
||||||
|
.setCategoryDenylist(Collections.emptyList());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */
|
/** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */
|
||||||
|
@ -238,12 +297,21 @@ public final class TextClassifier implements AutoCloseable {
|
||||||
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
|
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
|
||||||
BaseOptionsProto.BaseOptions.newBuilder();
|
BaseOptionsProto.BaseOptions.newBuilder();
|
||||||
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||||
|
ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
|
||||||
|
ClassifierOptionsProto.ClassifierOptions.newBuilder();
|
||||||
|
displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale);
|
||||||
|
maxResults().ifPresent(classifierOptionsBuilder::setMaxResults);
|
||||||
|
scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold);
|
||||||
|
if (!categoryAllowlist().isEmpty()) {
|
||||||
|
classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
|
||||||
|
}
|
||||||
|
if (!categoryDenylist().isEmpty()) {
|
||||||
|
classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
|
||||||
|
}
|
||||||
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder =
|
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder =
|
||||||
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder()
|
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder()
|
||||||
.setBaseOptions(baseOptionsBuilder);
|
.setBaseOptions(baseOptionsBuilder)
|
||||||
if (classifierOptions().isPresent()) {
|
.setClassifierOptions(classifierOptionsBuilder);
|
||||||
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
|
|
||||||
}
|
|
||||||
return CalculatorOptions.newBuilder()
|
return CalculatorOptions.newBuilder()
|
||||||
.setExtension(
|
.setExtension(
|
||||||
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext,
|
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext,
|
||||||
|
|
|
@ -98,10 +98,10 @@ android_library(
|
||||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
"//third_party:autovalue",
|
"//third_party:autovalue",
|
||||||
"@maven//:com_google_guava_guava",
|
"@maven//:com_google_guava_guava",
|
||||||
|
|
|
@ -27,7 +27,7 @@ import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||||
import com.google.mediapipe.framework.image.MPImage;
|
import com.google.mediapipe.framework.image.MPImage;
|
||||||
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
|
||||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||||
|
@ -376,10 +376,42 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
public abstract Builder setRunningMode(RunningMode runningMode);
|
public abstract Builder setRunningMode(RunningMode runningMode);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
|
* Sets the optional locale to use for display names specified through the TFLite Model
|
||||||
* score threshold, number of results, etc.
|
* Metadata, if any.
|
||||||
*/
|
*/
|
||||||
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
|
public abstract Builder setDisplayNamesLocale(String locale);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional maximum number of top-scored classification results to return.
|
||||||
|
*
|
||||||
|
* <p>If not set, all available results are returned. If set, must be > 0.
|
||||||
|
*/
|
||||||
|
public abstract Builder setMaxResults(Integer maxResults);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional score threshold. Results with score below this value are rejected.
|
||||||
|
*
|
||||||
|
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
|
||||||
|
*/
|
||||||
|
public abstract Builder setScoreThreshold(Float scoreThreshold);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional allowlist of category names.
|
||||||
|
*
|
||||||
|
* <p>If non-empty, detection results whose category name is not in this set will be filtered
|
||||||
|
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||||
|
* categoryDenylist}.
|
||||||
|
*/
|
||||||
|
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional denylist of category names.
|
||||||
|
*
|
||||||
|
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
|
||||||
|
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||||
|
* categoryAllowlist}.
|
||||||
|
*/
|
||||||
|
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets the {@link ResultListener} to receive the classification results asynchronously when
|
* Sets the {@link ResultListener} to receive the classification results asynchronously when
|
||||||
|
@ -396,9 +428,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
/**
|
/**
|
||||||
* Validates and builds the {@link ImageClassifierOptions} instance. *
|
* Validates and builds the {@link ImageClassifierOptions} instance. *
|
||||||
*
|
*
|
||||||
* @throws IllegalArgumentException if the result listener and the running mode are not
|
* @throws IllegalArgumentException if any of the set options are invalid.
|
||||||
* properly configured. The result listener should only be set when the image classifier
|
|
||||||
* is in the live stream mode.
|
|
||||||
*/
|
*/
|
||||||
public final ImageClassifierOptions build() {
|
public final ImageClassifierOptions build() {
|
||||||
ImageClassifierOptions options = autoBuild();
|
ImageClassifierOptions options = autoBuild();
|
||||||
|
@ -413,6 +443,13 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
"The image classifier is in the image or video mode, a user-defined result listener"
|
"The image classifier is in the image or video mode, a user-defined result listener"
|
||||||
+ " shouldn't be provided in ImageClassifierOptions.");
|
+ " shouldn't be provided in ImageClassifierOptions.");
|
||||||
}
|
}
|
||||||
|
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
|
||||||
|
throw new IllegalArgumentException("If specified, maxResults must be > 0.");
|
||||||
|
}
|
||||||
|
if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Category allowlist and denylist are mutually exclusive.");
|
||||||
|
}
|
||||||
return options;
|
return options;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -421,7 +458,15 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
|
|
||||||
abstract RunningMode runningMode();
|
abstract RunningMode runningMode();
|
||||||
|
|
||||||
abstract Optional<ClassifierOptions> classifierOptions();
|
abstract Optional<String> displayNamesLocale();
|
||||||
|
|
||||||
|
abstract Optional<Integer> maxResults();
|
||||||
|
|
||||||
|
abstract Optional<Float> scoreThreshold();
|
||||||
|
|
||||||
|
abstract List<String> categoryAllowlist();
|
||||||
|
|
||||||
|
abstract List<String> categoryDenylist();
|
||||||
|
|
||||||
abstract Optional<ResultListener<ImageClassifierResult, MPImage>> resultListener();
|
abstract Optional<ResultListener<ImageClassifierResult, MPImage>> resultListener();
|
||||||
|
|
||||||
|
@ -429,7 +474,9 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
|
|
||||||
public static Builder builder() {
|
public static Builder builder() {
|
||||||
return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder()
|
return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder()
|
||||||
.setRunningMode(RunningMode.IMAGE);
|
.setRunningMode(RunningMode.IMAGE)
|
||||||
|
.setCategoryAllowlist(Collections.emptyList())
|
||||||
|
.setCategoryDenylist(Collections.emptyList());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -441,12 +488,21 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
BaseOptionsProto.BaseOptions.newBuilder();
|
BaseOptionsProto.BaseOptions.newBuilder();
|
||||||
baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
|
baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
|
||||||
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||||
|
ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
|
||||||
|
ClassifierOptionsProto.ClassifierOptions.newBuilder();
|
||||||
|
displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale);
|
||||||
|
maxResults().ifPresent(classifierOptionsBuilder::setMaxResults);
|
||||||
|
scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold);
|
||||||
|
if (!categoryAllowlist().isEmpty()) {
|
||||||
|
classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
|
||||||
|
}
|
||||||
|
if (!categoryDenylist().isEmpty()) {
|
||||||
|
classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
|
||||||
|
}
|
||||||
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder =
|
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder =
|
||||||
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder()
|
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder()
|
||||||
.setBaseOptions(baseOptionsBuilder);
|
.setBaseOptions(baseOptionsBuilder)
|
||||||
if (classifierOptions().isPresent()) {
|
.setClassifierOptions(classifierOptionsBuilder);
|
||||||
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
|
|
||||||
}
|
|
||||||
return CalculatorOptions.newBuilder()
|
return CalculatorOptions.newBuilder()
|
||||||
.setExtension(
|
.setExtension(
|
||||||
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext,
|
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext,
|
||||||
|
|
|
@ -40,6 +40,37 @@ public class TextClassifierTest {
|
||||||
private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate";
|
private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate";
|
||||||
private static final String POSITIVE_TEXT = "it's a charming and often affecting journey";
|
private static final String POSITIVE_TEXT = "it's a charming and often affecting journey";
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void options_failsWithNegativeMaxResults() throws Exception {
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
TextClassifierOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build())
|
||||||
|
.setMaxResults(-1)
|
||||||
|
.build());
|
||||||
|
assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void options_failsWithBothAllowlistAndDenylist() throws Exception {
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
TextClassifierOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build())
|
||||||
|
.setCategoryAllowlist(Arrays.asList("foo"))
|
||||||
|
.setCategoryDenylist(Arrays.asList("bar"))
|
||||||
|
.build());
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("Category allowlist and denylist are mutually exclusive");
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void create_failsWithMissingModel() throws Exception {
|
public void create_failsWithMissingModel() throws Exception {
|
||||||
String nonExistentFile = "/path/to/non/existent/file";
|
String nonExistentFile = "/path/to/non/existent/file";
|
||||||
|
|
|
@ -26,7 +26,6 @@ import com.google.mediapipe.framework.MediaPipeException;
|
||||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||||
import com.google.mediapipe.framework.image.MPImage;
|
import com.google.mediapipe.framework.image.MPImage;
|
||||||
import com.google.mediapipe.tasks.components.containers.Category;
|
import com.google.mediapipe.tasks.components.containers.Category;
|
||||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
|
||||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
import com.google.mediapipe.tasks.core.TestUtils;
|
import com.google.mediapipe.tasks.core.TestUtils;
|
||||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||||
|
@ -55,6 +54,37 @@ public class ImageClassifierTest {
|
||||||
@RunWith(AndroidJUnit4.class)
|
@RunWith(AndroidJUnit4.class)
|
||||||
public static final class General extends ImageClassifierTest {
|
public static final class General extends ImageClassifierTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void options_failsWithNegativeMaxResults() throws Exception {
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setMaxResults(-1)
|
||||||
|
.build());
|
||||||
|
assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void options_failsWithBothAllowlistAndDenylist() throws Exception {
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setCategoryAllowlist(Arrays.asList("foo"))
|
||||||
|
.setCategoryDenylist(Arrays.asList("bar"))
|
||||||
|
.build());
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("Category allowlist and denylist are mutually exclusive");
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void create_failsWithMissingModel() throws Exception {
|
public void create_failsWithMissingModel() throws Exception {
|
||||||
String nonExistentFile = "/path/to/non/existent/file";
|
String nonExistentFile = "/path/to/non/existent/file";
|
||||||
|
@ -105,7 +135,7 @@ public class ImageClassifierTest {
|
||||||
ImageClassifierOptions options =
|
ImageClassifierOptions options =
|
||||||
ImageClassifierOptions.builder()
|
ImageClassifierOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build())
|
.setMaxResults(3)
|
||||||
.build();
|
.build();
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
@ -125,7 +155,7 @@ public class ImageClassifierTest {
|
||||||
ImageClassifierOptions options =
|
ImageClassifierOptions options =
|
||||||
ImageClassifierOptions.builder()
|
ImageClassifierOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build())
|
||||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
.setMaxResults(1)
|
||||||
.build();
|
.build();
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
@ -141,7 +171,7 @@ public class ImageClassifierTest {
|
||||||
ImageClassifierOptions options =
|
ImageClassifierOptions options =
|
||||||
ImageClassifierOptions.builder()
|
ImageClassifierOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
.setClassifierOptions(ClassifierOptions.builder().setScoreThreshold(0.02f).build())
|
.setScoreThreshold(0.02f)
|
||||||
.build();
|
.build();
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
@ -160,10 +190,7 @@ public class ImageClassifierTest {
|
||||||
ImageClassifierOptions options =
|
ImageClassifierOptions options =
|
||||||
ImageClassifierOptions.builder()
|
ImageClassifierOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
.setClassifierOptions(
|
.setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf"))
|
||||||
ClassifierOptions.builder()
|
|
||||||
.setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf"))
|
|
||||||
.build())
|
|
||||||
.build();
|
.build();
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
@ -183,11 +210,8 @@ public class ImageClassifierTest {
|
||||||
ImageClassifierOptions options =
|
ImageClassifierOptions options =
|
||||||
ImageClassifierOptions.builder()
|
ImageClassifierOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
.setClassifierOptions(
|
.setMaxResults(3)
|
||||||
ClassifierOptions.builder()
|
.setCategoryDenylist(Arrays.asList("bagel"))
|
||||||
.setMaxResults(3)
|
|
||||||
.setCategoryDenylist(Arrays.asList("bagel"))
|
|
||||||
.build())
|
|
||||||
.build();
|
.build();
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
@ -207,7 +231,7 @@ public class ImageClassifierTest {
|
||||||
ImageClassifierOptions options =
|
ImageClassifierOptions options =
|
||||||
ImageClassifierOptions.builder()
|
ImageClassifierOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
.setMaxResults(1)
|
||||||
.build();
|
.build();
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
@ -228,7 +252,7 @@ public class ImageClassifierTest {
|
||||||
ImageClassifierOptions options =
|
ImageClassifierOptions options =
|
||||||
ImageClassifierOptions.builder()
|
ImageClassifierOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build())
|
.setMaxResults(3)
|
||||||
.build();
|
.build();
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
@ -251,7 +275,7 @@ public class ImageClassifierTest {
|
||||||
ImageClassifierOptions options =
|
ImageClassifierOptions options =
|
||||||
ImageClassifierOptions.builder()
|
ImageClassifierOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
.setMaxResults(1)
|
||||||
.build();
|
.build();
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
@ -322,14 +346,14 @@ public class ImageClassifierTest {
|
||||||
MediaPipeException.class,
|
MediaPipeException.class,
|
||||||
() ->
|
() ->
|
||||||
imageClassifier.classifyForVideo(
|
imageClassifier.classifyForVideo(
|
||||||
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
|
||||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||||
exception =
|
exception =
|
||||||
assertThrows(
|
assertThrows(
|
||||||
MediaPipeException.class,
|
MediaPipeException.class,
|
||||||
() ->
|
() ->
|
||||||
imageClassifier.classifyAsync(
|
imageClassifier.classifyAsync(
|
||||||
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
|
||||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -353,7 +377,7 @@ public class ImageClassifierTest {
|
||||||
MediaPipeException.class,
|
MediaPipeException.class,
|
||||||
() ->
|
() ->
|
||||||
imageClassifier.classifyAsync(
|
imageClassifier.classifyAsync(
|
||||||
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
|
||||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -379,7 +403,7 @@ public class ImageClassifierTest {
|
||||||
MediaPipeException.class,
|
MediaPipeException.class,
|
||||||
() ->
|
() ->
|
||||||
imageClassifier.classifyForVideo(
|
imageClassifier.classifyForVideo(
|
||||||
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
|
||||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -388,7 +412,7 @@ public class ImageClassifierTest {
|
||||||
ImageClassifierOptions options =
|
ImageClassifierOptions options =
|
||||||
ImageClassifierOptions.builder()
|
ImageClassifierOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
.setMaxResults(1)
|
||||||
.build();
|
.build();
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
@ -405,13 +429,14 @@ public class ImageClassifierTest {
|
||||||
ImageClassifierOptions options =
|
ImageClassifierOptions options =
|
||||||
ImageClassifierOptions.builder()
|
ImageClassifierOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
.setMaxResults(1)
|
||||||
.setRunningMode(RunningMode.VIDEO)
|
.setRunningMode(RunningMode.VIDEO)
|
||||||
.build();
|
.build();
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i);
|
ImageClassifierResult results =
|
||||||
|
imageClassifier.classifyForVideo(image, /* timestampMs= */ i);
|
||||||
assertHasOneHead(results);
|
assertHasOneHead(results);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
|
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
|
||||||
|
@ -424,7 +449,7 @@ public class ImageClassifierTest {
|
||||||
ImageClassifierOptions options =
|
ImageClassifierOptions options =
|
||||||
ImageClassifierOptions.builder()
|
ImageClassifierOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
.setMaxResults(1)
|
||||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
.setResultListener(
|
.setResultListener(
|
||||||
(imageClassificationResult, inputImage) -> {
|
(imageClassificationResult, inputImage) -> {
|
||||||
|
@ -436,11 +461,11 @@ public class ImageClassifierTest {
|
||||||
.build();
|
.build();
|
||||||
try (ImageClassifier imageClassifier =
|
try (ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||||
imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1);
|
imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 1);
|
||||||
MediaPipeException exception =
|
MediaPipeException exception =
|
||||||
assertThrows(
|
assertThrows(
|
||||||
MediaPipeException.class,
|
MediaPipeException.class,
|
||||||
() -> imageClassifier.classifyAsync(image, /*timestampMs=*/ 0));
|
() -> imageClassifier.classifyAsync(image, /* timestampMs= */ 0));
|
||||||
assertThat(exception)
|
assertThat(exception)
|
||||||
.hasMessageThat()
|
.hasMessageThat()
|
||||||
.contains("having a smaller timestamp than the processed timestamp");
|
.contains("having a smaller timestamp than the processed timestamp");
|
||||||
|
@ -453,7 +478,7 @@ public class ImageClassifierTest {
|
||||||
ImageClassifierOptions options =
|
ImageClassifierOptions options =
|
||||||
ImageClassifierOptions.builder()
|
ImageClassifierOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
.setMaxResults(1)
|
||||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
.setResultListener(
|
.setResultListener(
|
||||||
(imageClassificationResult, inputImage) -> {
|
(imageClassificationResult, inputImage) -> {
|
||||||
|
@ -466,7 +491,7 @@ public class ImageClassifierTest {
|
||||||
try (ImageClassifier imageClassifier =
|
try (ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||||
for (int i = 0; i < 3; ++i) {
|
for (int i = 0; i < 3; ++i) {
|
||||||
imageClassifier.classifyAsync(image, /*timestampMs=*/ i);
|
imageClassifier.classifyAsync(image, /* timestampMs= */ i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user