From 51a760608380f895570304b6ad4014ce01a94046 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 12 Oct 2022 10:19:45 -0700 Subject: [PATCH] Add Java ImageClassifier API. PiperOrigin-RevId: 480656683 --- .../imageclassifier/AndroidManifest.xml | 8 + .../tasks/vision/imageclassifier/BUILD | 46 ++ .../ImageClassificationResult.java | 102 ++++ .../imageclassifier/ImageClassifier.java | 456 ++++++++++++++++++ .../imageclassifier/AndroidManifest.xml | 24 + .../tasks/vision/imageclassifier/BUILD | 19 + .../imageclassifier/ImageClassifierTest.java | 445 +++++++++++++++++ 7 files changed, 1100 insertions(+) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/BUILD create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/BUILD create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml new file mode 100644 index 000000000..e257ddc42 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/BUILD new file mode 100644 index 000000000..cecd9f521 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/BUILD @@ -0,0 +1,46 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +android_library( + name = "imageclassifier", + srcs = [ + "ImageClassificationResult.java", + "ImageClassifier.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = ":AndroidManifest.xml", + deps = [ + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", + "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java new file mode 100644 index 000000000..09f854caa --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java @@ -0,0 +1,102 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageclassifier; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.container.proto.CategoryProto; +import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto; +import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.ClassificationEntry; +import com.google.mediapipe.tasks.components.containers.Classifications; +import com.google.mediapipe.tasks.core.TaskResult; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** Represents the classification results generated by {@link ImageClassifier}. */ +@AutoValue +public abstract class ImageClassificationResult implements TaskResult { + + /** + * Creates an {@link ImageClassificationResult} instance from a {@link + * ClassificationsProto.ClassificationResult} protobuf message. + * + * @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf + * message. + * @param timestampMs a timestamp for this result. + */ + // TODO: consolidate output formats across platforms. + static ImageClassificationResult create( + ClassificationsProto.ClassificationResult classificationResult, long timestampMs) { + List classifications = new ArrayList<>(); + for (ClassificationsProto.Classifications classificationsProto : + classificationResult.getClassificationsList()) { + classifications.add(classificationsFromProto(classificationsProto)); + } + return new AutoValue_ImageClassificationResult( + timestampMs, Collections.unmodifiableList(classifications)); + } + + @Override + public abstract long timestampMs(); + + /** Contains one set of results per classifier head. */ + public abstract List classifications(); + + /** + * Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object. + * + * @param category the {@link CategoryProto.Category} protobuf message to convert. + */ + static Category categoryFromProto(CategoryProto.Category category) { + return Category.create( + category.getScore(), + category.getIndex(), + category.getCategoryName(), + category.getDisplayName()); + } + + /** + * Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link + * ClassificationEntry} object. + * + * @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert. + */ + static ClassificationEntry classificationEntryFromProto( + ClassificationsProto.ClassificationEntry entry) { + List categories = new ArrayList<>(); + for (CategoryProto.Category category : entry.getCategoriesList()) { + categories.add(categoryFromProto(category)); + } + return ClassificationEntry.create(categories, entry.getTimestampMs()); + } + + /** + * Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link + * Classifications} object. + * + * @param classifications the {@link ClassificationsProto.Classifications} protobuf message to + * convert. + */ + static Classifications classificationsFromProto( + ClassificationsProto.Classifications classifications) { + List entries = new ArrayList<>(); + for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) { + entries.add(classificationEntryFromProto(entry)); + } + return Classifications.create( + entries, classifications.getHeadIndex(), classifications.getHeadName()); + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java new file mode 100644 index 000000000..68cae151f --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -0,0 +1,456 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageclassifier; + +import android.content.Context; +import android.graphics.RectF; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.ProtoUtil; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto; +import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imageclassifier.proto.ImageClassifierGraphOptionsProto; +import com.google.protobuf.InvalidProtocolBufferException; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs classification on images. + * + *

The API expects a TFLite model with optional, but strongly recommended, TFLite Model Metadata.. + * + *

The API supports models with one image input tensor and one or more output tensors. To be more + * specific, here are the requirements. + * + *

    + *
  • Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) + *
      + *
    • image input of size {@code [batch x height x width x channels]}. + *
    • batch inference is not supported ({@code batch} is required to be 1). + *
    • only RGB inputs are supported ({@code channels} is required to be 3). + *
    • if type is kTfLiteFloat32, NormalizationOptions are required to be attached to the + * metadata for input normalization. + *
    + *
  • At least one output tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) with: + *
      + *
    • {@code N} classes and either 2 or 4 dimensions, i.e. {@code [1 x N]} or {@code [1 x 1 + * x 1 x N]} + *
    • optional (but recommended) label map(s) as AssociatedFile-s with type + * TENSOR_AXIS_LABELS, containing one label per line. The first such AssociatedFile (if + * any) is used to fill the {@code class_name} field of the results. The {@code + * display_name} field is filled from the AssociatedFile (if any) whose locale matches + * the {@code display_names_locale} field of the {@code ImageClassifierOptions} used at + * creation time ("en" by default, i.e. English). If none of these are available, only + * the {@code index} field of the results will be filled. + *
    • optional score calibration can be attached using ScoreCalibrationOptions and an + * AssociatedFile with type TENSOR_AXIS_SCORE_CALIBRATION. See + * metadata_schema.fbs for more details. + *
    + *
+ * + *

An example of such model can be found + * TensorFlow Hub. + */ +public final class ImageClassifier extends BaseVisionTaskApi { + private static final String TAG = ImageClassifier.class.getSimpleName(); + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; + private static final List INPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList("CLASSIFICATION_RESULT:classification_result_out", "IMAGE:image_out")); + private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0; + private static final int IMAGE_OUT_STREAM_INDEX = 1; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; + + static { + ProtoUtil.registerTypeName( + ClassificationsProto.ClassificationResult.class, + "mediapipe.tasks.components.containers.proto.ClassificationResult"); + } + + /** + * Creates an {@link ImageClassifier} instance from a model file and default {@link + * ImageClassifierOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the classification model in the assets. + * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation. + */ + public static ImageClassifier createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, ImageClassifierOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link ImageClassifier} instance from a model file and default {@link + * ImageClassifierOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the classification model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation. + */ + public static ImageClassifier createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, ImageClassifierOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates an {@link ImageClassifier} instance from a model buffer and default {@link + * ImageClassifierOptions}. + * + * @param context an Android {@link Context}. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * classification model. + * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation. + */ + public static ImageClassifier createFromBuffer(Context context, final ByteBuffer modelBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build(); + return createFromOptions( + context, ImageClassifierOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link ImageClassifier} instance from an {@link ImageClassifierOptions} instance. + * + * @param context an Android {@link Context}. + * @param options an {@link ImageClassifierOptions} instance. + * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation. + */ + public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) { + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public ImageClassificationResult convertToTaskResult(List packets) { + try { + return ImageClassificationResult.create( + PacketGetter.getProto( + packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX), + ClassificationsProto.ClassificationResult.getDefaultInstance()), + packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp()); + } catch (InvalidProtocolBufferException e) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); + } + } + + @Override + public Image convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + options.resultListener().ifPresent(handler::setResultListener); + options.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(options) + .setEnableFlowLimiting(options.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + return new ImageClassifier(runner, options.runningMode()); + } + + /** + * Constructor to initialize an {@link ImageClassifier} from a {@link TaskRunner} and {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private ImageClassifier(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); + } + + /** + * Performs classification on the provided single image. Only use this method when the {@link + * ImageClassifier} is created with {@link RunningMode.IMAGE}. + * + *

{@link ImageClassifier} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public ImageClassificationResult classify(Image inputImage) { + return (ImageClassificationResult) processImageData(inputImage, buildFullImageRectF()); + } + + /** + * Performs classification on the provided single image and region-of-interest. Only use this + * method when the {@link ImageClassifier} is created with {@link RunningMode.IMAGE}. + * + *

{@link ImageClassifier} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @param roi a {@link RectF} specifying the region of interest on which to perform + * classification. Coordinates are expected to be specified as normalized values in [0,1]. + * @throws MediaPipeException if there is an internal error. + */ + public ImageClassificationResult classify(Image inputImage, RectF roi) { + return (ImageClassificationResult) processImageData(inputImage, roi); + } + + /** + * Performs classification on the provided video frame. Only use this method when the {@link + * ImageClassifier} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageClassifier} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputTimestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ImageClassificationResult classifyForVideo(Image inputImage, long inputTimestampMs) { + return (ImageClassificationResult) + processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs); + } + + /** + * Performs classification on the provided video frame with additional region-of-interest. Only + * use this method when the {@link ImageClassifier} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageClassifier} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @param roi a {@link RectF} specifying the region of interest on which to perform + * classification. Coordinates are expected to be specified as normalized values in [0,1]. + * @param inputTimestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ImageClassificationResult classifyForVideo( + Image inputImage, RectF roi, long inputTimestampMs) { + return (ImageClassificationResult) processVideoData(inputImage, roi, inputTimestampMs); + } + + /** + * Sends live image data to perform classification, and the results will be available via the + * {@link ResultListener} provided in the {@link ImageClassifierOptions}. Only use this method + * when the {@link ImageClassifier} is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

{@link ImageClassifier} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputTimestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void classifyAsync(Image inputImage, long inputTimestampMs) { + sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs); + } + + /** + * Sends live image data and additional region-of-interest to perform classification, and the + * results will be available via the {@link ResultListener} provided in the {@link + * ImageClassifierOptions}. Only use this method when the {@link ImageClassifier} is created with + * {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

{@link ImageClassifier} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @param roi a {@link RectF} specifying the region of interest on which to perform + * classification. Coordinates are expected to be specified as normalized values in [0,1]. + * @param inputTimestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void classifyAsync(Image inputImage, RectF roi, long inputTimestampMs) { + sendLiveStreamData(inputImage, roi, inputTimestampMs); + } + + /** Options for setting up and {@link ImageClassifier}. */ + @AutoValue + public abstract static class ImageClassifierOptions extends TaskOptions { + + /** Builder for {@link ImageClassifierOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the {@link BaseOptions} for the image classifier task. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + + /** + * Sets the {@link RunningMode} for the image classifier task. Default to the image mode. + * Image classifier has three modes: + * + *
    + *
  • IMAGE: The mode for performing classification on single image inputs. + *
  • VIDEO: The mode for performing classification on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for for performing classification on a live stream of input + * data, such as from camera. In this mode, {@code setResultListener} must be called to + * set up a listener to receive the classification results asynchronously. + *
+ */ + public abstract Builder setRunningMode(RunningMode runningMode); + + /** + * Sets the optional {@link ClassifierOptions} controling classification behavior, such as + * score threshold, number of results, etc. + */ + public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + + /** + * Sets the {@link ResultListener} to receive the classification results asynchronously when + * the image classifier is in the live stream mode. + */ + public abstract Builder setResultListener( + ResultListener resultListener); + + /** Sets an optional {@link ErrorListener}. */ + public abstract Builder setErrorListener(ErrorListener errorListener); + + abstract ImageClassifierOptions autoBuild(); + + /** + * Validates and builds the {@link ImageClassifierOptions} instance. * + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the image classifier + * is in the live stream mode. + */ + public final ImageClassifierOptions build() { + ImageClassifierOptions options = autoBuild(); + if (options.runningMode() == RunningMode.LIVE_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The image classifier is in the live stream mode, a user-defined result listener" + + " must be provided in the ImageClassifierOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The image classifier is in the image or video mode, a user-defined result listener" + + " shouldn't be provided in ImageClassifierOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract Optional classifierOptions(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder() + .setRunningMode(RunningMode.IMAGE); + } + + /** + * Converts a {@link ImageClassifierOptions} to a {@link CalculatorOptions} protobuf message. + */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder(); + baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); + baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder = + ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder); + if (classifierOptions().isPresent()) { + taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); + } + return CalculatorOptions.newBuilder() + .setExtension( + ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } + + /** Creates a RectF covering the full image. */ + private static RectF buildFullImageRectF() { + return new RectF(0, 0, 1, 1); + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml new file mode 100644 index 000000000..66fa20509 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/BUILD new file mode 100644 index 000000000..a7f804c64 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java new file mode 100644 index 000000000..e02e8ebe7 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java @@ -0,0 +1,445 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageclassifier; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.BitmapFactory; +import android.graphics.RectF; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.TestUtils; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imageclassifier.ImageClassifier.ImageClassifierOptions; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link ImageClassifier}/ */ +@RunWith(Suite.class) +@SuiteClasses({ImageClassifierTest.General.class, ImageClassifierTest.RunningModeTest.class}) +public class ImageClassifierTest { + private static final String FLOAT_MODEL_FILE = "mobilenet_v2_1.0_224.tflite"; + private static final String QUANTIZED_MODEL_FILE = "mobilenet_v1_0.25_224_quant.tflite"; + private static final String BURGER_IMAGE = "burger.jpg"; + private static final String MULTI_OBJECTS_IMAGE = "multi_objects.jpg"; + + @RunWith(AndroidJUnit4.class) + public static final class General extends ImageClassifierTest { + + @Test + public void create_failsWithMissingModel() throws Exception { + String nonExistentFile = "/path/to/non/existent/file"; + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + ImageClassifier.createFromFile( + ApplicationProvider.getApplicationContext(), nonExistentFile)); + assertThat(exception).hasMessageThat().contains(nonExistentFile); + } + + @Test + public void create_failsWithInvalidModelBuffer() throws Exception { + // Create a non-direct model ByteBuffer. + ByteBuffer modelBuffer = + TestUtils.loadToNonDirectByteBuffer( + ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifier.createFromBuffer( + ApplicationProvider.getApplicationContext(), modelBuffer)); + + assertThat(exception) + .hasMessageThat() + .contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + + @Test + public void classify_succeedsWithNoOptions() throws Exception { + ImageClassifier imageClassifier = + ImageClassifier.createFromFile( + ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE); + ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertThat(results.classifications().get(0).entries().get(0).categories()).hasSize(1001); + assertThat(results.classifications().get(0).entries().get(0).categories().get(0)) + .isEqualTo(Category.create(0.7952058f, 934, "cheeseburger", "")); + } + + @Test + public void classify_succeedsWithFloatModel() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, + Arrays.asList( + Category.create(0.7952058f, 934, "cheeseburger", ""), + Category.create(0.027329788f, 932, "bagel", ""), + Category.create(0.019334773f, 925, "guacamole", ""))); + } + + @Test + public void classify_succeedsWithQuantizedModel() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, Arrays.asList(Category.create(0.97265625f, 934, "cheeseburger", ""))); + } + + @Test + public void classify_succeedsWithScoreThreshold() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setScoreThreshold(0.02f).build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, + Arrays.asList( + Category.create(0.7952058f, 934, "cheeseburger", ""), + Category.create(0.027329788f, 932, "bagel", ""))); + } + + @Test + public void classify_succeedsWithAllowlist() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions( + ClassifierOptions.builder() + .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf")) + .build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, + Arrays.asList( + Category.create(0.7952058f, 934, "cheeseburger", ""), + Category.create(0.019334773f, 925, "guacamole", ""), + Category.create(0.006279315f, 963, "meat loaf", ""))); + } + + @Test + public void classify_succeedsWithDenylist() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions( + ClassifierOptions.builder() + .setMaxResults(3) + .setCategoryDenylist(Arrays.asList("bagel")) + .build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, + Arrays.asList( + Category.create(0.7952058f, 934, "cheeseburger", ""), + Category.create(0.019334773f, 925, "guacamole", ""), + Category.create(0.006279315f, 963, "meat loaf", ""))); + } + + @Test + public void classify_succeedsWithRegionOfInterest() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + // RectF around the soccer ball. + RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f); + ImageClassificationResult results = + imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), roi); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", ""))); + } + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends ImageClassifierTest { + + @Test + public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception { + for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setRunningMode(mode) + .setResultListener((imageClassificationResult, inputImage) -> {}) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener shouldn't be provided"); + } + } + + @Test + public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener must be provided"); + } + + @Test + public void classify_failsWithCallingWrongApiInImageMode() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void classify_failsWithCallingWrongApiInVideoMode() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> imageClassifier.classify(getImageFromAsset(BURGER_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void classify_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((imageClassificationResult, inputImage) -> {}) + .build(); + + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> imageClassifier.classify(getImageFromAsset(BURGER_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void classify_succeedsWithImageMode() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); + } + + @Test + public void classify_succeedsWithVideoMode() throws Exception { + Image image = getImageFromAsset(BURGER_IMAGE); + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + for (int i = 0; i < 3; i++) { + ImageClassificationResult results = imageClassifier.classifyForVideo(image, i); + assertHasOneHeadAndOneTimestamp(results, i); + assertCategoriesAre( + results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); + } + } + + @Test + public void classify_failsWithOutOfOrderInputTimestamps() throws Exception { + Image image = getImageFromAsset(BURGER_IMAGE); + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (imageClassificationResult, inputImage) -> { + assertCategoriesAre( + imageClassificationResult, + Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 1); + MediaPipeException exception = + assertThrows(MediaPipeException.class, () -> imageClassifier.classifyAsync(image, 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + + @Test + public void classify_succeedsWithLiveStreamMode() throws Exception { + Image image = getImageFromAsset(BURGER_IMAGE); + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (imageClassificationResult, inputImage) -> { + assertCategoriesAre( + imageClassificationResult, + Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; ++i) { + imageClassifier.classifyAsync(image, i); + } + } + } + } + + private static Image getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + private static void assertHasOneHeadAndOneTimestamp( + ImageClassificationResult results, long timestampMs) { + assertThat(results.classifications()).hasSize(1); + assertThat(results.classifications().get(0).headIndex()).isEqualTo(0); + assertThat(results.classifications().get(0).headName()).isEqualTo("probability"); + assertThat(results.classifications().get(0).entries()).hasSize(1); + assertThat(results.classifications().get(0).entries().get(0).timestampMs()) + .isEqualTo(timestampMs); + } + + private static void assertCategoriesAre( + ImageClassificationResult results, List categories) { + assertThat(results.classifications().get(0).entries().get(0).categories()) + .hasSize(categories.size()); + for (int i = 0; i < categories.size(); i++) { + assertThat(results.classifications().get(0).entries().get(0).categories().get(i)) + .isEqualTo(categories.get(i)); + } + } + + private static void assertImageSizeIsExpected(Image inputImage) { + assertThat(inputImage).isNotNull(); + assertThat(inputImage.getWidth()).isEqualTo(480); + assertThat(inputImage.getHeight()).isEqualTo(325); + } +}