Add Java ImageEmbedder API.
PiperOrigin-RevId: 488588010
This commit is contained in:
parent
6f54308c25
commit
ebba119f15
|
@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
|
|||
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.vision.imageembedder.proto";
|
||||
option java_outer_classname = "ImageEmbedderGraphOptionsProto";
|
||||
|
||||
message ImageEmbedderGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional ImageEmbedderGraphOptions ext = 476348187;
|
||||
|
|
|
@ -42,6 +42,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
|||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
|
||||
|
|
|
@ -43,6 +43,7 @@ cc_binary(
|
|||
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
|
||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||
],
|
||||
|
@ -172,6 +173,34 @@ android_library(
|
|||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "imageembedder",
|
||||
srcs = [
|
||||
"imageembedder/ImageEmbedder.java",
|
||||
"imageembedder/ImageEmbedderResult.java",
|
||||
],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
manifest = "imageembedder/AndroidManifest.xml",
|
||||
deps = [
|
||||
":core",
|
||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar")
|
||||
|
||||
mediapipe_tasks_vision_aar(
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision.imageembedder">
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,448 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.imageembedder;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.ParcelFileDescriptor;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.framework.AndroidPacketGetter;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.PacketGetter;
|
||||
import com.google.mediapipe.framework.ProtoUtil;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.components.containers.Embedding;
|
||||
import com.google.mediapipe.tasks.components.containers.EmbeddingResult;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
|
||||
import com.google.mediapipe.tasks.components.processors.EmbedderOptions;
|
||||
import com.google.mediapipe.tasks.components.utils.CosineSimilarity;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
||||
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.imageembedder.proto.ImageEmbedderGraphOptionsProto;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Performs embedding extraction on images.
|
||||
*
|
||||
* <p>The API expects a TFLite model with optional, but strongly recommended, <a
|
||||
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
|
||||
*
|
||||
* <p>The API supports models with one image input tensor and one or more output tensors. To be more
|
||||
* specific, here are the requirements.
|
||||
*
|
||||
* <ul>
|
||||
* <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
|
||||
* <ul>
|
||||
* <li>image input of size {@code [batch x height x width x channels]}.
|
||||
* <li>batch inference is not supported ({@code batch} is required to be 1).
|
||||
* <li>only RGB inputs are supported ({@code channels} is required to be 3).
|
||||
* <li>if type is kTfLiteFloat32, NormalizationOptions are required to be attached to the
|
||||
* metadata for input normalization.
|
||||
* </ul>
|
||||
* <li>At least one output tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) with shape {@code
|
||||
* [1 x N]} where N is the number of dimensions in the produced embeddings.
|
||||
* </ul>
|
||||
*/
|
||||
public final class ImageEmbedder extends BaseVisionTaskApi {
|
||||
private static final String TAG = ImageEmbedder.class.getSimpleName();
|
||||
private static final String IMAGE_IN_STREAM_NAME = "image_in";
|
||||
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
|
||||
private static final List<String> INPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||
private static final List<String> OUTPUT_STREAMS =
|
||||
Collections.unmodifiableList(Arrays.asList("EMBEDDINGS:embeddings_out", "IMAGE:image_out"));
|
||||
private static final int EMBEDDINGS_OUT_STREAM_INDEX = 0;
|
||||
private static final int IMAGE_OUT_STREAM_INDEX = 1;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph";
|
||||
|
||||
static {
|
||||
ProtoUtil.registerTypeName(
|
||||
EmbeddingsProto.EmbeddingResult.class,
|
||||
"mediapipe.tasks.components.containers.proto.EmbeddingResult");
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageEmbedder} instance from a model file and default {@link
|
||||
* ImageEmbedderOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelPath path to the embedding model in the assets.
|
||||
* @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation.
|
||||
*/
|
||||
public static ImageEmbedder createFromFile(Context context, String modelPath) {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
|
||||
return createFromOptions(
|
||||
context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageEmbedder} instance from a model file and default {@link
|
||||
* ImageEmbedderOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelFile the embedding model {@link File} instance.
|
||||
* @throws IOException if an I/O error occurs when opening the tflite model file.
|
||||
* @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation.
|
||||
*/
|
||||
public static ImageEmbedder createFromFile(Context context, File modelFile) throws IOException {
|
||||
try (ParcelFileDescriptor descriptor =
|
||||
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
|
||||
BaseOptions baseOptions =
|
||||
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
|
||||
return createFromOptions(
|
||||
context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageEmbedder} instance from a model buffer and default {@link
|
||||
* ImageEmbedderOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the embedding
|
||||
* model.
|
||||
* @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation.
|
||||
*/
|
||||
public static ImageEmbedder createFromBuffer(Context context, final ByteBuffer modelBuffer) {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
|
||||
return createFromOptions(
|
||||
context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageEmbedder} instance from an {@link ImageEmbedderOptions} instance.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param options an {@link ImageEmbedderOptions} instance.
|
||||
* @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation.
|
||||
*/
|
||||
public static ImageEmbedder createFromOptions(Context context, ImageEmbedderOptions options) {
|
||||
OutputHandler<ImageEmbedderResult, MPImage> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<ImageEmbedderResult, MPImage>() {
|
||||
@Override
|
||||
public ImageEmbedderResult convertToTaskResult(List<Packet> packets) {
|
||||
try {
|
||||
return ImageEmbedderResult.create(
|
||||
EmbeddingResult.createFromProto(
|
||||
PacketGetter.getProto(
|
||||
packets.get(EMBEDDINGS_OUT_STREAM_INDEX),
|
||||
EmbeddingsProto.EmbeddingResult.getDefaultInstance())),
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
options.runningMode(), packets.get(EMBEDDINGS_OUT_STREAM_INDEX)));
|
||||
} catch (IOException e) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public MPImage convertToTaskInput(List<Packet> packets) {
|
||||
return new BitmapImageBuilder(
|
||||
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
|
||||
.build();
|
||||
}
|
||||
});
|
||||
options.resultListener().ifPresent(handler::setResultListener);
|
||||
options.errorListener().ifPresent(handler::setErrorListener);
|
||||
TaskRunner runner =
|
||||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<ImageEmbedderOptions>builder()
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
.setTaskOptions(options)
|
||||
.setEnableFlowLimiting(options.runningMode() == RunningMode.LIVE_STREAM)
|
||||
.build(),
|
||||
handler);
|
||||
return new ImageEmbedder(runner, options.runningMode());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize an {@link ImageEmbedder} from a {@link TaskRunner} and {@link
|
||||
* RunningMode}.
|
||||
*
|
||||
* @param taskRunner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
*/
|
||||
private ImageEmbedder(TaskRunner taskRunner, RunningMode runningMode) {
|
||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs embedding extraction on the provided single image with default image processing
|
||||
* options, i.e. using the whole image as region-of-interest and without any rotation applied.
|
||||
* Only use this method when the {@link ImageEmbedder} is created with {@link RunningMode.IMAGE}.
|
||||
*
|
||||
* <p>{@link ImageEmbedder} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ImageEmbedderResult embed(MPImage image) {
|
||||
return embed(image, ImageProcessingOptions.builder().build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs embedding extraction on the provided single image. Only use this method when the
|
||||
* {@link ImageEmbedder} is created with {@link RunningMode.IMAGE}.
|
||||
*
|
||||
* <p>{@link ImageEmbedder} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ImageEmbedderResult embed(MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||
return (ImageEmbedderResult) processImageData(image, imageProcessingOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs embedding extraction on the provided video frame with default image processing
|
||||
* options, i.e. using the whole image as region-of-interest and without any rotation applied.
|
||||
* Only use this method when the {@link ImageEmbedder} is created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ImageEmbedder} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ImageEmbedderResult embedForVideo(MPImage image, long timestampMs) {
|
||||
return embedForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs embedding extraction on the provided video frame. Only use this method when the {@link
|
||||
* ImageEmbedder} is created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ImageEmbedder} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ImageEmbedderResult embedForVideo(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
return (ImageEmbedderResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data to perform embedding extraction with default image processing options,
|
||||
* i.e. using the whole image as region-of-interest and without any rotation applied, and the
|
||||
* results will be available via the {@link ResultListener} provided in the {@link
|
||||
* ImageEmbedderOptions}. Only use this method when the {@link ImageEmbedder} is created with
|
||||
* {@link RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the object detector. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ImageEmbedder} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void embedAsync(MPImage image, long timestampMs) {
|
||||
embedAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data to perform embedding extraction, and the results will be available via
|
||||
* the {@link ResultListener} provided in the {@link ImageEmbedderOptions}. Only use this method
|
||||
* when the {@link ImageEmbedder} is created with {@link RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the object detector. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ImageEmbedder} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void embedAsync(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility function to compute <a href="https://en.wikipedia.org/wiki/Cosine_similarity">cosine
|
||||
* similarity</a> between two {@link Embedding} objects.
|
||||
*
|
||||
* @throws IllegalArgumentException if the embeddings are of different types (float vs.
|
||||
* quantized), have different sizes, or have an L2-norm of 0.
|
||||
*/
|
||||
public static double cosineSimilarity(Embedding u, Embedding v) {
|
||||
return CosineSimilarity.compute(u, v);
|
||||
}
|
||||
|
||||
/** Options for setting up and {@link ImageEmbedder}. */
|
||||
@AutoValue
|
||||
public abstract static class ImageEmbedderOptions extends TaskOptions {
|
||||
|
||||
/** Builder for {@link ImageEmbedderOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/** Sets the {@link BaseOptions} for the image embedder task. */
|
||||
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
||||
|
||||
/**
|
||||
* Sets the {@link RunningMode} for the image embedder task. Default to the image mode. Image
|
||||
* embedder has three modes:
|
||||
*
|
||||
* <ul>
|
||||
* <li>IMAGE: The mode for performing embedding extraction on single image inputs.
|
||||
* <li>VIDEO: The mode for performing embedding extraction on the decoded frames of a video.
|
||||
* <li>LIVE_STREAM: The mode for for performing embedding extraction on a live stream of
|
||||
* input data, such as from camera. In this mode, {@code setResultListener} must be
|
||||
* called to set up a listener to receive the embedding results asynchronously.
|
||||
* </ul>
|
||||
*/
|
||||
public abstract Builder setRunningMode(RunningMode runningMode);
|
||||
|
||||
/**
|
||||
* Sets the optional {@link EmbedderOptions} controling embedder behavior, such as
|
||||
* L2-normalization and scalar quantization.
|
||||
*/
|
||||
public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions);
|
||||
|
||||
/**
|
||||
* Sets the {@link ResultListener} to receive the embedding results asynchronously when the
|
||||
* image embedder is in the live stream mode.
|
||||
*/
|
||||
public abstract Builder setResultListener(
|
||||
ResultListener<ImageEmbedderResult, MPImage> resultListener);
|
||||
|
||||
/** Sets an optional {@link ErrorListener}. */
|
||||
public abstract Builder setErrorListener(ErrorListener errorListener);
|
||||
|
||||
abstract ImageEmbedderOptions autoBuild();
|
||||
|
||||
/**
|
||||
* Validates and builds the {@link ImageEmbedderOptions} instance. *
|
||||
*
|
||||
* @throws IllegalArgumentException if the result listener and the running mode are not
|
||||
* properly configured. The result listener should only be set when the image embedder is
|
||||
* in the live stream mode.
|
||||
*/
|
||||
public final ImageEmbedderOptions build() {
|
||||
ImageEmbedderOptions options = autoBuild();
|
||||
if (options.runningMode() == RunningMode.LIVE_STREAM) {
|
||||
if (!options.resultListener().isPresent()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The image embedder is in the live stream mode, a user-defined result listener"
|
||||
+ " must be provided in the ImageEmbedderOptions.");
|
||||
}
|
||||
} else if (options.resultListener().isPresent()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The image embedder is in the image or video mode, a user-defined result listener"
|
||||
+ " shouldn't be provided in ImageEmbedderOptions.");
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract RunningMode runningMode();
|
||||
|
||||
abstract Optional<EmbedderOptions> embedderOptions();
|
||||
|
||||
abstract Optional<ResultListener<ImageEmbedderResult, MPImage>> resultListener();
|
||||
|
||||
abstract Optional<ErrorListener> errorListener();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageEmbedder_ImageEmbedderOptions.Builder()
|
||||
.setRunningMode(RunningMode.IMAGE);
|
||||
}
|
||||
|
||||
/** Converts a {@link ImageEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */
|
||||
@Override
|
||||
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
|
||||
BaseOptionsProto.BaseOptions.newBuilder();
|
||||
baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
|
||||
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||
ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.Builder taskOptionsBuilder =
|
||||
ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.newBuilder()
|
||||
.setBaseOptions(baseOptionsBuilder);
|
||||
if (embedderOptions().isPresent()) {
|
||||
taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto());
|
||||
}
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.ext,
|
||||
taskOptionsBuilder.build())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.imageembedder;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.containers.EmbeddingResult;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
|
||||
/** Represents the embedding results generated by {@link ImageEmbedder}. */
|
||||
@AutoValue
|
||||
public abstract class ImageEmbedderResult implements TaskResult {
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageEmbedderResult} instance.
|
||||
*
|
||||
* @param embeddingResult the {@link EmbeddingResult} object containing one embedding per embedder
|
||||
* head.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
static ImageEmbedderResult create(EmbeddingResult embeddingResult, long timestampMs) {
|
||||
return new AutoValue_ImageEmbedderResult(embeddingResult, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageEmbedderResult} instance from a {@link EmbeddingsProto.EmbeddingResult}
|
||||
* protobuf message.
|
||||
*
|
||||
* @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
static ImageEmbedderResult createFromProto(
|
||||
EmbeddingsProto.EmbeddingResult proto, long timestampMs) {
|
||||
return create(EmbeddingResult.createFromProto(proto), timestampMs);
|
||||
}
|
||||
|
||||
/** Contains one embedding per embedder head. */
|
||||
public abstract EmbeddingResult embeddingResult();
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision.imageembeddertest"
|
||||
android:versionCode="1"
|
||||
android:versionName="1.0" >
|
||||
|
||||
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
<application
|
||||
android:label="imageembeddertest"
|
||||
android:name="android.support.multidex.MultiDexApplication"
|
||||
android:taskAffinity="">
|
||||
<uses-library android:name="android.test.runner" />
|
||||
</application>
|
||||
|
||||
<instrumentation
|
||||
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||
android:targetPackage="com.google.mediapipe.tasks.vision.imageembeddertest" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# TODO: Enable this in OSS
|
|
@ -0,0 +1,444 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.imageembedder;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.assertThrows;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import android.graphics.BitmapFactory;
|
||||
import android.graphics.RectF;
|
||||
import androidx.test.core.app.ApplicationProvider;
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.components.processors.EmbedderOptions;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.TestUtils;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.imageembedder.ImageEmbedder.ImageEmbedderOptions;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Suite;
|
||||
import org.junit.runners.Suite.SuiteClasses;
|
||||
|
||||
/** Test for {@link ImageEmbedder}/ */
|
||||
@RunWith(Suite.class)
|
||||
@SuiteClasses({ImageEmbedderTest.General.class, ImageEmbedderTest.RunningModeTest.class})
|
||||
public class ImageEmbedderTest {
|
||||
private static final String MOBILENET_EMBEDDER = "mobilenet_v3_small_100_224_embedder.tflite";
|
||||
private static final String BURGER_IMAGE = "burger.jpg";
|
||||
private static final String BURGER_CROP_IMAGE = "burger_crop.jpg";
|
||||
private static final String BURGER_ROTATED_IMAGE = "burger_rotated.jpg";
|
||||
|
||||
private static final double DOUBLE_DIFF_TOLERANCE = 1e-4;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class General extends ImageEmbedderTest {
|
||||
|
||||
@Test
|
||||
public void create_failsWithMissingModel() throws Exception {
|
||||
String nonExistentFile = "/path/to/non/existent/file";
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
ImageEmbedder.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(), nonExistentFile));
|
||||
assertThat(exception).hasMessageThat().contains(nonExistentFile);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void create_failsWithInvalidModelBuffer() throws Exception {
|
||||
// Create a non-direct model ByteBuffer.
|
||||
ByteBuffer modelBuffer =
|
||||
TestUtils.loadToNonDirectByteBuffer(
|
||||
ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER);
|
||||
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
ImageEmbedder.createFromBuffer(
|
||||
ApplicationProvider.getApplicationContext(), modelBuffer));
|
||||
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void embed_succeedsWithNoOptions() throws Exception {
|
||||
ImageEmbedder imageEmbedder =
|
||||
ImageEmbedder.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER);
|
||||
ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE));
|
||||
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
|
||||
|
||||
// Check results.
|
||||
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false);
|
||||
// Check similarity.
|
||||
double similarity =
|
||||
ImageEmbedder.cosineSimilarity(
|
||||
result.embeddingResult().embeddings().get(0),
|
||||
resultCrop.embeddingResult().embeddings().get(0));
|
||||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void embed_succeedsWithL2Normalization() throws Exception {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build();
|
||||
EmbedderOptions embedderOptions = EmbedderOptions.builder().setL2Normalize(true).build();
|
||||
ImageEmbedderOptions options =
|
||||
ImageEmbedderOptions.builder()
|
||||
.setBaseOptions(baseOptions)
|
||||
.setEmbedderOptions(embedderOptions)
|
||||
.build();
|
||||
|
||||
ImageEmbedder imageEmbedder =
|
||||
ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE));
|
||||
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
|
||||
|
||||
// Check results.
|
||||
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false);
|
||||
// Check similarity.
|
||||
double similarity =
|
||||
ImageEmbedder.cosineSimilarity(
|
||||
result.embeddingResult().embeddings().get(0),
|
||||
resultCrop.embeddingResult().embeddings().get(0));
|
||||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void embed_succeedsWithQuantization() throws Exception {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build();
|
||||
EmbedderOptions embedderOptions = EmbedderOptions.builder().setQuantize(true).build();
|
||||
ImageEmbedderOptions options =
|
||||
ImageEmbedderOptions.builder()
|
||||
.setBaseOptions(baseOptions)
|
||||
.setEmbedderOptions(embedderOptions)
|
||||
.build();
|
||||
|
||||
ImageEmbedder imageEmbedder =
|
||||
ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE));
|
||||
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
|
||||
|
||||
// Check results.
|
||||
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ true);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ true);
|
||||
// Check similarity.
|
||||
double similarity =
|
||||
ImageEmbedder.cosineSimilarity(
|
||||
result.embeddingResult().embeddings().get(0),
|
||||
resultCrop.embeddingResult().embeddings().get(0));
|
||||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.926776);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void embed_succeedsWithRegionOfInterest() throws Exception {
|
||||
ImageEmbedder imageEmbedder =
|
||||
ImageEmbedder.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER);
|
||||
// RectF around the region in "burger.jpg" corresponding to "burger_crop.jpg".
|
||||
RectF roi = new RectF(0.0f, 0.0f, 0.833333f, 1.0f);
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRegionOfInterest(roi).build();
|
||||
ImageEmbedderResult resultRoi =
|
||||
imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE), imageProcessingOptions);
|
||||
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
|
||||
|
||||
// Check results.
|
||||
assertHasOneHeadAndCorrectDimension(resultRoi, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false);
|
||||
// Check similarity.
|
||||
double similarity =
|
||||
ImageEmbedder.cosineSimilarity(
|
||||
resultRoi.embeddingResult().embeddings().get(0),
|
||||
resultCrop.embeddingResult().embeddings().get(0));
|
||||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999931f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void embed_succeedsWithRotation() throws Exception {
|
||||
ImageEmbedder imageEmbedder =
|
||||
ImageEmbedder.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER);
|
||||
ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE));
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
|
||||
ImageEmbedderResult resultRotated =
|
||||
imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions);
|
||||
|
||||
// Check results.
|
||||
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultRotated, /*quantized=*/ false);
|
||||
// Check similarity.
|
||||
double similarity =
|
||||
ImageEmbedder.cosineSimilarity(
|
||||
result.embeddingResult().embeddings().get(0),
|
||||
resultRotated.embeddingResult().embeddings().get(0));
|
||||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.571648426f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void embed_succeedsWithRegionOfInterestAndRotation() throws Exception {
|
||||
ImageEmbedder imageEmbedder =
|
||||
ImageEmbedder.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER);
|
||||
// RectF around the region in "burger_rotated.jpg" corresponding to "burger_crop.jpg".
|
||||
RectF roi = new RectF(0.0f, 0.0f, 1.0f, 0.833333f);
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build();
|
||||
ImageEmbedderResult resultRoiRotated =
|
||||
imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions);
|
||||
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
|
||||
|
||||
// Check results.
|
||||
assertHasOneHeadAndCorrectDimension(resultRoiRotated, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false);
|
||||
// Check similarity.
|
||||
double similarity =
|
||||
ImageEmbedder.cosineSimilarity(
|
||||
resultRoiRotated.embeddingResult().embeddings().get(0),
|
||||
resultCrop.embeddingResult().embeddings().get(0));
|
||||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.62780395f);
|
||||
}
|
||||
}
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class RunningModeTest extends ImageEmbedderTest {
|
||||
|
||||
@Test
|
||||
public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception {
|
||||
for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
ImageEmbedderOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build())
|
||||
.setRunningMode(mode)
|
||||
.setResultListener((result, inputImage) -> {})
|
||||
.build());
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("a user-defined result listener shouldn't be provided");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
ImageEmbedderOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.build());
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("a user-defined result listener must be provided");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void embed_failsWithCallingWrongApiInImageMode() throws Exception {
|
||||
ImageEmbedderOptions options =
|
||||
ImageEmbedderOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build())
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.build();
|
||||
|
||||
ImageEmbedder imageEmbedder =
|
||||
ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void embed_failsWithCallingWrongApiInVideoMode() throws Exception {
|
||||
ImageEmbedderOptions options =
|
||||
ImageEmbedderOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build())
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.build();
|
||||
|
||||
ImageEmbedder imageEmbedder =
|
||||
ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class, () -> imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void embed_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
|
||||
ImageEmbedderOptions options =
|
||||
ImageEmbedderOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener((imageClassificationResult, inputImage) -> {})
|
||||
.build();
|
||||
|
||||
ImageEmbedder imageEmbedder =
|
||||
ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class, () -> imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void embed_succeedsWithImageMode() throws Exception {
|
||||
ImageEmbedder imageEmbedder =
|
||||
ImageEmbedder.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER);
|
||||
ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE));
|
||||
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
|
||||
|
||||
// Check results.
|
||||
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false);
|
||||
// Check similarity.
|
||||
double similarity =
|
||||
ImageEmbedder.cosineSimilarity(
|
||||
result.embeddingResult().embeddings().get(0),
|
||||
resultCrop.embeddingResult().embeddings().get(0));
|
||||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void embed_succeedsWithVideoMode() throws Exception {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build();
|
||||
ImageEmbedderOptions options =
|
||||
ImageEmbedderOptions.builder()
|
||||
.setBaseOptions(baseOptions)
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.build();
|
||||
ImageEmbedder imageEmbedder =
|
||||
ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
ImageEmbedderResult result =
|
||||
imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ i);
|
||||
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void embed_failsWithOutOfOrderInputTimestamps() throws Exception {
|
||||
MPImage image = getImageFromAsset(BURGER_IMAGE);
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build();
|
||||
ImageEmbedderOptions options =
|
||||
ImageEmbedderOptions.builder()
|
||||
.setBaseOptions(baseOptions)
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(imageEmbedderResult, inputImage) -> {
|
||||
assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false);
|
||||
assertImageSizeIsExpected(inputImage);
|
||||
})
|
||||
.build();
|
||||
try (ImageEmbedder imageEmbedder =
|
||||
ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> imageEmbedder.embedAsync(image, /*timestampMs=*/ 0));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("having a smaller timestamp than the processed timestamp");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void embed_succeedsWithLiveStreamMode() throws Exception {
|
||||
MPImage image = getImageFromAsset(BURGER_IMAGE);
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build();
|
||||
ImageEmbedderOptions options =
|
||||
ImageEmbedderOptions.builder()
|
||||
.setBaseOptions(baseOptions)
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(imageEmbedderResult, inputImage) -> {
|
||||
assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false);
|
||||
assertImageSizeIsExpected(inputImage);
|
||||
})
|
||||
.build();
|
||||
try (ImageEmbedder imageEmbedder =
|
||||
ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
imageEmbedder.embedAsync(image, /*timestampMs=*/ i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static MPImage getImageFromAsset(String filePath) throws Exception {
|
||||
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||
InputStream istr = assetManager.open(filePath);
|
||||
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
|
||||
}
|
||||
|
||||
private static void assertHasOneHeadAndCorrectDimension(
|
||||
ImageEmbedderResult result, boolean quantized) {
|
||||
assertThat(result.embeddingResult().embeddings()).hasSize(1);
|
||||
assertThat(result.embeddingResult().embeddings().get(0).headIndex()).isEqualTo(0);
|
||||
assertThat(result.embeddingResult().embeddings().get(0).headName().get()).isEqualTo("feature");
|
||||
if (quantized) {
|
||||
assertThat(result.embeddingResult().embeddings().get(0).quantizedEmbedding()).hasLength(1024);
|
||||
} else {
|
||||
assertThat(result.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(1024);
|
||||
}
|
||||
}
|
||||
|
||||
private static void assertImageSizeIsExpected(MPImage inputImage) {
|
||||
assertThat(inputImage).isNotNull();
|
||||
assertThat(inputImage.getWidth()).isEqualTo(480);
|
||||
assertThat(inputImage.getHeight()).isEqualTo(325);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user