entries = new ArrayList<>();
+ for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) {
+ entries.add(classificationEntryFromProto(entry));
+ }
+ return Classifications.create(
+ entries, classifications.getHeadIndex(), classifications.getHeadName());
+ }
+}
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java
new file mode 100644
index 000000000..68cae151f
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java
@@ -0,0 +1,456 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.vision.imageclassifier;
+
+import android.content.Context;
+import android.graphics.RectF;
+import android.os.ParcelFileDescriptor;
+import com.google.auto.value.AutoValue;
+import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
+import com.google.mediapipe.framework.AndroidPacketGetter;
+import com.google.mediapipe.framework.MediaPipeException;
+import com.google.mediapipe.framework.Packet;
+import com.google.mediapipe.framework.PacketGetter;
+import com.google.mediapipe.framework.ProtoUtil;
+import com.google.mediapipe.framework.image.BitmapImageBuilder;
+import com.google.mediapipe.framework.image.Image;
+import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
+import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
+import com.google.mediapipe.tasks.core.BaseOptions;
+import com.google.mediapipe.tasks.core.ErrorListener;
+import com.google.mediapipe.tasks.core.OutputHandler;
+import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
+import com.google.mediapipe.tasks.core.TaskInfo;
+import com.google.mediapipe.tasks.core.TaskOptions;
+import com.google.mediapipe.tasks.core.TaskRunner;
+import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
+import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
+import com.google.mediapipe.tasks.vision.core.RunningMode;
+import com.google.mediapipe.tasks.vision.imageclassifier.proto.ImageClassifierGraphOptionsProto;
+import com.google.protobuf.InvalidProtocolBufferException;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * Performs classification on images.
+ *
+ * The API expects a TFLite model with optional, but strongly recommended, TFLite Model Metadata..
+ *
+ *
The API supports models with one image input tensor and one or more output tensors. To be more
+ * specific, here are the requirements.
+ *
+ *
+ * - Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
+ *
+ * - image input of size {@code [batch x height x width x channels]}.
+ *
- batch inference is not supported ({@code batch} is required to be 1).
+ *
- only RGB inputs are supported ({@code channels} is required to be 3).
+ *
- if type is kTfLiteFloat32, NormalizationOptions are required to be attached to the
+ * metadata for input normalization.
+ *
+ * - At least one output tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) with:
+ *
+ * - {@code N} classes and either 2 or 4 dimensions, i.e. {@code [1 x N]} or {@code [1 x 1
+ * x 1 x N]}
+ *
- optional (but recommended) label map(s) as AssociatedFile-s with type
+ * TENSOR_AXIS_LABELS, containing one label per line. The first such AssociatedFile (if
+ * any) is used to fill the {@code class_name} field of the results. The {@code
+ * display_name} field is filled from the AssociatedFile (if any) whose locale matches
+ * the {@code display_names_locale} field of the {@code ImageClassifierOptions} used at
+ * creation time ("en" by default, i.e. English). If none of these are available, only
+ * the {@code index} field of the results will be filled.
+ *
- optional score calibration can be attached using ScoreCalibrationOptions and an
+ * AssociatedFile with type TENSOR_AXIS_SCORE_CALIBRATION. See
+ * metadata_schema.fbs for more details.
+ *
+ *
+ *
+ * An example of such model can be found
+ * TensorFlow Hub.
+ */
+public final class ImageClassifier extends BaseVisionTaskApi {
+ private static final String TAG = ImageClassifier.class.getSimpleName();
+ private static final String IMAGE_IN_STREAM_NAME = "image_in";
+ private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
+ private static final List INPUT_STREAMS =
+ Collections.unmodifiableList(
+ Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
+ private static final List OUTPUT_STREAMS =
+ Collections.unmodifiableList(
+ Arrays.asList("CLASSIFICATION_RESULT:classification_result_out", "IMAGE:image_out"));
+ private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0;
+ private static final int IMAGE_OUT_STREAM_INDEX = 1;
+ private static final String TASK_GRAPH_NAME =
+ "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
+
+ static {
+ ProtoUtil.registerTypeName(
+ ClassificationsProto.ClassificationResult.class,
+ "mediapipe.tasks.components.containers.proto.ClassificationResult");
+ }
+
+ /**
+ * Creates an {@link ImageClassifier} instance from a model file and default {@link
+ * ImageClassifierOptions}.
+ *
+ * @param context an Android {@link Context}.
+ * @param modelPath path to the classification model in the assets.
+ * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
+ */
+ public static ImageClassifier createFromFile(Context context, String modelPath) {
+ BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
+ return createFromOptions(
+ context, ImageClassifierOptions.builder().setBaseOptions(baseOptions).build());
+ }
+
+ /**
+ * Creates an {@link ImageClassifier} instance from a model file and default {@link
+ * ImageClassifierOptions}.
+ *
+ * @param context an Android {@link Context}.
+ * @param modelFile the classification model {@link File} instance.
+ * @throws IOException if an I/O error occurs when opening the tflite model file.
+ * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
+ */
+ public static ImageClassifier createFromFile(Context context, File modelFile) throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ BaseOptions baseOptions =
+ BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
+ return createFromOptions(
+ context, ImageClassifierOptions.builder().setBaseOptions(baseOptions).build());
+ }
+ }
+
+ /**
+ * Creates an {@link ImageClassifier} instance from a model buffer and default {@link
+ * ImageClassifierOptions}.
+ *
+ * @param context an Android {@link Context}.
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
+ * classification model.
+ * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
+ */
+ public static ImageClassifier createFromBuffer(Context context, final ByteBuffer modelBuffer) {
+ BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
+ return createFromOptions(
+ context, ImageClassifierOptions.builder().setBaseOptions(baseOptions).build());
+ }
+
+ /**
+ * Creates an {@link ImageClassifier} instance from an {@link ImageClassifierOptions} instance.
+ *
+ * @param context an Android {@link Context}.
+ * @param options an {@link ImageClassifierOptions} instance.
+ * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
+ */
+ public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) {
+ OutputHandler handler = new OutputHandler<>();
+ handler.setOutputPacketConverter(
+ new OutputHandler.OutputPacketConverter() {
+ @Override
+ public ImageClassificationResult convertToTaskResult(List packets) {
+ try {
+ return ImageClassificationResult.create(
+ PacketGetter.getProto(
+ packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX),
+ ClassificationsProto.ClassificationResult.getDefaultInstance()),
+ packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp());
+ } catch (InvalidProtocolBufferException e) {
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
+ }
+ }
+
+ @Override
+ public Image convertToTaskInput(List packets) {
+ return new BitmapImageBuilder(
+ AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
+ .build();
+ }
+ });
+ options.resultListener().ifPresent(handler::setResultListener);
+ options.errorListener().ifPresent(handler::setErrorListener);
+ TaskRunner runner =
+ TaskRunner.create(
+ context,
+ TaskInfo.builder()
+ .setTaskGraphName(TASK_GRAPH_NAME)
+ .setInputStreams(INPUT_STREAMS)
+ .setOutputStreams(OUTPUT_STREAMS)
+ .setTaskOptions(options)
+ .setEnableFlowLimiting(options.runningMode() == RunningMode.LIVE_STREAM)
+ .build(),
+ handler);
+ return new ImageClassifier(runner, options.runningMode());
+ }
+
+ /**
+ * Constructor to initialize an {@link ImageClassifier} from a {@link TaskRunner} and {@link
+ * RunningMode}.
+ *
+ * @param taskRunner a {@link TaskRunner}.
+ * @param runningMode a mediapipe vision task {@link RunningMode}.
+ */
+ private ImageClassifier(TaskRunner taskRunner, RunningMode runningMode) {
+ super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
+ }
+
+ /**
+ * Performs classification on the provided single image. Only use this method when the {@link
+ * ImageClassifier} is created with {@link RunningMode.IMAGE}.
+ *
+ * {@link ImageClassifier} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param inputImage a MediaPipe {@link Image} object for processing.
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public ImageClassificationResult classify(Image inputImage) {
+ return (ImageClassificationResult) processImageData(inputImage, buildFullImageRectF());
+ }
+
+ /**
+ * Performs classification on the provided single image and region-of-interest. Only use this
+ * method when the {@link ImageClassifier} is created with {@link RunningMode.IMAGE}.
+ *
+ * {@link ImageClassifier} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param inputImage a MediaPipe {@link Image} object for processing.
+ * @param roi a {@link RectF} specifying the region of interest on which to perform
+ * classification. Coordinates are expected to be specified as normalized values in [0,1].
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public ImageClassificationResult classify(Image inputImage, RectF roi) {
+ return (ImageClassificationResult) processImageData(inputImage, roi);
+ }
+
+ /**
+ * Performs classification on the provided video frame. Only use this method when the {@link
+ * ImageClassifier} is created with {@link RunningMode.VIDEO}.
+ *
+ * It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
+ * must be monotonically increasing.
+ *
+ *
{@link ImageClassifier} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param inputImage a MediaPipe {@link Image} object for processing.
+ * @param inputTimestampMs the input timestamp (in milliseconds).
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public ImageClassificationResult classifyForVideo(Image inputImage, long inputTimestampMs) {
+ return (ImageClassificationResult)
+ processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs);
+ }
+
+ /**
+ * Performs classification on the provided video frame with additional region-of-interest. Only
+ * use this method when the {@link ImageClassifier} is created with {@link RunningMode.VIDEO}.
+ *
+ * It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
+ * must be monotonically increasing.
+ *
+ *
{@link ImageClassifier} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param inputImage a MediaPipe {@link Image} object for processing.
+ * @param roi a {@link RectF} specifying the region of interest on which to perform
+ * classification. Coordinates are expected to be specified as normalized values in [0,1].
+ * @param inputTimestampMs the input timestamp (in milliseconds).
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public ImageClassificationResult classifyForVideo(
+ Image inputImage, RectF roi, long inputTimestampMs) {
+ return (ImageClassificationResult) processVideoData(inputImage, roi, inputTimestampMs);
+ }
+
+ /**
+ * Sends live image data to perform classification, and the results will be available via the
+ * {@link ResultListener} provided in the {@link ImageClassifierOptions}. Only use this method
+ * when the {@link ImageClassifier} is created with {@link RunningMode.LIVE_STREAM}.
+ *
+ * It's required to provide a timestamp (in milliseconds) to indicate when the input image is
+ * sent to the object detector. The input timestamps must be monotonically increasing.
+ *
+ *
{@link ImageClassifier} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param inputImage a MediaPipe {@link Image} object for processing.
+ * @param inputTimestampMs the input timestamp (in milliseconds).
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public void classifyAsync(Image inputImage, long inputTimestampMs) {
+ sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs);
+ }
+
+ /**
+ * Sends live image data and additional region-of-interest to perform classification, and the
+ * results will be available via the {@link ResultListener} provided in the {@link
+ * ImageClassifierOptions}. Only use this method when the {@link ImageClassifier} is created with
+ * {@link RunningMode.LIVE_STREAM}.
+ *
+ * It's required to provide a timestamp (in milliseconds) to indicate when the input image is
+ * sent to the object detector. The input timestamps must be monotonically increasing.
+ *
+ *
{@link ImageClassifier} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param inputImage a MediaPipe {@link Image} object for processing.
+ * @param roi a {@link RectF} specifying the region of interest on which to perform
+ * classification. Coordinates are expected to be specified as normalized values in [0,1].
+ * @param inputTimestampMs the input timestamp (in milliseconds).
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public void classifyAsync(Image inputImage, RectF roi, long inputTimestampMs) {
+ sendLiveStreamData(inputImage, roi, inputTimestampMs);
+ }
+
+ /** Options for setting up and {@link ImageClassifier}. */
+ @AutoValue
+ public abstract static class ImageClassifierOptions extends TaskOptions {
+
+ /** Builder for {@link ImageClassifierOptions}. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ /** Sets the {@link BaseOptions} for the image classifier task. */
+ public abstract Builder setBaseOptions(BaseOptions baseOptions);
+
+ /**
+ * Sets the {@link RunningMode} for the image classifier task. Default to the image mode.
+ * Image classifier has three modes:
+ *
+ *
+ * - IMAGE: The mode for performing classification on single image inputs.
+ *
- VIDEO: The mode for performing classification on the decoded frames of a video.
+ *
- LIVE_STREAM: The mode for for performing classification on a live stream of input
+ * data, such as from camera. In this mode, {@code setResultListener} must be called to
+ * set up a listener to receive the classification results asynchronously.
+ *
+ */
+ public abstract Builder setRunningMode(RunningMode runningMode);
+
+ /**
+ * Sets the optional {@link ClassifierOptions} controling classification behavior, such as
+ * score threshold, number of results, etc.
+ */
+ public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
+
+ /**
+ * Sets the {@link ResultListener} to receive the classification results asynchronously when
+ * the image classifier is in the live stream mode.
+ */
+ public abstract Builder setResultListener(
+ ResultListener resultListener);
+
+ /** Sets an optional {@link ErrorListener}. */
+ public abstract Builder setErrorListener(ErrorListener errorListener);
+
+ abstract ImageClassifierOptions autoBuild();
+
+ /**
+ * Validates and builds the {@link ImageClassifierOptions} instance. *
+ *
+ * @throws IllegalArgumentException if the result listener and the running mode are not
+ * properly configured. The result listener should only be set when the image classifier
+ * is in the live stream mode.
+ */
+ public final ImageClassifierOptions build() {
+ ImageClassifierOptions options = autoBuild();
+ if (options.runningMode() == RunningMode.LIVE_STREAM) {
+ if (!options.resultListener().isPresent()) {
+ throw new IllegalArgumentException(
+ "The image classifier is in the live stream mode, a user-defined result listener"
+ + " must be provided in the ImageClassifierOptions.");
+ }
+ } else if (options.resultListener().isPresent()) {
+ throw new IllegalArgumentException(
+ "The image classifier is in the image or video mode, a user-defined result listener"
+ + " shouldn't be provided in ImageClassifierOptions.");
+ }
+ return options;
+ }
+ }
+
+ abstract BaseOptions baseOptions();
+
+ abstract RunningMode runningMode();
+
+ abstract Optional classifierOptions();
+
+ abstract Optional> resultListener();
+
+ abstract Optional errorListener();
+
+ public static Builder builder() {
+ return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder()
+ .setRunningMode(RunningMode.IMAGE);
+ }
+
+ /**
+ * Converts a {@link ImageClassifierOptions} to a {@link CalculatorOptions} protobuf message.
+ */
+ @Override
+ public CalculatorOptions convertToCalculatorOptionsProto() {
+ BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
+ BaseOptionsProto.BaseOptions.newBuilder();
+ baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
+ baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
+ ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder =
+ ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder()
+ .setBaseOptions(baseOptionsBuilder);
+ if (classifierOptions().isPresent()) {
+ taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
+ }
+ return CalculatorOptions.newBuilder()
+ .setExtension(
+ ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext,
+ taskOptionsBuilder.build())
+ .build();
+ }
+ }
+
+ /** Creates a RectF covering the full image. */
+ private static RectF buildFullImageRectF() {
+ return new RectF(0, 0, 1, 1);
+ }
+}
diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml
new file mode 100644
index 000000000..66fa20509
--- /dev/null
+++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml
@@ -0,0 +1,24 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/BUILD
new file mode 100644
index 000000000..a7f804c64
--- /dev/null
+++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/BUILD
@@ -0,0 +1,19 @@
+# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(default_visibility = ["//mediapipe/tasks:internal"])
+
+licenses(["notice"])
+
+# TODO: Enable this in OSS
diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java
new file mode 100644
index 000000000..e02e8ebe7
--- /dev/null
+++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java
@@ -0,0 +1,445 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.vision.imageclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertThrows;
+
+import android.content.res.AssetManager;
+import android.graphics.BitmapFactory;
+import android.graphics.RectF;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import com.google.mediapipe.framework.MediaPipeException;
+import com.google.mediapipe.framework.image.BitmapImageBuilder;
+import com.google.mediapipe.framework.image.Image;
+import com.google.mediapipe.tasks.components.containers.Category;
+import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
+import com.google.mediapipe.tasks.core.BaseOptions;
+import com.google.mediapipe.tasks.core.TestUtils;
+import com.google.mediapipe.tasks.vision.core.RunningMode;
+import com.google.mediapipe.tasks.vision.imageclassifier.ImageClassifier.ImageClassifierOptions;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Suite;
+import org.junit.runners.Suite.SuiteClasses;
+
+/** Test for {@link ImageClassifier}/ */
+@RunWith(Suite.class)
+@SuiteClasses({ImageClassifierTest.General.class, ImageClassifierTest.RunningModeTest.class})
+public class ImageClassifierTest {
+ private static final String FLOAT_MODEL_FILE = "mobilenet_v2_1.0_224.tflite";
+ private static final String QUANTIZED_MODEL_FILE = "mobilenet_v1_0.25_224_quant.tflite";
+ private static final String BURGER_IMAGE = "burger.jpg";
+ private static final String MULTI_OBJECTS_IMAGE = "multi_objects.jpg";
+
+ @RunWith(AndroidJUnit4.class)
+ public static final class General extends ImageClassifierTest {
+
+ @Test
+ public void create_failsWithMissingModel() throws Exception {
+ String nonExistentFile = "/path/to/non/existent/file";
+ MediaPipeException exception =
+ assertThrows(
+ MediaPipeException.class,
+ () ->
+ ImageClassifier.createFromFile(
+ ApplicationProvider.getApplicationContext(), nonExistentFile));
+ assertThat(exception).hasMessageThat().contains(nonExistentFile);
+ }
+
+ @Test
+ public void create_failsWithInvalidModelBuffer() throws Exception {
+ // Create a non-direct model ByteBuffer.
+ ByteBuffer modelBuffer =
+ TestUtils.loadToNonDirectByteBuffer(
+ ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE);
+
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ ImageClassifier.createFromBuffer(
+ ApplicationProvider.getApplicationContext(), modelBuffer));
+
+ assertThat(exception)
+ .hasMessageThat()
+ .contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+ }
+
+ @Test
+ public void classify_succeedsWithNoOptions() throws Exception {
+ ImageClassifier imageClassifier =
+ ImageClassifier.createFromFile(
+ ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE);
+ ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
+
+ assertHasOneHeadAndOneTimestamp(results, 0);
+ assertThat(results.classifications().get(0).entries().get(0).categories()).hasSize(1001);
+ assertThat(results.classifications().get(0).entries().get(0).categories().get(0))
+ .isEqualTo(Category.create(0.7952058f, 934, "cheeseburger", ""));
+ }
+
+ @Test
+ public void classify_succeedsWithFloatModel() throws Exception {
+ ImageClassifierOptions options =
+ ImageClassifierOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
+ .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build())
+ .build();
+ ImageClassifier imageClassifier =
+ ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
+
+ assertHasOneHeadAndOneTimestamp(results, 0);
+ assertCategoriesAre(
+ results,
+ Arrays.asList(
+ Category.create(0.7952058f, 934, "cheeseburger", ""),
+ Category.create(0.027329788f, 932, "bagel", ""),
+ Category.create(0.019334773f, 925, "guacamole", "")));
+ }
+
+ @Test
+ public void classify_succeedsWithQuantizedModel() throws Exception {
+ ImageClassifierOptions options =
+ ImageClassifierOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build())
+ .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
+ .build();
+ ImageClassifier imageClassifier =
+ ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
+
+ assertHasOneHeadAndOneTimestamp(results, 0);
+ assertCategoriesAre(
+ results, Arrays.asList(Category.create(0.97265625f, 934, "cheeseburger", "")));
+ }
+
+ @Test
+ public void classify_succeedsWithScoreThreshold() throws Exception {
+ ImageClassifierOptions options =
+ ImageClassifierOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
+ .setClassifierOptions(ClassifierOptions.builder().setScoreThreshold(0.02f).build())
+ .build();
+ ImageClassifier imageClassifier =
+ ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
+
+ assertHasOneHeadAndOneTimestamp(results, 0);
+ assertCategoriesAre(
+ results,
+ Arrays.asList(
+ Category.create(0.7952058f, 934, "cheeseburger", ""),
+ Category.create(0.027329788f, 932, "bagel", "")));
+ }
+
+ @Test
+ public void classify_succeedsWithAllowlist() throws Exception {
+ ImageClassifierOptions options =
+ ImageClassifierOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
+ .setClassifierOptions(
+ ClassifierOptions.builder()
+ .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf"))
+ .build())
+ .build();
+ ImageClassifier imageClassifier =
+ ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
+
+ assertHasOneHeadAndOneTimestamp(results, 0);
+ assertCategoriesAre(
+ results,
+ Arrays.asList(
+ Category.create(0.7952058f, 934, "cheeseburger", ""),
+ Category.create(0.019334773f, 925, "guacamole", ""),
+ Category.create(0.006279315f, 963, "meat loaf", "")));
+ }
+
+ @Test
+ public void classify_succeedsWithDenylist() throws Exception {
+ ImageClassifierOptions options =
+ ImageClassifierOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
+ .setClassifierOptions(
+ ClassifierOptions.builder()
+ .setMaxResults(3)
+ .setCategoryDenylist(Arrays.asList("bagel"))
+ .build())
+ .build();
+ ImageClassifier imageClassifier =
+ ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
+
+ assertHasOneHeadAndOneTimestamp(results, 0);
+ assertCategoriesAre(
+ results,
+ Arrays.asList(
+ Category.create(0.7952058f, 934, "cheeseburger", ""),
+ Category.create(0.019334773f, 925, "guacamole", ""),
+ Category.create(0.006279315f, 963, "meat loaf", "")));
+ }
+
+ @Test
+ public void classify_succeedsWithRegionOfInterest() throws Exception {
+ ImageClassifierOptions options =
+ ImageClassifierOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
+ .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
+ .build();
+ ImageClassifier imageClassifier =
+ ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ // RectF around the soccer ball.
+ RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f);
+ ImageClassificationResult results =
+ imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), roi);
+
+ assertHasOneHeadAndOneTimestamp(results, 0);
+ assertCategoriesAre(
+ results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", "")));
+ }
+ }
+
+ @RunWith(AndroidJUnit4.class)
+ public static final class RunningModeTest extends ImageClassifierTest {
+
+ @Test
+ public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception {
+ for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) {
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ ImageClassifierOptions.builder()
+ .setBaseOptions(
+ BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
+ .setRunningMode(mode)
+ .setResultListener((imageClassificationResult, inputImage) -> {})
+ .build());
+ assertThat(exception)
+ .hasMessageThat()
+ .contains("a user-defined result listener shouldn't be provided");
+ }
+ }
+
+ @Test
+ public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ ImageClassifierOptions.builder()
+ .setBaseOptions(
+ BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
+ .setRunningMode(RunningMode.LIVE_STREAM)
+ .build());
+ assertThat(exception)
+ .hasMessageThat()
+ .contains("a user-defined result listener must be provided");
+ }
+
+ @Test
+ public void classify_failsWithCallingWrongApiInImageMode() throws Exception {
+ ImageClassifierOptions options =
+ ImageClassifierOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
+ .setRunningMode(RunningMode.IMAGE)
+ .build();
+
+ ImageClassifier imageClassifier =
+ ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ MediaPipeException exception =
+ assertThrows(
+ MediaPipeException.class,
+ () -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0));
+ assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
+ exception =
+ assertThrows(
+ MediaPipeException.class,
+ () -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0));
+ assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
+ }
+
+ @Test
+ public void classify_failsWithCallingWrongApiInVideoMode() throws Exception {
+ ImageClassifierOptions options =
+ ImageClassifierOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
+ .setRunningMode(RunningMode.VIDEO)
+ .build();
+
+ ImageClassifier imageClassifier =
+ ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ MediaPipeException exception =
+ assertThrows(
+ MediaPipeException.class,
+ () -> imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)));
+ assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
+ exception =
+ assertThrows(
+ MediaPipeException.class,
+ () -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0));
+ assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
+ }
+
+ @Test
+ public void classify_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
+ ImageClassifierOptions options =
+ ImageClassifierOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
+ .setRunningMode(RunningMode.LIVE_STREAM)
+ .setResultListener((imageClassificationResult, inputImage) -> {})
+ .build();
+
+ ImageClassifier imageClassifier =
+ ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+
+ MediaPipeException exception =
+ assertThrows(
+ MediaPipeException.class,
+ () -> imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)));
+ assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
+ exception =
+ assertThrows(
+ MediaPipeException.class,
+ () -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0));
+ assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
+ }
+
+ @Test
+ public void classify_succeedsWithImageMode() throws Exception {
+ ImageClassifierOptions options =
+ ImageClassifierOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
+ .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
+ .build();
+ ImageClassifier imageClassifier =
+ ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
+
+ assertHasOneHeadAndOneTimestamp(results, 0);
+ assertCategoriesAre(
+ results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
+ }
+
+ @Test
+ public void classify_succeedsWithVideoMode() throws Exception {
+ Image image = getImageFromAsset(BURGER_IMAGE);
+ ImageClassifierOptions options =
+ ImageClassifierOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
+ .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
+ .setRunningMode(RunningMode.VIDEO)
+ .build();
+ ImageClassifier imageClassifier =
+ ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ for (int i = 0; i < 3; i++) {
+ ImageClassificationResult results = imageClassifier.classifyForVideo(image, i);
+ assertHasOneHeadAndOneTimestamp(results, i);
+ assertCategoriesAre(
+ results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
+ }
+ }
+
+ @Test
+ public void classify_failsWithOutOfOrderInputTimestamps() throws Exception {
+ Image image = getImageFromAsset(BURGER_IMAGE);
+ ImageClassifierOptions options =
+ ImageClassifierOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
+ .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
+ .setRunningMode(RunningMode.LIVE_STREAM)
+ .setResultListener(
+ (imageClassificationResult, inputImage) -> {
+ assertCategoriesAre(
+ imageClassificationResult,
+ Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
+ assertImageSizeIsExpected(inputImage);
+ })
+ .build();
+ try (ImageClassifier imageClassifier =
+ ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
+ imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 1);
+ MediaPipeException exception =
+ assertThrows(MediaPipeException.class, () -> imageClassifier.classifyAsync(image, 0));
+ assertThat(exception)
+ .hasMessageThat()
+ .contains("having a smaller timestamp than the processed timestamp");
+ }
+ }
+
+ @Test
+ public void classify_succeedsWithLiveStreamMode() throws Exception {
+ Image image = getImageFromAsset(BURGER_IMAGE);
+ ImageClassifierOptions options =
+ ImageClassifierOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
+ .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
+ .setRunningMode(RunningMode.LIVE_STREAM)
+ .setResultListener(
+ (imageClassificationResult, inputImage) -> {
+ assertCategoriesAre(
+ imageClassificationResult,
+ Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
+ assertImageSizeIsExpected(inputImage);
+ })
+ .build();
+ try (ImageClassifier imageClassifier =
+ ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
+ for (int i = 0; i < 3; ++i) {
+ imageClassifier.classifyAsync(image, i);
+ }
+ }
+ }
+ }
+
+ private static Image getImageFromAsset(String filePath) throws Exception {
+ AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
+ InputStream istr = assetManager.open(filePath);
+ return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
+ }
+
+ private static void assertHasOneHeadAndOneTimestamp(
+ ImageClassificationResult results, long timestampMs) {
+ assertThat(results.classifications()).hasSize(1);
+ assertThat(results.classifications().get(0).headIndex()).isEqualTo(0);
+ assertThat(results.classifications().get(0).headName()).isEqualTo("probability");
+ assertThat(results.classifications().get(0).entries()).hasSize(1);
+ assertThat(results.classifications().get(0).entries().get(0).timestampMs())
+ .isEqualTo(timestampMs);
+ }
+
+ private static void assertCategoriesAre(
+ ImageClassificationResult results, List categories) {
+ assertThat(results.classifications().get(0).entries().get(0).categories())
+ .hasSize(categories.size());
+ for (int i = 0; i < categories.size(); i++) {
+ assertThat(results.classifications().get(0).entries().get(0).categories().get(i))
+ .isEqualTo(categories.get(i));
+ }
+ }
+
+ private static void assertImageSizeIsExpected(Image inputImage) {
+ assertThat(inputImage).isNotNull();
+ assertThat(inputImage.getWidth()).isEqualTo(480);
+ assertThat(inputImage.getHeight()).isEqualTo(325);
+ }
+}