Inline formerly nested 'ClassifierOptions' in Java classifier APIs.
PiperOrigin-RevId: 492173060
This commit is contained in:
		
							parent
							
								
									460aee7933
								
							
						
					
					
						commit
						29c7702984
					
				| 
						 | 
					@ -66,10 +66,10 @@ android_library(
 | 
				
			||||||
        "//mediapipe/java/com/google/mediapipe/framework:android_framework",
 | 
					        "//mediapipe/java/com/google/mediapipe/framework:android_framework",
 | 
				
			||||||
        "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite",
 | 
					        "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite",
 | 
				
			||||||
        "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
 | 
					        "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
 | 
				
			||||||
        "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
 | 
					        "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata",
 | 
					        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata",
 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
 | 
					        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
 | 
					 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
 | 
					        "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
 | 
				
			||||||
        "//third_party:autovalue",
 | 
					        "//third_party:autovalue",
 | 
				
			||||||
        "@maven//:com_google_guava_guava",
 | 
					        "@maven//:com_google_guava_guava",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -27,7 +27,7 @@ import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi;
 | 
				
			||||||
import com.google.mediapipe.tasks.audio.core.RunningMode;
 | 
					import com.google.mediapipe.tasks.audio.core.RunningMode;
 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.AudioData;
 | 
					import com.google.mediapipe.tasks.components.containers.AudioData;
 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
 | 
					import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
 | 
				
			||||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
 | 
					import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
 | 
				
			||||||
import com.google.mediapipe.tasks.core.BaseOptions;
 | 
					import com.google.mediapipe.tasks.core.BaseOptions;
 | 
				
			||||||
import com.google.mediapipe.tasks.core.ErrorListener;
 | 
					import com.google.mediapipe.tasks.core.ErrorListener;
 | 
				
			||||||
import com.google.mediapipe.tasks.core.OutputHandler;
 | 
					import com.google.mediapipe.tasks.core.OutputHandler;
 | 
				
			||||||
| 
						 | 
					@ -320,10 +320,42 @@ public final class AudioClassifier extends BaseAudioTaskApi {
 | 
				
			||||||
      public abstract Builder setRunningMode(RunningMode runningMode);
 | 
					      public abstract Builder setRunningMode(RunningMode runningMode);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      /**
 | 
					      /**
 | 
				
			||||||
       * Sets the optional {@link ClassifierOptions} controling classification behavior, such as
 | 
					       * Sets the optional locale to use for display names specified through the TFLite Model
 | 
				
			||||||
       * score threshold, number of results, etc.
 | 
					       * Metadata, if any.
 | 
				
			||||||
       */
 | 
					       */
 | 
				
			||||||
      public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
 | 
					      public abstract Builder setDisplayNamesLocale(String locale);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      /**
 | 
				
			||||||
 | 
					       * Sets the optional maximum number of top-scored classification results to return.
 | 
				
			||||||
 | 
					       *
 | 
				
			||||||
 | 
					       * <p>If not set, all available results are returned. If set, must be > 0.
 | 
				
			||||||
 | 
					       */
 | 
				
			||||||
 | 
					      public abstract Builder setMaxResults(Integer maxResults);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      /**
 | 
				
			||||||
 | 
					       * Sets the optional score threshold. Results with score below this value are rejected.
 | 
				
			||||||
 | 
					       *
 | 
				
			||||||
 | 
					       * <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
 | 
				
			||||||
 | 
					       */
 | 
				
			||||||
 | 
					      public abstract Builder setScoreThreshold(Float scoreThreshold);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      /**
 | 
				
			||||||
 | 
					       * Sets the optional allowlist of category names.
 | 
				
			||||||
 | 
					       *
 | 
				
			||||||
 | 
					       * <p>If non-empty, detection results whose category name is not in this set will be filtered
 | 
				
			||||||
 | 
					       * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
 | 
				
			||||||
 | 
					       * categoryDenylist}.
 | 
				
			||||||
 | 
					       */
 | 
				
			||||||
 | 
					      public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      /**
 | 
				
			||||||
 | 
					       * Sets the optional denylist of category names.
 | 
				
			||||||
 | 
					       *
 | 
				
			||||||
 | 
					       * <p>If non-empty, detection results whose category name is in this set will be filtered out.
 | 
				
			||||||
 | 
					       * Duplicate or unknown category names are ignored. Mutually exclusive with {@code
 | 
				
			||||||
 | 
					       * categoryAllowlist}.
 | 
				
			||||||
 | 
					       */
 | 
				
			||||||
 | 
					      public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      /**
 | 
					      /**
 | 
				
			||||||
       * Sets the {@link ResultListener} to receive the classification results asynchronously when
 | 
					       * Sets the {@link ResultListener} to receive the classification results asynchronously when
 | 
				
			||||||
| 
						 | 
					@ -340,9 +372,7 @@ public final class AudioClassifier extends BaseAudioTaskApi {
 | 
				
			||||||
      /**
 | 
					      /**
 | 
				
			||||||
       * Validates and builds the {@link AudioClassifierOptions} instance.
 | 
					       * Validates and builds the {@link AudioClassifierOptions} instance.
 | 
				
			||||||
       *
 | 
					       *
 | 
				
			||||||
       * @throws IllegalArgumentException if the result listener and the running mode are not
 | 
					       * @throws IllegalArgumentException if any of the set options are invalid.
 | 
				
			||||||
       *     properly configured. The result listener should only be set when the audio classifier
 | 
					 | 
				
			||||||
       *     is in the audio stream mode.
 | 
					 | 
				
			||||||
       */
 | 
					       */
 | 
				
			||||||
      public final AudioClassifierOptions build() {
 | 
					      public final AudioClassifierOptions build() {
 | 
				
			||||||
        AudioClassifierOptions options = autoBuild();
 | 
					        AudioClassifierOptions options = autoBuild();
 | 
				
			||||||
| 
						 | 
					@ -357,6 +387,13 @@ public final class AudioClassifier extends BaseAudioTaskApi {
 | 
				
			||||||
              "The audio classifier is in the audio clips mode, a user-defined result listener"
 | 
					              "The audio classifier is in the audio clips mode, a user-defined result listener"
 | 
				
			||||||
                  + " shouldn't be provided in AudioClassifierOptions.");
 | 
					                  + " shouldn't be provided in AudioClassifierOptions.");
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					        if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
 | 
				
			||||||
 | 
					          throw new IllegalArgumentException("If specified, maxResults must be > 0.");
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) {
 | 
				
			||||||
 | 
					          throw new IllegalArgumentException(
 | 
				
			||||||
 | 
					              "Category allowlist and denylist are mutually exclusive.");
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        return options;
 | 
					        return options;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					@ -365,7 +402,15 @@ public final class AudioClassifier extends BaseAudioTaskApi {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    abstract RunningMode runningMode();
 | 
					    abstract RunningMode runningMode();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    abstract Optional<ClassifierOptions> classifierOptions();
 | 
					    abstract Optional<String> displayNamesLocale();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    abstract Optional<Integer> maxResults();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    abstract Optional<Float> scoreThreshold();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    abstract List<String> categoryAllowlist();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    abstract List<String> categoryDenylist();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    abstract Optional<PureResultListener<AudioClassifierResult>> resultListener();
 | 
					    abstract Optional<PureResultListener<AudioClassifierResult>> resultListener();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -373,7 +418,9 @@ public final class AudioClassifier extends BaseAudioTaskApi {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public static Builder builder() {
 | 
					    public static Builder builder() {
 | 
				
			||||||
      return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder()
 | 
					      return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder()
 | 
				
			||||||
          .setRunningMode(RunningMode.AUDIO_CLIPS);
 | 
					          .setRunningMode(RunningMode.AUDIO_CLIPS)
 | 
				
			||||||
 | 
					          .setCategoryAllowlist(Collections.emptyList())
 | 
				
			||||||
 | 
					          .setCategoryDenylist(Collections.emptyList());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
| 
						 | 
					@ -385,12 +432,21 @@ public final class AudioClassifier extends BaseAudioTaskApi {
 | 
				
			||||||
          BaseOptionsProto.BaseOptions.newBuilder();
 | 
					          BaseOptionsProto.BaseOptions.newBuilder();
 | 
				
			||||||
      baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM);
 | 
					      baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM);
 | 
				
			||||||
      baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
 | 
					      baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
 | 
				
			||||||
 | 
					      ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
 | 
				
			||||||
 | 
					          ClassifierOptionsProto.ClassifierOptions.newBuilder();
 | 
				
			||||||
 | 
					      displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale);
 | 
				
			||||||
 | 
					      maxResults().ifPresent(classifierOptionsBuilder::setMaxResults);
 | 
				
			||||||
 | 
					      scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold);
 | 
				
			||||||
 | 
					      if (!categoryAllowlist().isEmpty()) {
 | 
				
			||||||
 | 
					        classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      if (!categoryDenylist().isEmpty()) {
 | 
				
			||||||
 | 
					        classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
      AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder =
 | 
					      AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder =
 | 
				
			||||||
          AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder()
 | 
					          AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder()
 | 
				
			||||||
              .setBaseOptions(baseOptionsBuilder);
 | 
					              .setBaseOptions(baseOptionsBuilder)
 | 
				
			||||||
      if (classifierOptions().isPresent()) {
 | 
					              .setClassifierOptions(classifierOptionsBuilder);
 | 
				
			||||||
        taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      return CalculatorOptions.newBuilder()
 | 
					      return CalculatorOptions.newBuilder()
 | 
				
			||||||
          .setExtension(
 | 
					          .setExtension(
 | 
				
			||||||
              AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext,
 | 
					              AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -49,10 +49,10 @@ android_library(
 | 
				
			||||||
        "//mediapipe/framework:calculator_options_java_proto_lite",
 | 
					        "//mediapipe/framework:calculator_options_java_proto_lite",
 | 
				
			||||||
        "//mediapipe/java/com/google/mediapipe/framework:android_framework",
 | 
					        "//mediapipe/java/com/google/mediapipe/framework:android_framework",
 | 
				
			||||||
        "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
 | 
					        "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
 | 
				
			||||||
        "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
 | 
					        "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
 | 
				
			||||||
        "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
 | 
					        "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
 | 
					        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
 | 
					 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
 | 
					        "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
 | 
					        "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
 | 
				
			||||||
        "//third_party:autovalue",
 | 
					        "//third_party:autovalue",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -24,7 +24,7 @@ import com.google.mediapipe.framework.PacketGetter;
 | 
				
			||||||
import com.google.mediapipe.framework.ProtoUtil;
 | 
					import com.google.mediapipe.framework.ProtoUtil;
 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
 | 
					import com.google.mediapipe.tasks.components.containers.ClassificationResult;
 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
 | 
					import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
 | 
				
			||||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
 | 
					import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
 | 
				
			||||||
import com.google.mediapipe.tasks.core.BaseOptions;
 | 
					import com.google.mediapipe.tasks.core.BaseOptions;
 | 
				
			||||||
import com.google.mediapipe.tasks.core.OutputHandler;
 | 
					import com.google.mediapipe.tasks.core.OutputHandler;
 | 
				
			||||||
import com.google.mediapipe.tasks.core.TaskInfo;
 | 
					import com.google.mediapipe.tasks.core.TaskInfo;
 | 
				
			||||||
| 
						 | 
					@ -216,20 +216,79 @@ public final class TextClassifier implements AutoCloseable {
 | 
				
			||||||
      public abstract Builder setBaseOptions(BaseOptions value);
 | 
					      public abstract Builder setBaseOptions(BaseOptions value);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      /**
 | 
					      /**
 | 
				
			||||||
       * Sets the optional {@link ClassifierOptions} controling classification behavior, such as
 | 
					       * Sets the optional locale to use for display names specified through the TFLite Model
 | 
				
			||||||
       * score threshold, number of results, etc.
 | 
					       * Metadata, if any.
 | 
				
			||||||
       */
 | 
					       */
 | 
				
			||||||
      public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
 | 
					      public abstract Builder setDisplayNamesLocale(String locale);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      public abstract TextClassifierOptions build();
 | 
					      /**
 | 
				
			||||||
 | 
					       * Sets the optional maximum number of top-scored classification results to return.
 | 
				
			||||||
 | 
					       *
 | 
				
			||||||
 | 
					       * <p>If not set, all available results are returned. If set, must be > 0.
 | 
				
			||||||
 | 
					       */
 | 
				
			||||||
 | 
					      public abstract Builder setMaxResults(Integer maxResults);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      /**
 | 
				
			||||||
 | 
					       * Sets the optional score threshold. Results with score below this value are rejected.
 | 
				
			||||||
 | 
					       *
 | 
				
			||||||
 | 
					       * <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
 | 
				
			||||||
 | 
					       */
 | 
				
			||||||
 | 
					      public abstract Builder setScoreThreshold(Float scoreThreshold);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      /**
 | 
				
			||||||
 | 
					       * Sets the optional allowlist of category names.
 | 
				
			||||||
 | 
					       *
 | 
				
			||||||
 | 
					       * <p>If non-empty, detection results whose category name is not in this set will be filtered
 | 
				
			||||||
 | 
					       * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
 | 
				
			||||||
 | 
					       * categoryDenylist}.
 | 
				
			||||||
 | 
					       */
 | 
				
			||||||
 | 
					      public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      /**
 | 
				
			||||||
 | 
					       * Sets the optional denylist of category names.
 | 
				
			||||||
 | 
					       *
 | 
				
			||||||
 | 
					       * <p>If non-empty, detection results whose category name is in this set will be filtered out.
 | 
				
			||||||
 | 
					       * Duplicate or unknown category names are ignored. Mutually exclusive with {@code
 | 
				
			||||||
 | 
					       * categoryAllowlist}.
 | 
				
			||||||
 | 
					       */
 | 
				
			||||||
 | 
					      public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      abstract TextClassifierOptions autoBuild();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      /**
 | 
				
			||||||
 | 
					       * Validates and builds the {@link TextClassifierOptions} instance.
 | 
				
			||||||
 | 
					       *
 | 
				
			||||||
 | 
					       * @throws IllegalArgumentException if any of the set options are invalid.
 | 
				
			||||||
 | 
					       */
 | 
				
			||||||
 | 
					      public final TextClassifierOptions build() {
 | 
				
			||||||
 | 
					        TextClassifierOptions options = autoBuild();
 | 
				
			||||||
 | 
					        if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
 | 
				
			||||||
 | 
					          throw new IllegalArgumentException("If specified, maxResults must be > 0.");
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) {
 | 
				
			||||||
 | 
					          throw new IllegalArgumentException(
 | 
				
			||||||
 | 
					              "Category allowlist and denylist are mutually exclusive.");
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        return options;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    abstract BaseOptions baseOptions();
 | 
					    abstract BaseOptions baseOptions();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    abstract Optional<ClassifierOptions> classifierOptions();
 | 
					    abstract Optional<String> displayNamesLocale();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    abstract Optional<Integer> maxResults();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    abstract Optional<Float> scoreThreshold();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    abstract List<String> categoryAllowlist();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    abstract List<String> categoryDenylist();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public static Builder builder() {
 | 
					    public static Builder builder() {
 | 
				
			||||||
      return new AutoValue_TextClassifier_TextClassifierOptions.Builder();
 | 
					      return new AutoValue_TextClassifier_TextClassifierOptions.Builder()
 | 
				
			||||||
 | 
					          .setCategoryAllowlist(Collections.emptyList())
 | 
				
			||||||
 | 
					          .setCategoryDenylist(Collections.emptyList());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */
 | 
					    /** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */
 | 
				
			||||||
| 
						 | 
					@ -238,12 +297,21 @@ public final class TextClassifier implements AutoCloseable {
 | 
				
			||||||
      BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
 | 
					      BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
 | 
				
			||||||
          BaseOptionsProto.BaseOptions.newBuilder();
 | 
					          BaseOptionsProto.BaseOptions.newBuilder();
 | 
				
			||||||
      baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
 | 
					      baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
 | 
				
			||||||
 | 
					      ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
 | 
				
			||||||
 | 
					          ClassifierOptionsProto.ClassifierOptions.newBuilder();
 | 
				
			||||||
 | 
					      displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale);
 | 
				
			||||||
 | 
					      maxResults().ifPresent(classifierOptionsBuilder::setMaxResults);
 | 
				
			||||||
 | 
					      scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold);
 | 
				
			||||||
 | 
					      if (!categoryAllowlist().isEmpty()) {
 | 
				
			||||||
 | 
					        classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      if (!categoryDenylist().isEmpty()) {
 | 
				
			||||||
 | 
					        classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
      TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder =
 | 
					      TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder =
 | 
				
			||||||
          TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder()
 | 
					          TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder()
 | 
				
			||||||
              .setBaseOptions(baseOptionsBuilder);
 | 
					              .setBaseOptions(baseOptionsBuilder)
 | 
				
			||||||
      if (classifierOptions().isPresent()) {
 | 
					              .setClassifierOptions(classifierOptionsBuilder);
 | 
				
			||||||
        taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      return CalculatorOptions.newBuilder()
 | 
					      return CalculatorOptions.newBuilder()
 | 
				
			||||||
          .setExtension(
 | 
					          .setExtension(
 | 
				
			||||||
              TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext,
 | 
					              TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -98,10 +98,10 @@ android_library(
 | 
				
			||||||
        "//mediapipe/java/com/google/mediapipe/framework:android_framework",
 | 
					        "//mediapipe/java/com/google/mediapipe/framework:android_framework",
 | 
				
			||||||
        "//mediapipe/java/com/google/mediapipe/framework/image",
 | 
					        "//mediapipe/java/com/google/mediapipe/framework/image",
 | 
				
			||||||
        "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
 | 
					        "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
 | 
				
			||||||
        "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
 | 
					        "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
 | 
				
			||||||
        "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
 | 
					        "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
 | 
					        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
 | 
					 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
 | 
					        "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
 | 
				
			||||||
        "//third_party:autovalue",
 | 
					        "//third_party:autovalue",
 | 
				
			||||||
        "@maven//:com_google_guava_guava",
 | 
					        "@maven//:com_google_guava_guava",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -27,7 +27,7 @@ import com.google.mediapipe.framework.image.BitmapImageBuilder;
 | 
				
			||||||
import com.google.mediapipe.framework.image.MPImage;
 | 
					import com.google.mediapipe.framework.image.MPImage;
 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
 | 
					import com.google.mediapipe.tasks.components.containers.ClassificationResult;
 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
 | 
					import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
 | 
				
			||||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
 | 
					import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
 | 
				
			||||||
import com.google.mediapipe.tasks.core.BaseOptions;
 | 
					import com.google.mediapipe.tasks.core.BaseOptions;
 | 
				
			||||||
import com.google.mediapipe.tasks.core.ErrorListener;
 | 
					import com.google.mediapipe.tasks.core.ErrorListener;
 | 
				
			||||||
import com.google.mediapipe.tasks.core.OutputHandler;
 | 
					import com.google.mediapipe.tasks.core.OutputHandler;
 | 
				
			||||||
| 
						 | 
					@ -376,10 +376,42 @@ public final class ImageClassifier extends BaseVisionTaskApi {
 | 
				
			||||||
      public abstract Builder setRunningMode(RunningMode runningMode);
 | 
					      public abstract Builder setRunningMode(RunningMode runningMode);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      /**
 | 
					      /**
 | 
				
			||||||
       * Sets the optional {@link ClassifierOptions} controling classification behavior, such as
 | 
					       * Sets the optional locale to use for display names specified through the TFLite Model
 | 
				
			||||||
       * score threshold, number of results, etc.
 | 
					       * Metadata, if any.
 | 
				
			||||||
       */
 | 
					       */
 | 
				
			||||||
      public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
 | 
					      public abstract Builder setDisplayNamesLocale(String locale);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      /**
 | 
				
			||||||
 | 
					       * Sets the optional maximum number of top-scored classification results to return.
 | 
				
			||||||
 | 
					       *
 | 
				
			||||||
 | 
					       * <p>If not set, all available results are returned. If set, must be > 0.
 | 
				
			||||||
 | 
					       */
 | 
				
			||||||
 | 
					      public abstract Builder setMaxResults(Integer maxResults);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      /**
 | 
				
			||||||
 | 
					       * Sets the optional score threshold. Results with score below this value are rejected.
 | 
				
			||||||
 | 
					       *
 | 
				
			||||||
 | 
					       * <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
 | 
				
			||||||
 | 
					       */
 | 
				
			||||||
 | 
					      public abstract Builder setScoreThreshold(Float scoreThreshold);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      /**
 | 
				
			||||||
 | 
					       * Sets the optional allowlist of category names.
 | 
				
			||||||
 | 
					       *
 | 
				
			||||||
 | 
					       * <p>If non-empty, detection results whose category name is not in this set will be filtered
 | 
				
			||||||
 | 
					       * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
 | 
				
			||||||
 | 
					       * categoryDenylist}.
 | 
				
			||||||
 | 
					       */
 | 
				
			||||||
 | 
					      public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      /**
 | 
				
			||||||
 | 
					       * Sets the optional denylist of category names.
 | 
				
			||||||
 | 
					       *
 | 
				
			||||||
 | 
					       * <p>If non-empty, detection results whose category name is in this set will be filtered out.
 | 
				
			||||||
 | 
					       * Duplicate or unknown category names are ignored. Mutually exclusive with {@code
 | 
				
			||||||
 | 
					       * categoryAllowlist}.
 | 
				
			||||||
 | 
					       */
 | 
				
			||||||
 | 
					      public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      /**
 | 
					      /**
 | 
				
			||||||
       * Sets the {@link ResultListener} to receive the classification results asynchronously when
 | 
					       * Sets the {@link ResultListener} to receive the classification results asynchronously when
 | 
				
			||||||
| 
						 | 
					@ -396,9 +428,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
 | 
				
			||||||
      /**
 | 
					      /**
 | 
				
			||||||
       * Validates and builds the {@link ImageClassifierOptions} instance. *
 | 
					       * Validates and builds the {@link ImageClassifierOptions} instance. *
 | 
				
			||||||
       *
 | 
					       *
 | 
				
			||||||
       * @throws IllegalArgumentException if the result listener and the running mode are not
 | 
					       * @throws IllegalArgumentException if any of the set options are invalid.
 | 
				
			||||||
       *     properly configured. The result listener should only be set when the image classifier
 | 
					 | 
				
			||||||
       *     is in the live stream mode.
 | 
					 | 
				
			||||||
       */
 | 
					       */
 | 
				
			||||||
      public final ImageClassifierOptions build() {
 | 
					      public final ImageClassifierOptions build() {
 | 
				
			||||||
        ImageClassifierOptions options = autoBuild();
 | 
					        ImageClassifierOptions options = autoBuild();
 | 
				
			||||||
| 
						 | 
					@ -413,6 +443,13 @@ public final class ImageClassifier extends BaseVisionTaskApi {
 | 
				
			||||||
              "The image classifier is in the image or video mode, a user-defined result listener"
 | 
					              "The image classifier is in the image or video mode, a user-defined result listener"
 | 
				
			||||||
                  + " shouldn't be provided in ImageClassifierOptions.");
 | 
					                  + " shouldn't be provided in ImageClassifierOptions.");
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					        if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
 | 
				
			||||||
 | 
					          throw new IllegalArgumentException("If specified, maxResults must be > 0.");
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) {
 | 
				
			||||||
 | 
					          throw new IllegalArgumentException(
 | 
				
			||||||
 | 
					              "Category allowlist and denylist are mutually exclusive.");
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        return options;
 | 
					        return options;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					@ -421,7 +458,15 @@ public final class ImageClassifier extends BaseVisionTaskApi {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    abstract RunningMode runningMode();
 | 
					    abstract RunningMode runningMode();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    abstract Optional<ClassifierOptions> classifierOptions();
 | 
					    abstract Optional<String> displayNamesLocale();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    abstract Optional<Integer> maxResults();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    abstract Optional<Float> scoreThreshold();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    abstract List<String> categoryAllowlist();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    abstract List<String> categoryDenylist();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    abstract Optional<ResultListener<ImageClassifierResult, MPImage>> resultListener();
 | 
					    abstract Optional<ResultListener<ImageClassifierResult, MPImage>> resultListener();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -429,7 +474,9 @@ public final class ImageClassifier extends BaseVisionTaskApi {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public static Builder builder() {
 | 
					    public static Builder builder() {
 | 
				
			||||||
      return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder()
 | 
					      return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder()
 | 
				
			||||||
          .setRunningMode(RunningMode.IMAGE);
 | 
					          .setRunningMode(RunningMode.IMAGE)
 | 
				
			||||||
 | 
					          .setCategoryAllowlist(Collections.emptyList())
 | 
				
			||||||
 | 
					          .setCategoryDenylist(Collections.emptyList());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
| 
						 | 
					@ -441,12 +488,21 @@ public final class ImageClassifier extends BaseVisionTaskApi {
 | 
				
			||||||
          BaseOptionsProto.BaseOptions.newBuilder();
 | 
					          BaseOptionsProto.BaseOptions.newBuilder();
 | 
				
			||||||
      baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
 | 
					      baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
 | 
				
			||||||
      baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
 | 
					      baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
 | 
				
			||||||
 | 
					      ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
 | 
				
			||||||
 | 
					          ClassifierOptionsProto.ClassifierOptions.newBuilder();
 | 
				
			||||||
 | 
					      displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale);
 | 
				
			||||||
 | 
					      maxResults().ifPresent(classifierOptionsBuilder::setMaxResults);
 | 
				
			||||||
 | 
					      scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold);
 | 
				
			||||||
 | 
					      if (!categoryAllowlist().isEmpty()) {
 | 
				
			||||||
 | 
					        classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      if (!categoryDenylist().isEmpty()) {
 | 
				
			||||||
 | 
					        classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
      ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder =
 | 
					      ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder =
 | 
				
			||||||
          ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder()
 | 
					          ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder()
 | 
				
			||||||
              .setBaseOptions(baseOptionsBuilder);
 | 
					              .setBaseOptions(baseOptionsBuilder)
 | 
				
			||||||
      if (classifierOptions().isPresent()) {
 | 
					              .setClassifierOptions(classifierOptionsBuilder);
 | 
				
			||||||
        taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      return CalculatorOptions.newBuilder()
 | 
					      return CalculatorOptions.newBuilder()
 | 
				
			||||||
          .setExtension(
 | 
					          .setExtension(
 | 
				
			||||||
              ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext,
 | 
					              ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -40,6 +40,37 @@ public class TextClassifierTest {
 | 
				
			||||||
  private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate";
 | 
					  private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate";
 | 
				
			||||||
  private static final String POSITIVE_TEXT = "it's a charming and often affecting journey";
 | 
					  private static final String POSITIVE_TEXT = "it's a charming and often affecting journey";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  public void options_failsWithNegativeMaxResults() throws Exception {
 | 
				
			||||||
 | 
					    IllegalArgumentException exception =
 | 
				
			||||||
 | 
					        assertThrows(
 | 
				
			||||||
 | 
					            IllegalArgumentException.class,
 | 
				
			||||||
 | 
					            () ->
 | 
				
			||||||
 | 
					                TextClassifierOptions.builder()
 | 
				
			||||||
 | 
					                    .setBaseOptions(
 | 
				
			||||||
 | 
					                        BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build())
 | 
				
			||||||
 | 
					                    .setMaxResults(-1)
 | 
				
			||||||
 | 
					                    .build());
 | 
				
			||||||
 | 
					    assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  public void options_failsWithBothAllowlistAndDenylist() throws Exception {
 | 
				
			||||||
 | 
					    IllegalArgumentException exception =
 | 
				
			||||||
 | 
					        assertThrows(
 | 
				
			||||||
 | 
					            IllegalArgumentException.class,
 | 
				
			||||||
 | 
					            () ->
 | 
				
			||||||
 | 
					                TextClassifierOptions.builder()
 | 
				
			||||||
 | 
					                    .setBaseOptions(
 | 
				
			||||||
 | 
					                        BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build())
 | 
				
			||||||
 | 
					                    .setCategoryAllowlist(Arrays.asList("foo"))
 | 
				
			||||||
 | 
					                    .setCategoryDenylist(Arrays.asList("bar"))
 | 
				
			||||||
 | 
					                    .build());
 | 
				
			||||||
 | 
					    assertThat(exception)
 | 
				
			||||||
 | 
					        .hasMessageThat()
 | 
				
			||||||
 | 
					        .contains("Category allowlist and denylist are mutually exclusive");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @Test
 | 
					  @Test
 | 
				
			||||||
  public void create_failsWithMissingModel() throws Exception {
 | 
					  public void create_failsWithMissingModel() throws Exception {
 | 
				
			||||||
    String nonExistentFile = "/path/to/non/existent/file";
 | 
					    String nonExistentFile = "/path/to/non/existent/file";
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -26,7 +26,6 @@ import com.google.mediapipe.framework.MediaPipeException;
 | 
				
			||||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
 | 
					import com.google.mediapipe.framework.image.BitmapImageBuilder;
 | 
				
			||||||
import com.google.mediapipe.framework.image.MPImage;
 | 
					import com.google.mediapipe.framework.image.MPImage;
 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.Category;
 | 
					import com.google.mediapipe.tasks.components.containers.Category;
 | 
				
			||||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
 | 
					 | 
				
			||||||
import com.google.mediapipe.tasks.core.BaseOptions;
 | 
					import com.google.mediapipe.tasks.core.BaseOptions;
 | 
				
			||||||
import com.google.mediapipe.tasks.core.TestUtils;
 | 
					import com.google.mediapipe.tasks.core.TestUtils;
 | 
				
			||||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
 | 
					import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
 | 
				
			||||||
| 
						 | 
					@ -55,6 +54,37 @@ public class ImageClassifierTest {
 | 
				
			||||||
  @RunWith(AndroidJUnit4.class)
 | 
					  @RunWith(AndroidJUnit4.class)
 | 
				
			||||||
  public static final class General extends ImageClassifierTest {
 | 
					  public static final class General extends ImageClassifierTest {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    public void options_failsWithNegativeMaxResults() throws Exception {
 | 
				
			||||||
 | 
					      IllegalArgumentException exception =
 | 
				
			||||||
 | 
					          assertThrows(
 | 
				
			||||||
 | 
					              IllegalArgumentException.class,
 | 
				
			||||||
 | 
					              () ->
 | 
				
			||||||
 | 
					                  ImageClassifierOptions.builder()
 | 
				
			||||||
 | 
					                      .setBaseOptions(
 | 
				
			||||||
 | 
					                          BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
				
			||||||
 | 
					                      .setMaxResults(-1)
 | 
				
			||||||
 | 
					                      .build());
 | 
				
			||||||
 | 
					      assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    public void options_failsWithBothAllowlistAndDenylist() throws Exception {
 | 
				
			||||||
 | 
					      IllegalArgumentException exception =
 | 
				
			||||||
 | 
					          assertThrows(
 | 
				
			||||||
 | 
					              IllegalArgumentException.class,
 | 
				
			||||||
 | 
					              () ->
 | 
				
			||||||
 | 
					                  ImageClassifierOptions.builder()
 | 
				
			||||||
 | 
					                      .setBaseOptions(
 | 
				
			||||||
 | 
					                          BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
				
			||||||
 | 
					                      .setCategoryAllowlist(Arrays.asList("foo"))
 | 
				
			||||||
 | 
					                      .setCategoryDenylist(Arrays.asList("bar"))
 | 
				
			||||||
 | 
					                      .build());
 | 
				
			||||||
 | 
					      assertThat(exception)
 | 
				
			||||||
 | 
					          .hasMessageThat()
 | 
				
			||||||
 | 
					          .contains("Category allowlist and denylist are mutually exclusive");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void create_failsWithMissingModel() throws Exception {
 | 
					    public void create_failsWithMissingModel() throws Exception {
 | 
				
			||||||
      String nonExistentFile = "/path/to/non/existent/file";
 | 
					      String nonExistentFile = "/path/to/non/existent/file";
 | 
				
			||||||
| 
						 | 
					@ -105,7 +135,7 @@ public class ImageClassifierTest {
 | 
				
			||||||
      ImageClassifierOptions options =
 | 
					      ImageClassifierOptions options =
 | 
				
			||||||
          ImageClassifierOptions.builder()
 | 
					          ImageClassifierOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
				
			||||||
              .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build())
 | 
					              .setMaxResults(3)
 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
| 
						 | 
					@ -125,7 +155,7 @@ public class ImageClassifierTest {
 | 
				
			||||||
      ImageClassifierOptions options =
 | 
					      ImageClassifierOptions options =
 | 
				
			||||||
          ImageClassifierOptions.builder()
 | 
					          ImageClassifierOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build())
 | 
				
			||||||
              .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
 | 
					              .setMaxResults(1)
 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
| 
						 | 
					@ -141,7 +171,7 @@ public class ImageClassifierTest {
 | 
				
			||||||
      ImageClassifierOptions options =
 | 
					      ImageClassifierOptions options =
 | 
				
			||||||
          ImageClassifierOptions.builder()
 | 
					          ImageClassifierOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
				
			||||||
              .setClassifierOptions(ClassifierOptions.builder().setScoreThreshold(0.02f).build())
 | 
					              .setScoreThreshold(0.02f)
 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
| 
						 | 
					@ -160,10 +190,7 @@ public class ImageClassifierTest {
 | 
				
			||||||
      ImageClassifierOptions options =
 | 
					      ImageClassifierOptions options =
 | 
				
			||||||
          ImageClassifierOptions.builder()
 | 
					          ImageClassifierOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
				
			||||||
              .setClassifierOptions(
 | 
					 | 
				
			||||||
                  ClassifierOptions.builder()
 | 
					 | 
				
			||||||
              .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf"))
 | 
					              .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf"))
 | 
				
			||||||
                      .build())
 | 
					 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
| 
						 | 
					@ -183,11 +210,8 @@ public class ImageClassifierTest {
 | 
				
			||||||
      ImageClassifierOptions options =
 | 
					      ImageClassifierOptions options =
 | 
				
			||||||
          ImageClassifierOptions.builder()
 | 
					          ImageClassifierOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
				
			||||||
              .setClassifierOptions(
 | 
					 | 
				
			||||||
                  ClassifierOptions.builder()
 | 
					 | 
				
			||||||
              .setMaxResults(3)
 | 
					              .setMaxResults(3)
 | 
				
			||||||
              .setCategoryDenylist(Arrays.asList("bagel"))
 | 
					              .setCategoryDenylist(Arrays.asList("bagel"))
 | 
				
			||||||
                      .build())
 | 
					 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
| 
						 | 
					@ -207,7 +231,7 @@ public class ImageClassifierTest {
 | 
				
			||||||
      ImageClassifierOptions options =
 | 
					      ImageClassifierOptions options =
 | 
				
			||||||
          ImageClassifierOptions.builder()
 | 
					          ImageClassifierOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
				
			||||||
              .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
 | 
					              .setMaxResults(1)
 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
| 
						 | 
					@ -228,7 +252,7 @@ public class ImageClassifierTest {
 | 
				
			||||||
      ImageClassifierOptions options =
 | 
					      ImageClassifierOptions options =
 | 
				
			||||||
          ImageClassifierOptions.builder()
 | 
					          ImageClassifierOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
				
			||||||
              .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build())
 | 
					              .setMaxResults(3)
 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
| 
						 | 
					@ -251,7 +275,7 @@ public class ImageClassifierTest {
 | 
				
			||||||
      ImageClassifierOptions options =
 | 
					      ImageClassifierOptions options =
 | 
				
			||||||
          ImageClassifierOptions.builder()
 | 
					          ImageClassifierOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
				
			||||||
              .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
 | 
					              .setMaxResults(1)
 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
| 
						 | 
					@ -388,7 +412,7 @@ public class ImageClassifierTest {
 | 
				
			||||||
      ImageClassifierOptions options =
 | 
					      ImageClassifierOptions options =
 | 
				
			||||||
          ImageClassifierOptions.builder()
 | 
					          ImageClassifierOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
				
			||||||
              .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
 | 
					              .setMaxResults(1)
 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
| 
						 | 
					@ -405,13 +429,14 @@ public class ImageClassifierTest {
 | 
				
			||||||
      ImageClassifierOptions options =
 | 
					      ImageClassifierOptions options =
 | 
				
			||||||
          ImageClassifierOptions.builder()
 | 
					          ImageClassifierOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
				
			||||||
              .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
 | 
					              .setMaxResults(1)
 | 
				
			||||||
              .setRunningMode(RunningMode.VIDEO)
 | 
					              .setRunningMode(RunningMode.VIDEO)
 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
      for (int i = 0; i < 3; i++) {
 | 
					      for (int i = 0; i < 3; i++) {
 | 
				
			||||||
        ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i);
 | 
					        ImageClassifierResult results =
 | 
				
			||||||
 | 
					            imageClassifier.classifyForVideo(image, /* timestampMs= */ i);
 | 
				
			||||||
        assertHasOneHead(results);
 | 
					        assertHasOneHead(results);
 | 
				
			||||||
        assertCategoriesAre(
 | 
					        assertCategoriesAre(
 | 
				
			||||||
            results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
 | 
					            results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
 | 
				
			||||||
| 
						 | 
					@ -424,7 +449,7 @@ public class ImageClassifierTest {
 | 
				
			||||||
      ImageClassifierOptions options =
 | 
					      ImageClassifierOptions options =
 | 
				
			||||||
          ImageClassifierOptions.builder()
 | 
					          ImageClassifierOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
				
			||||||
              .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
 | 
					              .setMaxResults(1)
 | 
				
			||||||
              .setRunningMode(RunningMode.LIVE_STREAM)
 | 
					              .setRunningMode(RunningMode.LIVE_STREAM)
 | 
				
			||||||
              .setResultListener(
 | 
					              .setResultListener(
 | 
				
			||||||
                  (imageClassificationResult, inputImage) -> {
 | 
					                  (imageClassificationResult, inputImage) -> {
 | 
				
			||||||
| 
						 | 
					@ -453,7 +478,7 @@ public class ImageClassifierTest {
 | 
				
			||||||
      ImageClassifierOptions options =
 | 
					      ImageClassifierOptions options =
 | 
				
			||||||
          ImageClassifierOptions.builder()
 | 
					          ImageClassifierOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
 | 
				
			||||||
              .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
 | 
					              .setMaxResults(1)
 | 
				
			||||||
              .setRunningMode(RunningMode.LIVE_STREAM)
 | 
					              .setRunningMode(RunningMode.LIVE_STREAM)
 | 
				
			||||||
              .setResultListener(
 | 
					              .setResultListener(
 | 
				
			||||||
                  (imageClassificationResult, inputImage) -> {
 | 
					                  (imageClassificationResult, inputImage) -> {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user