diff --git a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto
index 4adba5ab7..72b3e7ee3 100644
--- a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto
+++ b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto
@@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
+option java_package = "com.google.mediapipe.tasks.vision.imageembedder.proto";
+option java_outer_classname = "ImageEmbedderGraphOptionsProto";
+
message ImageEmbedderGraphOptions {
extend mediapipe.CalculatorOptions {
optional ImageEmbedderGraphOptions ext = 476348187;
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl
index 2b648bc43..8b09260bd 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl
@@ -42,6 +42,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
+ "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD
index 4dc4a547e..289e3000d 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD
@@ -43,6 +43,7 @@ cc_binary(
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
+ "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
],
@@ -172,6 +173,34 @@ android_library(
],
)
+android_library(
+ name = "imageembedder",
+ srcs = [
+ "imageembedder/ImageEmbedder.java",
+ "imageembedder/ImageEmbedderResult.java",
+ ],
+ javacopts = [
+ "-Xep:AndroidJdkLibsChecker:OFF",
+ ],
+ manifest = "imageembedder/AndroidManifest.xml",
+ deps = [
+ ":core",
+ "//mediapipe/framework:calculator_options_java_proto_lite",
+ "//mediapipe/java/com/google/mediapipe/framework:android_framework",
+ "//mediapipe/java/com/google/mediapipe/framework/image",
+ "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
+ "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
+ "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
+ "//third_party:autovalue",
+ "@maven//:com_google_guava_guava",
+ ],
+)
+
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar")
mediapipe_tasks_vision_aar(
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml
new file mode 100644
index 000000000..ebdb037d6
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java
new file mode 100644
index 000000000..0d8ecd5c3
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java
@@ -0,0 +1,448 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.vision.imageembedder;
+
+import android.content.Context;
+import android.os.ParcelFileDescriptor;
+import com.google.auto.value.AutoValue;
+import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
+import com.google.mediapipe.framework.AndroidPacketGetter;
+import com.google.mediapipe.framework.MediaPipeException;
+import com.google.mediapipe.framework.Packet;
+import com.google.mediapipe.framework.PacketGetter;
+import com.google.mediapipe.framework.ProtoUtil;
+import com.google.mediapipe.framework.image.BitmapImageBuilder;
+import com.google.mediapipe.framework.image.MPImage;
+import com.google.mediapipe.tasks.components.containers.Embedding;
+import com.google.mediapipe.tasks.components.containers.EmbeddingResult;
+import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
+import com.google.mediapipe.tasks.components.processors.EmbedderOptions;
+import com.google.mediapipe.tasks.components.utils.CosineSimilarity;
+import com.google.mediapipe.tasks.core.BaseOptions;
+import com.google.mediapipe.tasks.core.ErrorListener;
+import com.google.mediapipe.tasks.core.OutputHandler;
+import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
+import com.google.mediapipe.tasks.core.TaskInfo;
+import com.google.mediapipe.tasks.core.TaskOptions;
+import com.google.mediapipe.tasks.core.TaskRunner;
+import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
+import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
+import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
+import com.google.mediapipe.tasks.vision.core.RunningMode;
+import com.google.mediapipe.tasks.vision.imageembedder.proto.ImageEmbedderGraphOptionsProto;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * Performs embedding extraction on images.
+ *
+ *
The API expects a TFLite model with optional, but strongly recommended, TFLite Model Metadata..
+ *
+ *
The API supports models with one image input tensor and one or more output tensors. To be more
+ * specific, here are the requirements.
+ *
+ *
+ * - Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
+ *
+ * - image input of size {@code [batch x height x width x channels]}.
+ *
- batch inference is not supported ({@code batch} is required to be 1).
+ *
- only RGB inputs are supported ({@code channels} is required to be 3).
+ *
- if type is kTfLiteFloat32, NormalizationOptions are required to be attached to the
+ * metadata for input normalization.
+ *
+ * - At least one output tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) with shape {@code
+ * [1 x N]} where N is the number of dimensions in the produced embeddings.
+ *
+ */
+public final class ImageEmbedder extends BaseVisionTaskApi {
+ private static final String TAG = ImageEmbedder.class.getSimpleName();
+ private static final String IMAGE_IN_STREAM_NAME = "image_in";
+ private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
+ private static final List INPUT_STREAMS =
+ Collections.unmodifiableList(
+ Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
+ private static final List OUTPUT_STREAMS =
+ Collections.unmodifiableList(Arrays.asList("EMBEDDINGS:embeddings_out", "IMAGE:image_out"));
+ private static final int EMBEDDINGS_OUT_STREAM_INDEX = 0;
+ private static final int IMAGE_OUT_STREAM_INDEX = 1;
+ private static final String TASK_GRAPH_NAME =
+ "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph";
+
+ static {
+ ProtoUtil.registerTypeName(
+ EmbeddingsProto.EmbeddingResult.class,
+ "mediapipe.tasks.components.containers.proto.EmbeddingResult");
+ }
+
+ /**
+ * Creates an {@link ImageEmbedder} instance from a model file and default {@link
+ * ImageEmbedderOptions}.
+ *
+ * @param context an Android {@link Context}.
+ * @param modelPath path to the embedding model in the assets.
+ * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation.
+ */
+ public static ImageEmbedder createFromFile(Context context, String modelPath) {
+ BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
+ return createFromOptions(
+ context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build());
+ }
+
+ /**
+ * Creates an {@link ImageEmbedder} instance from a model file and default {@link
+ * ImageEmbedderOptions}.
+ *
+ * @param context an Android {@link Context}.
+ * @param modelFile the embedding model {@link File} instance.
+ * @throws IOException if an I/O error occurs when opening the tflite model file.
+ * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation.
+ */
+ public static ImageEmbedder createFromFile(Context context, File modelFile) throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ BaseOptions baseOptions =
+ BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
+ return createFromOptions(
+ context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build());
+ }
+ }
+
+ /**
+ * Creates an {@link ImageEmbedder} instance from a model buffer and default {@link
+ * ImageEmbedderOptions}.
+ *
+ * @param context an Android {@link Context}.
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the embedding
+ * model.
+ * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation.
+ */
+ public static ImageEmbedder createFromBuffer(Context context, final ByteBuffer modelBuffer) {
+ BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
+ return createFromOptions(
+ context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build());
+ }
+
+ /**
+ * Creates an {@link ImageEmbedder} instance from an {@link ImageEmbedderOptions} instance.
+ *
+ * @param context an Android {@link Context}.
+ * @param options an {@link ImageEmbedderOptions} instance.
+ * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation.
+ */
+ public static ImageEmbedder createFromOptions(Context context, ImageEmbedderOptions options) {
+ OutputHandler handler = new OutputHandler<>();
+ handler.setOutputPacketConverter(
+ new OutputHandler.OutputPacketConverter() {
+ @Override
+ public ImageEmbedderResult convertToTaskResult(List packets) {
+ try {
+ return ImageEmbedderResult.create(
+ EmbeddingResult.createFromProto(
+ PacketGetter.getProto(
+ packets.get(EMBEDDINGS_OUT_STREAM_INDEX),
+ EmbeddingsProto.EmbeddingResult.getDefaultInstance())),
+ BaseVisionTaskApi.generateResultTimestampMs(
+ options.runningMode(), packets.get(EMBEDDINGS_OUT_STREAM_INDEX)));
+ } catch (IOException e) {
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
+ }
+ }
+
+ @Override
+ public MPImage convertToTaskInput(List packets) {
+ return new BitmapImageBuilder(
+ AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
+ .build();
+ }
+ });
+ options.resultListener().ifPresent(handler::setResultListener);
+ options.errorListener().ifPresent(handler::setErrorListener);
+ TaskRunner runner =
+ TaskRunner.create(
+ context,
+ TaskInfo.builder()
+ .setTaskGraphName(TASK_GRAPH_NAME)
+ .setInputStreams(INPUT_STREAMS)
+ .setOutputStreams(OUTPUT_STREAMS)
+ .setTaskOptions(options)
+ .setEnableFlowLimiting(options.runningMode() == RunningMode.LIVE_STREAM)
+ .build(),
+ handler);
+ return new ImageEmbedder(runner, options.runningMode());
+ }
+
+ /**
+ * Constructor to initialize an {@link ImageEmbedder} from a {@link TaskRunner} and {@link
+ * RunningMode}.
+ *
+ * @param taskRunner a {@link TaskRunner}.
+ * @param runningMode a mediapipe vision task {@link RunningMode}.
+ */
+ private ImageEmbedder(TaskRunner taskRunner, RunningMode runningMode) {
+ super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
+ }
+
+ /**
+ * Performs embedding extraction on the provided single image with default image processing
+ * options, i.e. using the whole image as region-of-interest and without any rotation applied.
+ * Only use this method when the {@link ImageEmbedder} is created with {@link RunningMode.IMAGE}.
+ *
+ * {@link ImageEmbedder} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param image a MediaPipe {@link MPImage} object for processing.
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public ImageEmbedderResult embed(MPImage image) {
+ return embed(image, ImageProcessingOptions.builder().build());
+ }
+
+ /**
+ * Performs embedding extraction on the provided single image. Only use this method when the
+ * {@link ImageEmbedder} is created with {@link RunningMode.IMAGE}.
+ *
+ * {@link ImageEmbedder} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param image a MediaPipe {@link MPImage} object for processing.
+ * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
+ * input image before running inference.
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public ImageEmbedderResult embed(MPImage image, ImageProcessingOptions imageProcessingOptions) {
+ return (ImageEmbedderResult) processImageData(image, imageProcessingOptions);
+ }
+
+ /**
+ * Performs embedding extraction on the provided video frame with default image processing
+ * options, i.e. using the whole image as region-of-interest and without any rotation applied.
+ * Only use this method when the {@link ImageEmbedder} is created with {@link RunningMode.VIDEO}.
+ *
+ * It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
+ * must be monotonically increasing.
+ *
+ *
{@link ImageEmbedder} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param image a MediaPipe {@link MPImage} object for processing.
+ * @param timestampMs the input timestamp (in milliseconds).
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public ImageEmbedderResult embedForVideo(MPImage image, long timestampMs) {
+ return embedForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
+ }
+
+ /**
+ * Performs embedding extraction on the provided video frame. Only use this method when the {@link
+ * ImageEmbedder} is created with {@link RunningMode.VIDEO}.
+ *
+ * It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
+ * must be monotonically increasing.
+ *
+ *
{@link ImageEmbedder} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param image a MediaPipe {@link MPImage} object for processing.
+ * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
+ * input image before running inference.
+ * @param timestampMs the input timestamp (in milliseconds).
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public ImageEmbedderResult embedForVideo(
+ MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
+ return (ImageEmbedderResult) processVideoData(image, imageProcessingOptions, timestampMs);
+ }
+
+ /**
+ * Sends live image data to perform embedding extraction with default image processing options,
+ * i.e. using the whole image as region-of-interest and without any rotation applied, and the
+ * results will be available via the {@link ResultListener} provided in the {@link
+ * ImageEmbedderOptions}. Only use this method when the {@link ImageEmbedder} is created with
+ * {@link RunningMode.LIVE_STREAM}.
+ *
+ * It's required to provide a timestamp (in milliseconds) to indicate when the input image is
+ * sent to the object detector. The input timestamps must be monotonically increasing.
+ *
+ *
{@link ImageEmbedder} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param image a MediaPipe {@link MPImage} object for processing.
+ * @param timestampMs the input timestamp (in milliseconds).
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public void embedAsync(MPImage image, long timestampMs) {
+ embedAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
+ }
+
+ /**
+ * Sends live image data to perform embedding extraction, and the results will be available via
+ * the {@link ResultListener} provided in the {@link ImageEmbedderOptions}. Only use this method
+ * when the {@link ImageEmbedder} is created with {@link RunningMode.LIVE_STREAM}.
+ *
+ * It's required to provide a timestamp (in milliseconds) to indicate when the input image is
+ * sent to the object detector. The input timestamps must be monotonically increasing.
+ *
+ *
{@link ImageEmbedder} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param image a MediaPipe {@link MPImage} object for processing.
+ * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
+ * input image before running inference.
+ * @param timestampMs the input timestamp (in milliseconds).
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public void embedAsync(
+ MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
+ sendLiveStreamData(image, imageProcessingOptions, timestampMs);
+ }
+
+ /**
+ * Utility function to compute cosine
+ * similarity between two {@link Embedding} objects.
+ *
+ * @throws IllegalArgumentException if the embeddings are of different types (float vs.
+ * quantized), have different sizes, or have an L2-norm of 0.
+ */
+ public static double cosineSimilarity(Embedding u, Embedding v) {
+ return CosineSimilarity.compute(u, v);
+ }
+
+ /** Options for setting up and {@link ImageEmbedder}. */
+ @AutoValue
+ public abstract static class ImageEmbedderOptions extends TaskOptions {
+
+ /** Builder for {@link ImageEmbedderOptions}. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ /** Sets the {@link BaseOptions} for the image embedder task. */
+ public abstract Builder setBaseOptions(BaseOptions baseOptions);
+
+ /**
+ * Sets the {@link RunningMode} for the image embedder task. Default to the image mode. Image
+ * embedder has three modes:
+ *
+ *
+ * - IMAGE: The mode for performing embedding extraction on single image inputs.
+ *
- VIDEO: The mode for performing embedding extraction on the decoded frames of a video.
+ *
- LIVE_STREAM: The mode for for performing embedding extraction on a live stream of
+ * input data, such as from camera. In this mode, {@code setResultListener} must be
+ * called to set up a listener to receive the embedding results asynchronously.
+ *
+ */
+ public abstract Builder setRunningMode(RunningMode runningMode);
+
+ /**
+ * Sets the optional {@link EmbedderOptions} controling embedder behavior, such as
+ * L2-normalization and scalar quantization.
+ */
+ public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions);
+
+ /**
+ * Sets the {@link ResultListener} to receive the embedding results asynchronously when the
+ * image embedder is in the live stream mode.
+ */
+ public abstract Builder setResultListener(
+ ResultListener resultListener);
+
+ /** Sets an optional {@link ErrorListener}. */
+ public abstract Builder setErrorListener(ErrorListener errorListener);
+
+ abstract ImageEmbedderOptions autoBuild();
+
+ /**
+ * Validates and builds the {@link ImageEmbedderOptions} instance. *
+ *
+ * @throws IllegalArgumentException if the result listener and the running mode are not
+ * properly configured. The result listener should only be set when the image embedder is
+ * in the live stream mode.
+ */
+ public final ImageEmbedderOptions build() {
+ ImageEmbedderOptions options = autoBuild();
+ if (options.runningMode() == RunningMode.LIVE_STREAM) {
+ if (!options.resultListener().isPresent()) {
+ throw new IllegalArgumentException(
+ "The image embedder is in the live stream mode, a user-defined result listener"
+ + " must be provided in the ImageEmbedderOptions.");
+ }
+ } else if (options.resultListener().isPresent()) {
+ throw new IllegalArgumentException(
+ "The image embedder is in the image or video mode, a user-defined result listener"
+ + " shouldn't be provided in ImageEmbedderOptions.");
+ }
+ return options;
+ }
+ }
+
+ abstract BaseOptions baseOptions();
+
+ abstract RunningMode runningMode();
+
+ abstract Optional embedderOptions();
+
+ abstract Optional> resultListener();
+
+ abstract Optional errorListener();
+
+ public static Builder builder() {
+ return new AutoValue_ImageEmbedder_ImageEmbedderOptions.Builder()
+ .setRunningMode(RunningMode.IMAGE);
+ }
+
+ /** Converts a {@link ImageEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */
+ @Override
+ public CalculatorOptions convertToCalculatorOptionsProto() {
+ BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
+ BaseOptionsProto.BaseOptions.newBuilder();
+ baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
+ baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
+ ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.Builder taskOptionsBuilder =
+ ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.newBuilder()
+ .setBaseOptions(baseOptionsBuilder);
+ if (embedderOptions().isPresent()) {
+ taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto());
+ }
+ return CalculatorOptions.newBuilder()
+ .setExtension(
+ ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.ext,
+ taskOptionsBuilder.build())
+ .build();
+ }
+ }
+}
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java
new file mode 100644
index 000000000..ee3f4abc9
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java
@@ -0,0 +1,54 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.vision.imageembedder;
+
+import com.google.auto.value.AutoValue;
+import com.google.mediapipe.tasks.components.containers.EmbeddingResult;
+import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
+import com.google.mediapipe.tasks.core.TaskResult;
+
+/** Represents the embedding results generated by {@link ImageEmbedder}. */
+@AutoValue
+public abstract class ImageEmbedderResult implements TaskResult {
+
+ /**
+ * Creates an {@link ImageEmbedderResult} instance.
+ *
+ * @param embeddingResult the {@link EmbeddingResult} object containing one embedding per embedder
+ * head.
+ * @param timestampMs a timestamp for this result.
+ */
+ static ImageEmbedderResult create(EmbeddingResult embeddingResult, long timestampMs) {
+ return new AutoValue_ImageEmbedderResult(embeddingResult, timestampMs);
+ }
+
+ /**
+ * Creates an {@link ImageEmbedderResult} instance from a {@link EmbeddingsProto.EmbeddingResult}
+ * protobuf message.
+ *
+ * @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert.
+ * @param timestampMs a timestamp for this result.
+ */
+ static ImageEmbedderResult createFromProto(
+ EmbeddingsProto.EmbeddingResult proto, long timestampMs) {
+ return create(EmbeddingResult.createFromProto(proto), timestampMs);
+ }
+
+ /** Contains one embedding per embedder head. */
+ public abstract EmbeddingResult embeddingResult();
+
+ @Override
+ public abstract long timestampMs();
+}
diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml
new file mode 100644
index 000000000..db303a439
--- /dev/null
+++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml
@@ -0,0 +1,24 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD
new file mode 100644
index 000000000..a7f804c64
--- /dev/null
+++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD
@@ -0,0 +1,19 @@
+# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(default_visibility = ["//mediapipe/tasks:internal"])
+
+licenses(["notice"])
+
+# TODO: Enable this in OSS
diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java
new file mode 100644
index 000000000..56249ead9
--- /dev/null
+++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java
@@ -0,0 +1,444 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.vision.imageembedder;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertThrows;
+
+import android.content.res.AssetManager;
+import android.graphics.BitmapFactory;
+import android.graphics.RectF;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import com.google.mediapipe.framework.MediaPipeException;
+import com.google.mediapipe.framework.image.BitmapImageBuilder;
+import com.google.mediapipe.framework.image.MPImage;
+import com.google.mediapipe.tasks.components.processors.EmbedderOptions;
+import com.google.mediapipe.tasks.core.BaseOptions;
+import com.google.mediapipe.tasks.core.TestUtils;
+import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
+import com.google.mediapipe.tasks.vision.core.RunningMode;
+import com.google.mediapipe.tasks.vision.imageembedder.ImageEmbedder.ImageEmbedderOptions;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Suite;
+import org.junit.runners.Suite.SuiteClasses;
+
+/** Test for {@link ImageEmbedder}/ */
+@RunWith(Suite.class)
+@SuiteClasses({ImageEmbedderTest.General.class, ImageEmbedderTest.RunningModeTest.class})
+public class ImageEmbedderTest {
+ private static final String MOBILENET_EMBEDDER = "mobilenet_v3_small_100_224_embedder.tflite";
+ private static final String BURGER_IMAGE = "burger.jpg";
+ private static final String BURGER_CROP_IMAGE = "burger_crop.jpg";
+ private static final String BURGER_ROTATED_IMAGE = "burger_rotated.jpg";
+
+ private static final double DOUBLE_DIFF_TOLERANCE = 1e-4;
+
+ @RunWith(AndroidJUnit4.class)
+ public static final class General extends ImageEmbedderTest {
+
+ @Test
+ public void create_failsWithMissingModel() throws Exception {
+ String nonExistentFile = "/path/to/non/existent/file";
+ MediaPipeException exception =
+ assertThrows(
+ MediaPipeException.class,
+ () ->
+ ImageEmbedder.createFromFile(
+ ApplicationProvider.getApplicationContext(), nonExistentFile));
+ assertThat(exception).hasMessageThat().contains(nonExistentFile);
+ }
+
+ @Test
+ public void create_failsWithInvalidModelBuffer() throws Exception {
+ // Create a non-direct model ByteBuffer.
+ ByteBuffer modelBuffer =
+ TestUtils.loadToNonDirectByteBuffer(
+ ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER);
+
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ ImageEmbedder.createFromBuffer(
+ ApplicationProvider.getApplicationContext(), modelBuffer));
+
+ assertThat(exception)
+ .hasMessageThat()
+ .contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+ }
+
+ @Test
+ public void embed_succeedsWithNoOptions() throws Exception {
+ ImageEmbedder imageEmbedder =
+ ImageEmbedder.createFromFile(
+ ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER);
+ ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE));
+ ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
+
+ // Check results.
+ assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false);
+ assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false);
+ // Check similarity.
+ double similarity =
+ ImageEmbedder.cosineSimilarity(
+ result.embeddingResult().embeddings().get(0),
+ resultCrop.embeddingResult().embeddings().get(0));
+ assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272);
+ }
+
+ @Test
+ public void embed_succeedsWithL2Normalization() throws Exception {
+ BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build();
+ EmbedderOptions embedderOptions = EmbedderOptions.builder().setL2Normalize(true).build();
+ ImageEmbedderOptions options =
+ ImageEmbedderOptions.builder()
+ .setBaseOptions(baseOptions)
+ .setEmbedderOptions(embedderOptions)
+ .build();
+
+ ImageEmbedder imageEmbedder =
+ ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE));
+ ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
+
+ // Check results.
+ assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false);
+ assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false);
+ // Check similarity.
+ double similarity =
+ ImageEmbedder.cosineSimilarity(
+ result.embeddingResult().embeddings().get(0),
+ resultCrop.embeddingResult().embeddings().get(0));
+ assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272);
+ }
+
+ @Test
+ public void embed_succeedsWithQuantization() throws Exception {
+ BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build();
+ EmbedderOptions embedderOptions = EmbedderOptions.builder().setQuantize(true).build();
+ ImageEmbedderOptions options =
+ ImageEmbedderOptions.builder()
+ .setBaseOptions(baseOptions)
+ .setEmbedderOptions(embedderOptions)
+ .build();
+
+ ImageEmbedder imageEmbedder =
+ ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE));
+ ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
+
+ // Check results.
+ assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ true);
+ assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ true);
+ // Check similarity.
+ double similarity =
+ ImageEmbedder.cosineSimilarity(
+ result.embeddingResult().embeddings().get(0),
+ resultCrop.embeddingResult().embeddings().get(0));
+ assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.926776);
+ }
+
+ @Test
+ public void embed_succeedsWithRegionOfInterest() throws Exception {
+ ImageEmbedder imageEmbedder =
+ ImageEmbedder.createFromFile(
+ ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER);
+ // RectF around the region in "burger.jpg" corresponding to "burger_crop.jpg".
+ RectF roi = new RectF(0.0f, 0.0f, 0.833333f, 1.0f);
+ ImageProcessingOptions imageProcessingOptions =
+ ImageProcessingOptions.builder().setRegionOfInterest(roi).build();
+ ImageEmbedderResult resultRoi =
+ imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE), imageProcessingOptions);
+ ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
+
+ // Check results.
+ assertHasOneHeadAndCorrectDimension(resultRoi, /*quantized=*/ false);
+ assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false);
+ // Check similarity.
+ double similarity =
+ ImageEmbedder.cosineSimilarity(
+ resultRoi.embeddingResult().embeddings().get(0),
+ resultCrop.embeddingResult().embeddings().get(0));
+ assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999931f);
+ }
+
+ @Test
+ public void embed_succeedsWithRotation() throws Exception {
+ ImageEmbedder imageEmbedder =
+ ImageEmbedder.createFromFile(
+ ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER);
+ ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE));
+ ImageProcessingOptions imageProcessingOptions =
+ ImageProcessingOptions.builder().setRotationDegrees(-90).build();
+ ImageEmbedderResult resultRotated =
+ imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions);
+
+ // Check results.
+ assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false);
+ assertHasOneHeadAndCorrectDimension(resultRotated, /*quantized=*/ false);
+ // Check similarity.
+ double similarity =
+ ImageEmbedder.cosineSimilarity(
+ result.embeddingResult().embeddings().get(0),
+ resultRotated.embeddingResult().embeddings().get(0));
+ assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.571648426f);
+ }
+
+ @Test
+ public void embed_succeedsWithRegionOfInterestAndRotation() throws Exception {
+ ImageEmbedder imageEmbedder =
+ ImageEmbedder.createFromFile(
+ ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER);
+ // RectF around the region in "burger_rotated.jpg" corresponding to "burger_crop.jpg".
+ RectF roi = new RectF(0.0f, 0.0f, 1.0f, 0.833333f);
+ ImageProcessingOptions imageProcessingOptions =
+ ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build();
+ ImageEmbedderResult resultRoiRotated =
+ imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions);
+ ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
+
+ // Check results.
+ assertHasOneHeadAndCorrectDimension(resultRoiRotated, /*quantized=*/ false);
+ assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false);
+ // Check similarity.
+ double similarity =
+ ImageEmbedder.cosineSimilarity(
+ resultRoiRotated.embeddingResult().embeddings().get(0),
+ resultCrop.embeddingResult().embeddings().get(0));
+ assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.62780395f);
+ }
+ }
+
+ @RunWith(AndroidJUnit4.class)
+ public static final class RunningModeTest extends ImageEmbedderTest {
+
+ @Test
+ public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception {
+ for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) {
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ ImageEmbedderOptions.builder()
+ .setBaseOptions(
+ BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build())
+ .setRunningMode(mode)
+ .setResultListener((result, inputImage) -> {})
+ .build());
+ assertThat(exception)
+ .hasMessageThat()
+ .contains("a user-defined result listener shouldn't be provided");
+ }
+ }
+
+ @Test
+ public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ ImageEmbedderOptions.builder()
+ .setBaseOptions(
+ BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build())
+ .setRunningMode(RunningMode.LIVE_STREAM)
+ .build());
+ assertThat(exception)
+ .hasMessageThat()
+ .contains("a user-defined result listener must be provided");
+ }
+
+ @Test
+ public void embed_failsWithCallingWrongApiInImageMode() throws Exception {
+ ImageEmbedderOptions options =
+ ImageEmbedderOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build())
+ .setRunningMode(RunningMode.IMAGE)
+ .build();
+
+ ImageEmbedder imageEmbedder =
+ ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ MediaPipeException exception =
+ assertThrows(
+ MediaPipeException.class,
+ () ->
+ imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
+ assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
+ exception =
+ assertThrows(
+ MediaPipeException.class,
+ () -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
+ assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
+ }
+
+ @Test
+ public void embed_failsWithCallingWrongApiInVideoMode() throws Exception {
+ ImageEmbedderOptions options =
+ ImageEmbedderOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build())
+ .setRunningMode(RunningMode.VIDEO)
+ .build();
+
+ ImageEmbedder imageEmbedder =
+ ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ MediaPipeException exception =
+ assertThrows(
+ MediaPipeException.class, () -> imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)));
+ assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
+ exception =
+ assertThrows(
+ MediaPipeException.class,
+ () -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
+ assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
+ }
+
+ @Test
+ public void embed_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
+ ImageEmbedderOptions options =
+ ImageEmbedderOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build())
+ .setRunningMode(RunningMode.LIVE_STREAM)
+ .setResultListener((imageClassificationResult, inputImage) -> {})
+ .build();
+
+ ImageEmbedder imageEmbedder =
+ ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+
+ MediaPipeException exception =
+ assertThrows(
+ MediaPipeException.class, () -> imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)));
+ assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
+ exception =
+ assertThrows(
+ MediaPipeException.class,
+ () ->
+ imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
+ assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
+ }
+
+ @Test
+ public void embed_succeedsWithImageMode() throws Exception {
+ ImageEmbedder imageEmbedder =
+ ImageEmbedder.createFromFile(
+ ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER);
+ ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE));
+ ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
+
+ // Check results.
+ assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false);
+ assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false);
+ // Check similarity.
+ double similarity =
+ ImageEmbedder.cosineSimilarity(
+ result.embeddingResult().embeddings().get(0),
+ resultCrop.embeddingResult().embeddings().get(0));
+ assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272);
+ }
+
+ @Test
+ public void embed_succeedsWithVideoMode() throws Exception {
+ BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build();
+ ImageEmbedderOptions options =
+ ImageEmbedderOptions.builder()
+ .setBaseOptions(baseOptions)
+ .setRunningMode(RunningMode.VIDEO)
+ .build();
+ ImageEmbedder imageEmbedder =
+ ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+
+ for (int i = 0; i < 3; ++i) {
+ ImageEmbedderResult result =
+ imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ i);
+ assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false);
+ }
+ }
+
+ @Test
+ public void embed_failsWithOutOfOrderInputTimestamps() throws Exception {
+ MPImage image = getImageFromAsset(BURGER_IMAGE);
+ BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build();
+ ImageEmbedderOptions options =
+ ImageEmbedderOptions.builder()
+ .setBaseOptions(baseOptions)
+ .setRunningMode(RunningMode.LIVE_STREAM)
+ .setResultListener(
+ (imageEmbedderResult, inputImage) -> {
+ assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false);
+ assertImageSizeIsExpected(inputImage);
+ })
+ .build();
+ try (ImageEmbedder imageEmbedder =
+ ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
+ imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1);
+ MediaPipeException exception =
+ assertThrows(
+ MediaPipeException.class,
+ () -> imageEmbedder.embedAsync(image, /*timestampMs=*/ 0));
+ assertThat(exception)
+ .hasMessageThat()
+ .contains("having a smaller timestamp than the processed timestamp");
+ }
+ }
+
+ @Test
+ public void embed_succeedsWithLiveStreamMode() throws Exception {
+ MPImage image = getImageFromAsset(BURGER_IMAGE);
+ BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build();
+ ImageEmbedderOptions options =
+ ImageEmbedderOptions.builder()
+ .setBaseOptions(baseOptions)
+ .setRunningMode(RunningMode.LIVE_STREAM)
+ .setResultListener(
+ (imageEmbedderResult, inputImage) -> {
+ assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false);
+ assertImageSizeIsExpected(inputImage);
+ })
+ .build();
+ try (ImageEmbedder imageEmbedder =
+ ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
+ for (int i = 0; i < 3; ++i) {
+ imageEmbedder.embedAsync(image, /*timestampMs=*/ i);
+ }
+ }
+ }
+ }
+
+ private static MPImage getImageFromAsset(String filePath) throws Exception {
+ AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
+ InputStream istr = assetManager.open(filePath);
+ return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
+ }
+
+ private static void assertHasOneHeadAndCorrectDimension(
+ ImageEmbedderResult result, boolean quantized) {
+ assertThat(result.embeddingResult().embeddings()).hasSize(1);
+ assertThat(result.embeddingResult().embeddings().get(0).headIndex()).isEqualTo(0);
+ assertThat(result.embeddingResult().embeddings().get(0).headName().get()).isEqualTo("feature");
+ if (quantized) {
+ assertThat(result.embeddingResult().embeddings().get(0).quantizedEmbedding()).hasLength(1024);
+ } else {
+ assertThat(result.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(1024);
+ }
+ }
+
+ private static void assertImageSizeIsExpected(MPImage inputImage) {
+ assertThat(inputImage).isNotNull();
+ assertThat(inputImage.getWidth()).isEqualTo(480);
+ assertThat(inputImage.getHeight()).isEqualTo(325);
+ }
+}