Internal change
PiperOrigin-RevId: 492188196
This commit is contained in:
parent
29c7702984
commit
01010fa248
|
@ -92,12 +92,12 @@ android_library(
|
|||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/audio:libmediapipe_tasks_audio_jni_lib",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:autovalue",
|
||||
|
|
|
@ -28,7 +28,7 @@ import com.google.mediapipe.tasks.audio.core.RunningMode;
|
|||
import com.google.mediapipe.tasks.components.containers.AudioData;
|
||||
import com.google.mediapipe.tasks.components.containers.Embedding;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
|
||||
import com.google.mediapipe.tasks.components.processors.EmbedderOptions;
|
||||
import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto;
|
||||
import com.google.mediapipe.tasks.components.utils.CosineSimilarity;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
|
@ -309,10 +309,24 @@ public final class AudioEmbedder extends BaseAudioTaskApi {
|
|||
public abstract Builder setRunningMode(RunningMode runningMode);
|
||||
|
||||
/**
|
||||
* Sets the optional {@link EmbedderOptions} controling embedding behavior, such as score
|
||||
* threshold, number of results, etc.
|
||||
* Sets whether L2 normalization should be performed on the returned embeddings. Use this
|
||||
* option only if the model does not already contain a native <code>L2_NORMALIZATION</code> TF
|
||||
* Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF
|
||||
* Lite inference.
|
||||
*
|
||||
* <p>False by default.
|
||||
*/
|
||||
public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions);
|
||||
public abstract Builder setL2Normalize(boolean l2Normalize);
|
||||
|
||||
/**
|
||||
* Sets whether the returned embedding should be quantized to bytes via scalar quantization.
|
||||
* Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is
|
||||
* guaranteed to have value in <code>[-1.0, 1.0]</code>. Use {@link #setL2Normalize(boolean)}
|
||||
* if this is not the case.
|
||||
*
|
||||
* <p>False by default.
|
||||
*/
|
||||
public abstract Builder setQuantize(boolean quantize);
|
||||
|
||||
/**
|
||||
* Sets the {@link ResultListener} to receive the embedding results asynchronously when the
|
||||
|
@ -354,7 +368,9 @@ public final class AudioEmbedder extends BaseAudioTaskApi {
|
|||
|
||||
abstract RunningMode runningMode();
|
||||
|
||||
abstract Optional<EmbedderOptions> embedderOptions();
|
||||
abstract boolean l2Normalize();
|
||||
|
||||
abstract boolean quantize();
|
||||
|
||||
abstract Optional<PureResultListener<AudioEmbedderResult>> resultListener();
|
||||
|
||||
|
@ -362,7 +378,9 @@ public final class AudioEmbedder extends BaseAudioTaskApi {
|
|||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_AudioEmbedder_AudioEmbedderOptions.Builder()
|
||||
.setRunningMode(RunningMode.AUDIO_CLIPS);
|
||||
.setRunningMode(RunningMode.AUDIO_CLIPS)
|
||||
.setL2Normalize(false)
|
||||
.setQuantize(false);
|
||||
}
|
||||
|
||||
/** Converts a {@link AudioEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */
|
||||
|
@ -372,12 +390,14 @@ public final class AudioEmbedder extends BaseAudioTaskApi {
|
|||
BaseOptionsProto.BaseOptions.newBuilder();
|
||||
baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM);
|
||||
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||
EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder =
|
||||
EmbedderOptionsProto.EmbedderOptions.newBuilder();
|
||||
embedderOptionsBuilder.setL2Normalize(l2Normalize());
|
||||
embedderOptionsBuilder.setQuantize(quantize());
|
||||
AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.Builder taskOptionsBuilder =
|
||||
AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.newBuilder()
|
||||
.setBaseOptions(baseOptionsBuilder);
|
||||
if (embedderOptions().isPresent()) {
|
||||
taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto());
|
||||
}
|
||||
.setBaseOptions(baseOptionsBuilder)
|
||||
.setEmbedderOptions(embedderOptionsBuilder);
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.ext,
|
||||
|
|
|
@ -29,19 +29,6 @@ android_library(
|
|||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "embedderoptions",
|
||||
srcs = ["EmbedderOptions.java"],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
# Expose the java source files for building mediapipe tasks core AAR.
|
||||
filegroup(
|
||||
name = "java_src",
|
||||
|
|
|
@ -1,68 +0,0 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.processors;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto;
|
||||
|
||||
/** Embedder options shared across MediaPipe Java embedding tasks. */
|
||||
@AutoValue
|
||||
public abstract class EmbedderOptions {
|
||||
|
||||
/** Builder for {@link EmbedderOptions} */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/**
|
||||
* Sets whether L2 normalization should be performed on the returned embeddings. Use this option
|
||||
* only if the model does not already contain a native <code>L2_NORMALIZATION</code> TF Lite Op.
|
||||
* In most cases, this is already the case and L2 norm is thus achieved through TF Lite
|
||||
* inference.
|
||||
*
|
||||
* <p>False by default.
|
||||
*/
|
||||
public abstract Builder setL2Normalize(boolean l2Normalize);
|
||||
|
||||
/**
|
||||
* Sets whether the returned embedding should be quantized to bytes via scalar quantization.
|
||||
* Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is guaranteed
|
||||
* to have value in <code>[-1.0, 1.0]</code>. Use {@link #setL2Normalize(boolean)} if this is
|
||||
* not the case.
|
||||
*
|
||||
* <p>False by default.
|
||||
*/
|
||||
public abstract Builder setQuantize(boolean quantize);
|
||||
|
||||
public abstract EmbedderOptions build();
|
||||
}
|
||||
|
||||
public abstract boolean l2Normalize();
|
||||
|
||||
public abstract boolean quantize();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_EmbedderOptions.Builder().setL2Normalize(false).setQuantize(false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts an {@link EmbedderOptions} object to an {@link EmbedderOptionsProto.EmbedderOptions}
|
||||
* protobuf message.
|
||||
*/
|
||||
public EmbedderOptionsProto.EmbedderOptions convertToProto() {
|
||||
return EmbedderOptionsProto.EmbedderOptions.newBuilder()
|
||||
.setL2Normalize(l2Normalize())
|
||||
.setQuantize(quantize())
|
||||
.build();
|
||||
}
|
||||
}
|
|
@ -74,11 +74,11 @@ android_library(
|
|||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
|
||||
|
|
|
@ -25,7 +25,7 @@ import com.google.mediapipe.framework.ProtoUtil;
|
|||
import com.google.mediapipe.tasks.components.containers.Embedding;
|
||||
import com.google.mediapipe.tasks.components.containers.EmbeddingResult;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
|
||||
import com.google.mediapipe.tasks.components.processors.EmbedderOptions;
|
||||
import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto;
|
||||
import com.google.mediapipe.tasks.components.utils.CosineSimilarity;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
|
@ -41,7 +41,6 @@ import java.util.Collections;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Performs embedding extraction on text.
|
||||
|
@ -218,20 +217,38 @@ public final class TextEmbedder implements AutoCloseable {
|
|||
public abstract Builder setBaseOptions(BaseOptions value);
|
||||
|
||||
/**
|
||||
* Sets the optional {@link EmbedderOptions} controling embedder behavior, such as
|
||||
* L2-normalization and scalar quantization.
|
||||
* Sets whether L2 normalization should be performed on the returned embeddings. Use this
|
||||
* option only if the model does not already contain a native <code>L2_NORMALIZATION</code> TF
|
||||
* Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF
|
||||
* Lite inference.
|
||||
*
|
||||
* <p>False by default.
|
||||
*/
|
||||
public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions);
|
||||
public abstract Builder setL2Normalize(boolean l2Normalize);
|
||||
|
||||
/**
|
||||
* Sets whether the returned embedding should be quantized to bytes via scalar quantization.
|
||||
* Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is
|
||||
* guaranteed to have value in <code>[-1.0, 1.0]</code>. Use {@link #setL2Normalize(boolean)}
|
||||
* if this is not the case.
|
||||
*
|
||||
* <p>False by default.
|
||||
*/
|
||||
public abstract Builder setQuantize(boolean quantize);
|
||||
|
||||
public abstract TextEmbedderOptions build();
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract Optional<EmbedderOptions> embedderOptions();
|
||||
abstract boolean l2Normalize();
|
||||
|
||||
abstract boolean quantize();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder();
|
||||
return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder()
|
||||
.setL2Normalize(false)
|
||||
.setQuantize(false);
|
||||
}
|
||||
|
||||
/** Converts a {@link TextEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */
|
||||
|
@ -240,12 +257,14 @@ public final class TextEmbedder implements AutoCloseable {
|
|||
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
|
||||
BaseOptionsProto.BaseOptions.newBuilder();
|
||||
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||
EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder =
|
||||
EmbedderOptionsProto.EmbedderOptions.newBuilder();
|
||||
embedderOptionsBuilder.setL2Normalize(l2Normalize());
|
||||
embedderOptionsBuilder.setQuantize(quantize());
|
||||
TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.Builder taskOptionsBuilder =
|
||||
TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.newBuilder()
|
||||
.setBaseOptions(baseOptionsBuilder);
|
||||
if (embedderOptions().isPresent()) {
|
||||
taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto());
|
||||
}
|
||||
.setBaseOptions(baseOptionsBuilder)
|
||||
.setEmbedderOptions(embedderOptionsBuilder);
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.ext,
|
||||
|
|
|
@ -190,11 +190,11 @@ android_library(
|
|||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:autovalue",
|
||||
|
|
|
@ -28,7 +28,7 @@ import com.google.mediapipe.framework.image.MPImage;
|
|||
import com.google.mediapipe.tasks.components.containers.Embedding;
|
||||
import com.google.mediapipe.tasks.components.containers.EmbeddingResult;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
|
||||
import com.google.mediapipe.tasks.components.processors.EmbedderOptions;
|
||||
import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto;
|
||||
import com.google.mediapipe.tasks.components.utils.CosineSimilarity;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
|
@ -369,10 +369,24 @@ public final class ImageEmbedder extends BaseVisionTaskApi {
|
|||
public abstract Builder setRunningMode(RunningMode runningMode);
|
||||
|
||||
/**
|
||||
* Sets the optional {@link EmbedderOptions} controling embedder behavior, such as
|
||||
* L2-normalization and scalar quantization.
|
||||
* Sets whether L2 normalization should be performed on the returned embeddings. Use this
|
||||
* option only if the model does not already contain a native <code>L2_NORMALIZATION</code> TF
|
||||
* Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF
|
||||
* Lite inference.
|
||||
*
|
||||
* <p>False by default.
|
||||
*/
|
||||
public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions);
|
||||
public abstract Builder setL2Normalize(boolean l2Normalize);
|
||||
|
||||
/**
|
||||
* Sets whether the returned embedding should be quantized to bytes via scalar quantization.
|
||||
* Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is
|
||||
* guaranteed to have value in <code>[-1.0, 1.0]</code>. Use {@link #setL2Normalize(boolean)}
|
||||
* if this is not the case.
|
||||
*
|
||||
* <p>False by default.
|
||||
*/
|
||||
public abstract Builder setQuantize(boolean quantize);
|
||||
|
||||
/**
|
||||
* Sets the {@link ResultListener} to receive the embedding results asynchronously when the
|
||||
|
@ -414,7 +428,9 @@ public final class ImageEmbedder extends BaseVisionTaskApi {
|
|||
|
||||
abstract RunningMode runningMode();
|
||||
|
||||
abstract Optional<EmbedderOptions> embedderOptions();
|
||||
abstract boolean l2Normalize();
|
||||
|
||||
abstract boolean quantize();
|
||||
|
||||
abstract Optional<ResultListener<ImageEmbedderResult, MPImage>> resultListener();
|
||||
|
||||
|
@ -422,7 +438,9 @@ public final class ImageEmbedder extends BaseVisionTaskApi {
|
|||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageEmbedder_ImageEmbedderOptions.Builder()
|
||||
.setRunningMode(RunningMode.IMAGE);
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.setL2Normalize(false)
|
||||
.setQuantize(false);
|
||||
}
|
||||
|
||||
/** Converts a {@link ImageEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */
|
||||
|
@ -432,12 +450,14 @@ public final class ImageEmbedder extends BaseVisionTaskApi {
|
|||
BaseOptionsProto.BaseOptions.newBuilder();
|
||||
baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
|
||||
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||
EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder =
|
||||
EmbedderOptionsProto.EmbedderOptions.newBuilder();
|
||||
embedderOptionsBuilder.setL2Normalize(l2Normalize());
|
||||
embedderOptionsBuilder.setQuantize(quantize());
|
||||
ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.Builder taskOptionsBuilder =
|
||||
ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.newBuilder()
|
||||
.setBaseOptions(baseOptionsBuilder);
|
||||
if (embedderOptions().isPresent()) {
|
||||
taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto());
|
||||
}
|
||||
.setBaseOptions(baseOptionsBuilder)
|
||||
.setEmbedderOptions(embedderOptionsBuilder);
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.ext,
|
||||
|
|
|
@ -25,7 +25,6 @@ import androidx.test.ext.junit.runners.AndroidJUnit4;
|
|||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.components.processors.EmbedderOptions;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.TestUtils;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
|
@ -92,8 +91,8 @@ public class ImageEmbedderTest {
|
|||
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
|
||||
|
||||
// Check results.
|
||||
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false);
|
||||
// Check similarity.
|
||||
double similarity =
|
||||
ImageEmbedder.cosineSimilarity(
|
||||
|
@ -105,12 +104,8 @@ public class ImageEmbedderTest {
|
|||
@Test
|
||||
public void embed_succeedsWithL2Normalization() throws Exception {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build();
|
||||
EmbedderOptions embedderOptions = EmbedderOptions.builder().setL2Normalize(true).build();
|
||||
ImageEmbedderOptions options =
|
||||
ImageEmbedderOptions.builder()
|
||||
.setBaseOptions(baseOptions)
|
||||
.setEmbedderOptions(embedderOptions)
|
||||
.build();
|
||||
ImageEmbedderOptions.builder().setBaseOptions(baseOptions).setL2Normalize(true).build();
|
||||
|
||||
ImageEmbedder imageEmbedder =
|
||||
ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
|
@ -118,8 +113,8 @@ public class ImageEmbedderTest {
|
|||
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
|
||||
|
||||
// Check results.
|
||||
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false);
|
||||
// Check similarity.
|
||||
double similarity =
|
||||
ImageEmbedder.cosineSimilarity(
|
||||
|
@ -131,12 +126,8 @@ public class ImageEmbedderTest {
|
|||
@Test
|
||||
public void embed_succeedsWithQuantization() throws Exception {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build();
|
||||
EmbedderOptions embedderOptions = EmbedderOptions.builder().setQuantize(true).build();
|
||||
ImageEmbedderOptions options =
|
||||
ImageEmbedderOptions.builder()
|
||||
.setBaseOptions(baseOptions)
|
||||
.setEmbedderOptions(embedderOptions)
|
||||
.build();
|
||||
ImageEmbedderOptions.builder().setBaseOptions(baseOptions).setQuantize(true).build();
|
||||
|
||||
ImageEmbedder imageEmbedder =
|
||||
ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
|
@ -144,8 +135,8 @@ public class ImageEmbedderTest {
|
|||
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
|
||||
|
||||
// Check results.
|
||||
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ true);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ true);
|
||||
assertHasOneHeadAndCorrectDimension(result, /* quantized= */ true);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ true);
|
||||
// Check similarity.
|
||||
double similarity =
|
||||
ImageEmbedder.cosineSimilarity(
|
||||
|
@ -168,8 +159,8 @@ public class ImageEmbedderTest {
|
|||
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
|
||||
|
||||
// Check results.
|
||||
assertHasOneHeadAndCorrectDimension(resultRoi, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultRoi, /* quantized= */ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false);
|
||||
// Check similarity.
|
||||
double similarity =
|
||||
ImageEmbedder.cosineSimilarity(
|
||||
|
@ -190,8 +181,8 @@ public class ImageEmbedderTest {
|
|||
imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions);
|
||||
|
||||
// Check results.
|
||||
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultRotated, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultRotated, /* quantized= */ false);
|
||||
// Check similarity.
|
||||
double similarity =
|
||||
ImageEmbedder.cosineSimilarity(
|
||||
|
@ -214,8 +205,8 @@ public class ImageEmbedderTest {
|
|||
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
|
||||
|
||||
// Check results.
|
||||
assertHasOneHeadAndCorrectDimension(resultRoiRotated, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultRoiRotated, /* quantized= */ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false);
|
||||
// Check similarity.
|
||||
double similarity =
|
||||
ImageEmbedder.cosineSimilarity(
|
||||
|
@ -277,12 +268,14 @@ public class ImageEmbedderTest {
|
|||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||
imageEmbedder.embedForVideo(
|
||||
getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||
() ->
|
||||
imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
|
@ -303,7 +296,8 @@ public class ImageEmbedderTest {
|
|||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||
() ->
|
||||
imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
|
@ -327,7 +321,8 @@ public class ImageEmbedderTest {
|
|||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||
imageEmbedder.embedForVideo(
|
||||
getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
}
|
||||
|
||||
|
@ -340,8 +335,8 @@ public class ImageEmbedderTest {
|
|||
ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE));
|
||||
|
||||
// Check results.
|
||||
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false);
|
||||
assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false);
|
||||
// Check similarity.
|
||||
double similarity =
|
||||
ImageEmbedder.cosineSimilarity(
|
||||
|
@ -363,8 +358,8 @@ public class ImageEmbedderTest {
|
|||
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
ImageEmbedderResult result =
|
||||
imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ i);
|
||||
assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false);
|
||||
imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ i);
|
||||
assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -378,17 +373,18 @@ public class ImageEmbedderTest {
|
|||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(imageEmbedderResult, inputImage) -> {
|
||||
assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(
|
||||
imageEmbedderResult, /* quantized= */ false);
|
||||
assertImageSizeIsExpected(inputImage);
|
||||
})
|
||||
.build();
|
||||
try (ImageEmbedder imageEmbedder =
|
||||
ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1);
|
||||
imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 1);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> imageEmbedder.embedAsync(image, /*timestampMs=*/ 0));
|
||||
() -> imageEmbedder.embedAsync(image, /* timestampMs= */ 0));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("having a smaller timestamp than the processed timestamp");
|
||||
|
@ -405,14 +401,15 @@ public class ImageEmbedderTest {
|
|||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(imageEmbedderResult, inputImage) -> {
|
||||
assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false);
|
||||
assertHasOneHeadAndCorrectDimension(
|
||||
imageEmbedderResult, /* quantized= */ false);
|
||||
assertImageSizeIsExpected(inputImage);
|
||||
})
|
||||
.build();
|
||||
try (ImageEmbedder imageEmbedder =
|
||||
ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
imageEmbedder.embedAsync(image, /*timestampMs=*/ i);
|
||||
imageEmbedder.embedAsync(image, /* timestampMs= */ i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user