Internal change

PiperOrigin-RevId: 492188196
This commit is contained in:
MediaPipe Team 2022-12-01 07:15:52 -08:00 committed by Copybara-Service
parent 29c7702984
commit 01010fa248
9 changed files with 126 additions and 151 deletions

View File

@ -92,12 +92,12 @@ android_library(
"//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/audio:libmediapipe_tasks_audio_jni_lib", "//mediapipe/tasks/java/com/google/mediapipe/tasks/audio:libmediapipe_tasks_audio_jni_lib",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue", "//third_party:autovalue",

View File

@ -28,7 +28,7 @@ import com.google.mediapipe.tasks.audio.core.RunningMode;
import com.google.mediapipe.tasks.components.containers.AudioData; import com.google.mediapipe.tasks.components.containers.AudioData;
import com.google.mediapipe.tasks.components.containers.Embedding; import com.google.mediapipe.tasks.components.containers.Embedding;
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
import com.google.mediapipe.tasks.components.processors.EmbedderOptions; import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto;
import com.google.mediapipe.tasks.components.utils.CosineSimilarity; import com.google.mediapipe.tasks.components.utils.CosineSimilarity;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.ErrorListener;
@ -309,10 +309,24 @@ public final class AudioEmbedder extends BaseAudioTaskApi {
public abstract Builder setRunningMode(RunningMode runningMode); public abstract Builder setRunningMode(RunningMode runningMode);
/** /**
* Sets the optional {@link EmbedderOptions} controling embedding behavior, such as score * Sets whether L2 normalization should be performed on the returned embeddings. Use this
* threshold, number of results, etc. * option only if the model does not already contain a native <code>L2_NORMALIZATION</code> TF
* Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF
* Lite inference.
*
* <p>False by default.
*/ */
public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); public abstract Builder setL2Normalize(boolean l2Normalize);
/**
* Sets whether the returned embedding should be quantized to bytes via scalar quantization.
* Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is
* guaranteed to have value in <code>[-1.0, 1.0]</code>. Use {@link #setL2Normalize(boolean)}
* if this is not the case.
*
* <p>False by default.
*/
public abstract Builder setQuantize(boolean quantize);
/** /**
* Sets the {@link ResultListener} to receive the embedding results asynchronously when the * Sets the {@link ResultListener} to receive the embedding results asynchronously when the
@ -354,7 +368,9 @@ public final class AudioEmbedder extends BaseAudioTaskApi {
abstract RunningMode runningMode(); abstract RunningMode runningMode();
abstract Optional<EmbedderOptions> embedderOptions(); abstract boolean l2Normalize();
abstract boolean quantize();
abstract Optional<PureResultListener<AudioEmbedderResult>> resultListener(); abstract Optional<PureResultListener<AudioEmbedderResult>> resultListener();
@ -362,7 +378,9 @@ public final class AudioEmbedder extends BaseAudioTaskApi {
public static Builder builder() { public static Builder builder() {
return new AutoValue_AudioEmbedder_AudioEmbedderOptions.Builder() return new AutoValue_AudioEmbedder_AudioEmbedderOptions.Builder()
.setRunningMode(RunningMode.AUDIO_CLIPS); .setRunningMode(RunningMode.AUDIO_CLIPS)
.setL2Normalize(false)
.setQuantize(false);
} }
/** Converts a {@link AudioEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ /** Converts a {@link AudioEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */
@ -372,12 +390,14 @@ public final class AudioEmbedder extends BaseAudioTaskApi {
BaseOptionsProto.BaseOptions.newBuilder(); BaseOptionsProto.BaseOptions.newBuilder();
baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM); baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM);
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder =
EmbedderOptionsProto.EmbedderOptions.newBuilder();
embedderOptionsBuilder.setL2Normalize(l2Normalize());
embedderOptionsBuilder.setQuantize(quantize());
AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.Builder taskOptionsBuilder = AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.Builder taskOptionsBuilder =
AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.newBuilder() AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder); .setBaseOptions(baseOptionsBuilder)
if (embedderOptions().isPresent()) { .setEmbedderOptions(embedderOptionsBuilder);
taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto());
}
return CalculatorOptions.newBuilder() return CalculatorOptions.newBuilder()
.setExtension( .setExtension(
AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.ext, AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.ext,

View File

@ -29,19 +29,6 @@ android_library(
], ],
) )
android_library(
name = "embedderoptions",
srcs = ["EmbedderOptions.java"],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
# Expose the java source files for building mediapipe tasks core AAR. # Expose the java source files for building mediapipe tasks core AAR.
filegroup( filegroup(
name = "java_src", name = "java_src",

View File

@ -1,68 +0,0 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.processors;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto;
/** Embedder options shared across MediaPipe Java embedding tasks. */
@AutoValue
public abstract class EmbedderOptions {
/** Builder for {@link EmbedderOptions} */
@AutoValue.Builder
public abstract static class Builder {
/**
* Sets whether L2 normalization should be performed on the returned embeddings. Use this option
* only if the model does not already contain a native <code>L2_NORMALIZATION</code> TF Lite Op.
* In most cases, this is already the case and L2 norm is thus achieved through TF Lite
* inference.
*
* <p>False by default.
*/
public abstract Builder setL2Normalize(boolean l2Normalize);
/**
* Sets whether the returned embedding should be quantized to bytes via scalar quantization.
* Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is guaranteed
* to have value in <code>[-1.0, 1.0]</code>. Use {@link #setL2Normalize(boolean)} if this is
* not the case.
*
* <p>False by default.
*/
public abstract Builder setQuantize(boolean quantize);
public abstract EmbedderOptions build();
}
public abstract boolean l2Normalize();
public abstract boolean quantize();
public static Builder builder() {
return new AutoValue_EmbedderOptions.Builder().setL2Normalize(false).setQuantize(false);
}
/**
* Converts an {@link EmbedderOptions} object to an {@link EmbedderOptionsProto.EmbedderOptions}
* protobuf message.
*/
public EmbedderOptionsProto.EmbedderOptions convertToProto() {
return EmbedderOptionsProto.EmbedderOptions.newBuilder()
.setL2Normalize(l2Normalize())
.setQuantize(quantize())
.build();
}
}

View File

@ -74,11 +74,11 @@ android_library(
"//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",

View File

@ -25,7 +25,7 @@ import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.tasks.components.containers.Embedding; import com.google.mediapipe.tasks.components.containers.Embedding;
import com.google.mediapipe.tasks.components.containers.EmbeddingResult; import com.google.mediapipe.tasks.components.containers.EmbeddingResult;
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
import com.google.mediapipe.tasks.components.processors.EmbedderOptions; import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto;
import com.google.mediapipe.tasks.components.utils.CosineSimilarity; import com.google.mediapipe.tasks.components.utils.CosineSimilarity;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.OutputHandler; import com.google.mediapipe.tasks.core.OutputHandler;
@ -41,7 +41,6 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional;
/** /**
* Performs embedding extraction on text. * Performs embedding extraction on text.
@ -218,20 +217,38 @@ public final class TextEmbedder implements AutoCloseable {
public abstract Builder setBaseOptions(BaseOptions value); public abstract Builder setBaseOptions(BaseOptions value);
/** /**
* Sets the optional {@link EmbedderOptions} controling embedder behavior, such as * Sets whether L2 normalization should be performed on the returned embeddings. Use this
* L2-normalization and scalar quantization. * option only if the model does not already contain a native <code>L2_NORMALIZATION</code> TF
* Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF
* Lite inference.
*
* <p>False by default.
*/ */
public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); public abstract Builder setL2Normalize(boolean l2Normalize);
/**
* Sets whether the returned embedding should be quantized to bytes via scalar quantization.
* Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is
* guaranteed to have value in <code>[-1.0, 1.0]</code>. Use {@link #setL2Normalize(boolean)}
* if this is not the case.
*
* <p>False by default.
*/
public abstract Builder setQuantize(boolean quantize);
public abstract TextEmbedderOptions build(); public abstract TextEmbedderOptions build();
} }
abstract BaseOptions baseOptions(); abstract BaseOptions baseOptions();
abstract Optional<EmbedderOptions> embedderOptions(); abstract boolean l2Normalize();
abstract boolean quantize();
public static Builder builder() { public static Builder builder() {
return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder(); return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder()
.setL2Normalize(false)
.setQuantize(false);
} }
/** Converts a {@link TextEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ /** Converts a {@link TextEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */
@ -240,12 +257,14 @@ public final class TextEmbedder implements AutoCloseable {
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
BaseOptionsProto.BaseOptions.newBuilder(); BaseOptionsProto.BaseOptions.newBuilder();
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder =
EmbedderOptionsProto.EmbedderOptions.newBuilder();
embedderOptionsBuilder.setL2Normalize(l2Normalize());
embedderOptionsBuilder.setQuantize(quantize());
TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.Builder taskOptionsBuilder = TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.Builder taskOptionsBuilder =
TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.newBuilder() TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder); .setBaseOptions(baseOptionsBuilder)
if (embedderOptions().isPresent()) { .setEmbedderOptions(embedderOptionsBuilder);
taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto());
}
return CalculatorOptions.newBuilder() return CalculatorOptions.newBuilder()
.setExtension( .setExtension(
TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.ext, TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.ext,

View File

@ -190,11 +190,11 @@ android_library(
"//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue", "//third_party:autovalue",

View File

@ -28,7 +28,7 @@ import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.components.containers.Embedding; import com.google.mediapipe.tasks.components.containers.Embedding;
import com.google.mediapipe.tasks.components.containers.EmbeddingResult; import com.google.mediapipe.tasks.components.containers.EmbeddingResult;
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
import com.google.mediapipe.tasks.components.processors.EmbedderOptions; import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto;
import com.google.mediapipe.tasks.components.utils.CosineSimilarity; import com.google.mediapipe.tasks.components.utils.CosineSimilarity;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.ErrorListener;
@ -369,10 +369,24 @@ public final class ImageEmbedder extends BaseVisionTaskApi {
public abstract Builder setRunningMode(RunningMode runningMode); public abstract Builder setRunningMode(RunningMode runningMode);
/** /**
* Sets the optional {@link EmbedderOptions} controling embedder behavior, such as * Sets whether L2 normalization should be performed on the returned embeddings. Use this
* L2-normalization and scalar quantization. * option only if the model does not already contain a native <code>L2_NORMALIZATION</code> TF
* Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF
* Lite inference.
*
* <p>False by default.
*/ */
public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); public abstract Builder setL2Normalize(boolean l2Normalize);
/**
* Sets whether the returned embedding should be quantized to bytes via scalar quantization.
* Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is
* guaranteed to have value in <code>[-1.0, 1.0]</code>. Use {@link #setL2Normalize(boolean)}
* if this is not the case.
*
* <p>False by default.
*/
public abstract Builder setQuantize(boolean quantize);
/** /**
* Sets the {@link ResultListener} to receive the embedding results asynchronously when the * Sets the {@link ResultListener} to receive the embedding results asynchronously when the
@ -414,7 +428,9 @@ public final class ImageEmbedder extends BaseVisionTaskApi {
abstract RunningMode runningMode(); abstract RunningMode runningMode();
abstract Optional<EmbedderOptions> embedderOptions(); abstract boolean l2Normalize();
abstract boolean quantize();
abstract Optional<ResultListener<ImageEmbedderResult, MPImage>> resultListener(); abstract Optional<ResultListener<ImageEmbedderResult, MPImage>> resultListener();
@ -422,7 +438,9 @@ public final class ImageEmbedder extends BaseVisionTaskApi {
public static Builder builder() { public static Builder builder() {
return new AutoValue_ImageEmbedder_ImageEmbedderOptions.Builder() return new AutoValue_ImageEmbedder_ImageEmbedderOptions.Builder()
.setRunningMode(RunningMode.IMAGE); .setRunningMode(RunningMode.IMAGE)
.setL2Normalize(false)
.setQuantize(false);
} }
/** Converts a {@link ImageEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ /** Converts a {@link ImageEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */
@ -432,12 +450,14 @@ public final class ImageEmbedder extends BaseVisionTaskApi {
BaseOptionsProto.BaseOptions.newBuilder(); BaseOptionsProto.BaseOptions.newBuilder();
baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder =
EmbedderOptionsProto.EmbedderOptions.newBuilder();
embedderOptionsBuilder.setL2Normalize(l2Normalize());
embedderOptionsBuilder.setQuantize(quantize());
ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.Builder taskOptionsBuilder = ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.Builder taskOptionsBuilder =
ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.newBuilder() ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder); .setBaseOptions(baseOptionsBuilder)
if (embedderOptions().isPresent()) { .setEmbedderOptions(embedderOptionsBuilder);
taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto());
}
return CalculatorOptions.newBuilder() return CalculatorOptions.newBuilder()
.setExtension( .setExtension(
ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.ext, ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.ext,

View File

@ -25,7 +25,6 @@ import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.components.processors.EmbedderOptions;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.TestUtils; import com.google.mediapipe.tasks.core.TestUtils;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
@ -92,8 +91,8 @@ public class ImageEmbedderTest {
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
// Check results. // Check results.
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false);
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false);
// Check similarity. // Check similarity.
double similarity = double similarity =
ImageEmbedder.cosineSimilarity( ImageEmbedder.cosineSimilarity(
@ -105,12 +104,8 @@ public class ImageEmbedderTest {
@Test @Test
public void embed_succeedsWithL2Normalization() throws Exception { public void embed_succeedsWithL2Normalization() throws Exception {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build();
EmbedderOptions embedderOptions = EmbedderOptions.builder().setL2Normalize(true).build();
ImageEmbedderOptions options = ImageEmbedderOptions options =
ImageEmbedderOptions.builder() ImageEmbedderOptions.builder().setBaseOptions(baseOptions).setL2Normalize(true).build();
.setBaseOptions(baseOptions)
.setEmbedderOptions(embedderOptions)
.build();
ImageEmbedder imageEmbedder = ImageEmbedder imageEmbedder =
ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -118,8 +113,8 @@ public class ImageEmbedderTest {
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
// Check results. // Check results.
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false);
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false);
// Check similarity. // Check similarity.
double similarity = double similarity =
ImageEmbedder.cosineSimilarity( ImageEmbedder.cosineSimilarity(
@ -131,12 +126,8 @@ public class ImageEmbedderTest {
@Test @Test
public void embed_succeedsWithQuantization() throws Exception { public void embed_succeedsWithQuantization() throws Exception {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build();
EmbedderOptions embedderOptions = EmbedderOptions.builder().setQuantize(true).build();
ImageEmbedderOptions options = ImageEmbedderOptions options =
ImageEmbedderOptions.builder() ImageEmbedderOptions.builder().setBaseOptions(baseOptions).setQuantize(true).build();
.setBaseOptions(baseOptions)
.setEmbedderOptions(embedderOptions)
.build();
ImageEmbedder imageEmbedder = ImageEmbedder imageEmbedder =
ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -144,8 +135,8 @@ public class ImageEmbedderTest {
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
// Check results. // Check results.
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ true); assertHasOneHeadAndCorrectDimension(result, /* quantized= */ true);
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ true); assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ true);
// Check similarity. // Check similarity.
double similarity = double similarity =
ImageEmbedder.cosineSimilarity( ImageEmbedder.cosineSimilarity(
@ -168,8 +159,8 @@ public class ImageEmbedderTest {
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
// Check results. // Check results.
assertHasOneHeadAndCorrectDimension(resultRoi, /*quantized=*/ false); assertHasOneHeadAndCorrectDimension(resultRoi, /* quantized= */ false);
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false);
// Check similarity. // Check similarity.
double similarity = double similarity =
ImageEmbedder.cosineSimilarity( ImageEmbedder.cosineSimilarity(
@ -190,8 +181,8 @@ public class ImageEmbedderTest {
imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions);
// Check results. // Check results.
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false);
assertHasOneHeadAndCorrectDimension(resultRotated, /*quantized=*/ false); assertHasOneHeadAndCorrectDimension(resultRotated, /* quantized= */ false);
// Check similarity. // Check similarity.
double similarity = double similarity =
ImageEmbedder.cosineSimilarity( ImageEmbedder.cosineSimilarity(
@ -214,8 +205,8 @@ public class ImageEmbedderTest {
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
// Check results. // Check results.
assertHasOneHeadAndCorrectDimension(resultRoiRotated, /*quantized=*/ false); assertHasOneHeadAndCorrectDimension(resultRoiRotated, /* quantized= */ false);
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false);
// Check similarity. // Check similarity.
double similarity = double similarity =
ImageEmbedder.cosineSimilarity( ImageEmbedder.cosineSimilarity(
@ -277,12 +268,14 @@ public class ImageEmbedderTest {
assertThrows( assertThrows(
MediaPipeException.class, MediaPipeException.class,
() -> () ->
imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); imageEmbedder.embedForVideo(
getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
exception = exception =
assertThrows( assertThrows(
MediaPipeException.class, MediaPipeException.class,
() -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); () ->
imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
} }
@ -303,7 +296,8 @@ public class ImageEmbedderTest {
exception = exception =
assertThrows( assertThrows(
MediaPipeException.class, MediaPipeException.class,
() -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); () ->
imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
} }
@ -327,7 +321,8 @@ public class ImageEmbedderTest {
assertThrows( assertThrows(
MediaPipeException.class, MediaPipeException.class,
() -> () ->
imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); imageEmbedder.embedForVideo(
getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
} }
@ -340,8 +335,8 @@ public class ImageEmbedderTest {
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
// Check results. // Check results.
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false);
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false);
// Check similarity. // Check similarity.
double similarity = double similarity =
ImageEmbedder.cosineSimilarity( ImageEmbedder.cosineSimilarity(
@ -363,8 +358,8 @@ public class ImageEmbedderTest {
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
ImageEmbedderResult result = ImageEmbedderResult result =
imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ i); imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ i);
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false);
} }
} }
@ -378,17 +373,18 @@ public class ImageEmbedderTest {
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener( .setResultListener(
(imageEmbedderResult, inputImage) -> { (imageEmbedderResult, inputImage) -> {
assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false); assertHasOneHeadAndCorrectDimension(
imageEmbedderResult, /* quantized= */ false);
assertImageSizeIsExpected(inputImage); assertImageSizeIsExpected(inputImage);
}) })
.build(); .build();
try (ImageEmbedder imageEmbedder = try (ImageEmbedder imageEmbedder =
ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1); imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 1);
MediaPipeException exception = MediaPipeException exception =
assertThrows( assertThrows(
MediaPipeException.class, MediaPipeException.class,
() -> imageEmbedder.embedAsync(image, /*timestampMs=*/ 0)); () -> imageEmbedder.embedAsync(image, /* timestampMs= */ 0));
assertThat(exception) assertThat(exception)
.hasMessageThat() .hasMessageThat()
.contains("having a smaller timestamp than the processed timestamp"); .contains("having a smaller timestamp than the processed timestamp");
@ -405,14 +401,15 @@ public class ImageEmbedderTest {
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener( .setResultListener(
(imageEmbedderResult, inputImage) -> { (imageEmbedderResult, inputImage) -> {
assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false); assertHasOneHeadAndCorrectDimension(
imageEmbedderResult, /* quantized= */ false);
assertImageSizeIsExpected(inputImage); assertImageSizeIsExpected(inputImage);
}) })
.build(); .build();
try (ImageEmbedder imageEmbedder = try (ImageEmbedder imageEmbedder =
ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
imageEmbedder.embedAsync(image, /*timestampMs=*/ i); imageEmbedder.embedAsync(image, /* timestampMs= */ i);
} }
} }
} }