diff --git a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto index 4adba5ab7..72b3e7ee3 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto @@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +option java_package = "com.google.mediapipe.tasks.vision.imageembedder.proto"; +option java_outer_classname = "ImageEmbedderGraphOptionsProto"; + message ImageEmbedderGraphOptions { extend mediapipe.CalculatorOptions { optional ImageEmbedderGraphOptions ext = 476348187; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 2b648bc43..8b09260bd 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -42,6 +42,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 4dc4a547e..289e3000d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -43,6 +43,7 @@ cc_binary( "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", + "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", ], @@ -172,6 +173,34 @@ android_library( ], ) +android_library( + name = "imageembedder", + srcs = [ + "imageembedder/ImageEmbedder.java", + "imageembedder/ImageEmbedderResult.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "imageembedder/AndroidManifest.xml", + deps = [ + ":core", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar") mediapipe_tasks_vision_aar( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml new file mode 100644 index 000000000..ebdb037d6 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java new file mode 100644 index 000000000..0d8ecd5c3 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java @@ -0,0 +1,448 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageembedder; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.ProtoUtil; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.Embedding; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.components.utils.CosineSimilarity; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imageembedder.proto.ImageEmbedderGraphOptionsProto; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs embedding extraction on images. + * + *

The API expects a TFLite model with optional, but strongly recommended, TFLite Model Metadata.. + * + *

The API supports models with one image input tensor and one or more output tensors. To be more + * specific, here are the requirements. + * + *

+ */ +public final class ImageEmbedder extends BaseVisionTaskApi { + private static final String TAG = ImageEmbedder.class.getSimpleName(); + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; + private static final List INPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("EMBEDDINGS:embeddings_out", "IMAGE:image_out")); + private static final int EMBEDDINGS_OUT_STREAM_INDEX = 0; + private static final int IMAGE_OUT_STREAM_INDEX = 1; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"; + + static { + ProtoUtil.registerTypeName( + EmbeddingsProto.EmbeddingResult.class, + "mediapipe.tasks.components.containers.proto.EmbeddingResult"); + } + + /** + * Creates an {@link ImageEmbedder} instance from a model file and default {@link + * ImageEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the embedding model in the assets. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link ImageEmbedder} instance from a model file and default {@link + * ImageEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the embedding model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates an {@link ImageEmbedder} instance from a model buffer and default {@link + * ImageEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the embedding + * model. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromBuffer(Context context, final ByteBuffer modelBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build(); + return createFromOptions( + context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link ImageEmbedder} instance from an {@link ImageEmbedderOptions} instance. + * + * @param context an Android {@link Context}. + * @param options an {@link ImageEmbedderOptions} instance. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromOptions(Context context, ImageEmbedderOptions options) { + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public ImageEmbedderResult convertToTaskResult(List packets) { + try { + return ImageEmbedderResult.create( + EmbeddingResult.createFromProto( + PacketGetter.getProto( + packets.get(EMBEDDINGS_OUT_STREAM_INDEX), + EmbeddingsProto.EmbeddingResult.getDefaultInstance())), + BaseVisionTaskApi.generateResultTimestampMs( + options.runningMode(), packets.get(EMBEDDINGS_OUT_STREAM_INDEX))); + } catch (IOException e) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); + } + } + + @Override + public MPImage convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + options.resultListener().ifPresent(handler::setResultListener); + options.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(options) + .setEnableFlowLimiting(options.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + return new ImageEmbedder(runner, options.runningMode()); + } + + /** + * Constructor to initialize an {@link ImageEmbedder} from a {@link TaskRunner} and {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private ImageEmbedder(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); + } + + /** + * Performs embedding extraction on the provided single image with default image processing + * options, i.e. using the whole image as region-of-interest and without any rotation applied. + * Only use this method when the {@link ImageEmbedder} is created with {@link RunningMode.IMAGE}. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embed(MPImage image) { + return embed(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs embedding extraction on the provided single image. Only use this method when the + * {@link ImageEmbedder} is created with {@link RunningMode.IMAGE}. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embed(MPImage image, ImageProcessingOptions imageProcessingOptions) { + return (ImageEmbedderResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs embedding extraction on the provided video frame with default image processing + * options, i.e. using the whole image as region-of-interest and without any rotation applied. + * Only use this method when the {@link ImageEmbedder} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embedForVideo(MPImage image, long timestampMs) { + return embedForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Performs embedding extraction on the provided video frame. Only use this method when the {@link + * ImageEmbedder} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embedForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + return (ImageEmbedderResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform embedding extraction with default image processing options, + * i.e. using the whole image as region-of-interest and without any rotation applied, and the + * results will be available via the {@link ResultListener} provided in the {@link + * ImageEmbedderOptions}. Only use this method when the {@link ImageEmbedder} is created with + * {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void embedAsync(MPImage image, long timestampMs) { + embedAsync(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Sends live image data to perform embedding extraction, and the results will be available via + * the {@link ResultListener} provided in the {@link ImageEmbedderOptions}. Only use this method + * when the {@link ImageEmbedder} is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void embedAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + sendLiveStreamData(image, imageProcessingOptions, timestampMs); + } + + /** + * Utility function to compute cosine + * similarity between two {@link Embedding} objects. + * + * @throws IllegalArgumentException if the embeddings are of different types (float vs. + * quantized), have different sizes, or have an L2-norm of 0. + */ + public static double cosineSimilarity(Embedding u, Embedding v) { + return CosineSimilarity.compute(u, v); + } + + /** Options for setting up and {@link ImageEmbedder}. */ + @AutoValue + public abstract static class ImageEmbedderOptions extends TaskOptions { + + /** Builder for {@link ImageEmbedderOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the {@link BaseOptions} for the image embedder task. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + + /** + * Sets the {@link RunningMode} for the image embedder task. Default to the image mode. Image + * embedder has three modes: + * + *
    + *
  • IMAGE: The mode for performing embedding extraction on single image inputs. + *
  • VIDEO: The mode for performing embedding extraction on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for for performing embedding extraction on a live stream of + * input data, such as from camera. In this mode, {@code setResultListener} must be + * called to set up a listener to receive the embedding results asynchronously. + *
+ */ + public abstract Builder setRunningMode(RunningMode runningMode); + + /** + * Sets the optional {@link EmbedderOptions} controling embedder behavior, such as + * L2-normalization and scalar quantization. + */ + public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); + + /** + * Sets the {@link ResultListener} to receive the embedding results asynchronously when the + * image embedder is in the live stream mode. + */ + public abstract Builder setResultListener( + ResultListener resultListener); + + /** Sets an optional {@link ErrorListener}. */ + public abstract Builder setErrorListener(ErrorListener errorListener); + + abstract ImageEmbedderOptions autoBuild(); + + /** + * Validates and builds the {@link ImageEmbedderOptions} instance. * + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the image embedder is + * in the live stream mode. + */ + public final ImageEmbedderOptions build() { + ImageEmbedderOptions options = autoBuild(); + if (options.runningMode() == RunningMode.LIVE_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The image embedder is in the live stream mode, a user-defined result listener" + + " must be provided in the ImageEmbedderOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The image embedder is in the image or video mode, a user-defined result listener" + + " shouldn't be provided in ImageEmbedderOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract Optional embedderOptions(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_ImageEmbedder_ImageEmbedderOptions.Builder() + .setRunningMode(RunningMode.IMAGE); + } + + /** Converts a {@link ImageEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder(); + baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); + baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.Builder taskOptionsBuilder = + ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder); + if (embedderOptions().isPresent()) { + taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto()); + } + return CalculatorOptions.newBuilder() + .setExtension( + ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java new file mode 100644 index 000000000..ee3f4abc9 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java @@ -0,0 +1,54 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageembedder; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.core.TaskResult; + +/** Represents the embedding results generated by {@link ImageEmbedder}. */ +@AutoValue +public abstract class ImageEmbedderResult implements TaskResult { + + /** + * Creates an {@link ImageEmbedderResult} instance. + * + * @param embeddingResult the {@link EmbeddingResult} object containing one embedding per embedder + * head. + * @param timestampMs a timestamp for this result. + */ + static ImageEmbedderResult create(EmbeddingResult embeddingResult, long timestampMs) { + return new AutoValue_ImageEmbedderResult(embeddingResult, timestampMs); + } + + /** + * Creates an {@link ImageEmbedderResult} instance from a {@link EmbeddingsProto.EmbeddingResult} + * protobuf message. + * + * @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert. + * @param timestampMs a timestamp for this result. + */ + static ImageEmbedderResult createFromProto( + EmbeddingsProto.EmbeddingResult proto, long timestampMs) { + return create(EmbeddingResult.createFromProto(proto), timestampMs); + } + + /** Contains one embedding per embedder head. */ + public abstract EmbeddingResult embeddingResult(); + + @Override + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml new file mode 100644 index 000000000..db303a439 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD new file mode 100644 index 000000000..a7f804c64 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java new file mode 100644 index 000000000..56249ead9 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java @@ -0,0 +1,444 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageembedder; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.BitmapFactory; +import android.graphics.RectF; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.TestUtils; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imageembedder.ImageEmbedder.ImageEmbedderOptions; +import java.io.InputStream; +import java.nio.ByteBuffer; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link ImageEmbedder}/ */ +@RunWith(Suite.class) +@SuiteClasses({ImageEmbedderTest.General.class, ImageEmbedderTest.RunningModeTest.class}) +public class ImageEmbedderTest { + private static final String MOBILENET_EMBEDDER = "mobilenet_v3_small_100_224_embedder.tflite"; + private static final String BURGER_IMAGE = "burger.jpg"; + private static final String BURGER_CROP_IMAGE = "burger_crop.jpg"; + private static final String BURGER_ROTATED_IMAGE = "burger_rotated.jpg"; + + private static final double DOUBLE_DIFF_TOLERANCE = 1e-4; + + @RunWith(AndroidJUnit4.class) + public static final class General extends ImageEmbedderTest { + + @Test + public void create_failsWithMissingModel() throws Exception { + String nonExistentFile = "/path/to/non/existent/file"; + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), nonExistentFile)); + assertThat(exception).hasMessageThat().contains(nonExistentFile); + } + + @Test + public void create_failsWithInvalidModelBuffer() throws Exception { + // Create a non-direct model ByteBuffer. + ByteBuffer modelBuffer = + TestUtils.loadToNonDirectByteBuffer( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageEmbedder.createFromBuffer( + ApplicationProvider.getApplicationContext(), modelBuffer)); + + assertThat(exception) + .hasMessageThat() + .contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + + @Test + public void embed_succeedsWithNoOptions() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272); + } + + @Test + public void embed_succeedsWithL2Normalization() throws Exception { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + EmbedderOptions embedderOptions = EmbedderOptions.builder().setL2Normalize(true).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setEmbedderOptions(embedderOptions) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272); + } + + @Test + public void embed_succeedsWithQuantization() throws Exception { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + EmbedderOptions embedderOptions = EmbedderOptions.builder().setQuantize(true).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setEmbedderOptions(embedderOptions) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ true); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ true); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.926776); + } + + @Test + public void embed_succeedsWithRegionOfInterest() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + // RectF around the region in "burger.jpg" corresponding to "burger_crop.jpg". + RectF roi = new RectF(0.0f, 0.0f, 0.833333f, 1.0f); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(roi).build(); + ImageEmbedderResult resultRoi = + imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE), imageProcessingOptions); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(resultRoi, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + resultRoi.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999931f); + } + + @Test + public void embed_succeedsWithRotation() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRotationDegrees(-90).build(); + ImageEmbedderResult resultRotated = + imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultRotated, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultRotated.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.571648426f); + } + + @Test + public void embed_succeedsWithRegionOfInterestAndRotation() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + // RectF around the region in "burger_rotated.jpg" corresponding to "burger_crop.jpg". + RectF roi = new RectF(0.0f, 0.0f, 1.0f, 0.833333f); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build(); + ImageEmbedderResult resultRoiRotated = + imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(resultRoiRotated, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + resultRoiRotated.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.62780395f); + } + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends ImageEmbedderTest { + + @Test + public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception { + for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageEmbedderOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(mode) + .setResultListener((result, inputImage) -> {}) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener shouldn't be provided"); + } + } + + @Test + public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageEmbedderOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener must be provided"); + } + + @Test + public void embed_failsWithCallingWrongApiInImageMode() throws Exception { + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void embed_failsWithCallingWrongApiInVideoMode() throws Exception { + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, () -> imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void embed_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((imageClassificationResult, inputImage) -> {}) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + + MediaPipeException exception = + assertThrows( + MediaPipeException.class, () -> imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void embed_succeedsWithImageMode() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272); + } + + @Test + public void embed_succeedsWithVideoMode() throws Exception { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setRunningMode(RunningMode.VIDEO) + .build(); + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + + for (int i = 0; i < 3; ++i) { + ImageEmbedderResult result = + imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ i); + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + } + } + + @Test + public void embed_failsWithOutOfOrderInputTimestamps() throws Exception { + MPImage image = getImageFromAsset(BURGER_IMAGE); + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (imageEmbedderResult, inputImage) -> { + assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> imageEmbedder.embedAsync(image, /*timestampMs=*/ 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + + @Test + public void embed_succeedsWithLiveStreamMode() throws Exception { + MPImage image = getImageFromAsset(BURGER_IMAGE); + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (imageEmbedderResult, inputImage) -> { + assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; ++i) { + imageEmbedder.embedAsync(image, /*timestampMs=*/ i); + } + } + } + } + + private static MPImage getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + private static void assertHasOneHeadAndCorrectDimension( + ImageEmbedderResult result, boolean quantized) { + assertThat(result.embeddingResult().embeddings()).hasSize(1); + assertThat(result.embeddingResult().embeddings().get(0).headIndex()).isEqualTo(0); + assertThat(result.embeddingResult().embeddings().get(0).headName().get()).isEqualTo("feature"); + if (quantized) { + assertThat(result.embeddingResult().embeddings().get(0).quantizedEmbedding()).hasLength(1024); + } else { + assertThat(result.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(1024); + } + } + + private static void assertImageSizeIsExpected(MPImage inputImage) { + assertThat(inputImage).isNotNull(); + assertThat(inputImage.getWidth()).isEqualTo(480); + assertThat(inputImage.getHeight()).isEqualTo(325); + } +}