diff --git a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto
index 6b8d41a57..e7e3a63c7 100644
--- a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto
+++ b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto
@@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
+option java_package = "com.google.mediapipe.tasks.text.textembedder.proto";
+option java_outer_classname = "TextEmbedderGraphOptionsProto";
+
message TextEmbedderGraphOptions {
extend mediapipe.CalculatorOptions {
optional TextEmbedderGraphOptions ext = 477589892;
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java
index 9bf600360..0fc48742e 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java
@@ -58,8 +58,8 @@ public abstract class TaskOptions {
AccelerationProto.Acceleration.newBuilder();
switch (options.delegate()) {
case CPU:
- accelerationBuilder.setXnnpack(
- InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Xnnpack
+ accelerationBuilder.setTflite(
+ InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.TfLite
.getDefaultInstance());
break;
case GPU:
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl
index f0c9f81c6..ab7ad6616 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl
@@ -49,6 +49,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
+ "//mediapipe/tasks/cc/text/text_classifier/proto:text_embedder_graph_options_java_proto_lite",
]
def mediapipe_tasks_core_aar(name, srcs, manifest):
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD
index b49169529..0e72878ab 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD
@@ -24,6 +24,7 @@ cc_binary(
deps = [
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
+ "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
],
)
@@ -60,6 +61,33 @@ android_library(
],
)
+android_library(
+ name = "textembedder",
+ srcs = [
+ "textembedder/TextEmbedder.java",
+ "textembedder/TextEmbedderResult.java",
+ ],
+ javacopts = [
+ "-Xep:AndroidJdkLibsChecker:OFF",
+ ],
+ manifest = "textembedder/AndroidManifest.xml",
+ deps = [
+ "//mediapipe/framework:calculator_options_java_proto_lite",
+ "//mediapipe/java/com/google/mediapipe/framework:android_framework",
+ "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
+ "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
+ "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
+ "//third_party:autovalue",
+ "@maven//:com_google_guava_guava",
+ ],
+)
+
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_text_aar")
mediapipe_tasks_text_aar(
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml
new file mode 100644
index 000000000..d9c885d16
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java
new file mode 100644
index 000000000..95fa1f087
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java
@@ -0,0 +1,256 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.text.textembedder;
+
+import android.content.Context;
+import android.os.ParcelFileDescriptor;
+import com.google.auto.value.AutoValue;
+import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
+import com.google.mediapipe.framework.MediaPipeException;
+import com.google.mediapipe.framework.Packet;
+import com.google.mediapipe.framework.PacketGetter;
+import com.google.mediapipe.framework.ProtoUtil;
+import com.google.mediapipe.tasks.components.containers.Embedding;
+import com.google.mediapipe.tasks.components.containers.EmbeddingResult;
+import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
+import com.google.mediapipe.tasks.components.processors.EmbedderOptions;
+import com.google.mediapipe.tasks.components.utils.CosineSimilarity;
+import com.google.mediapipe.tasks.core.BaseOptions;
+import com.google.mediapipe.tasks.core.OutputHandler;
+import com.google.mediapipe.tasks.core.TaskInfo;
+import com.google.mediapipe.tasks.core.TaskOptions;
+import com.google.mediapipe.tasks.core.TaskRunner;
+import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
+import com.google.mediapipe.tasks.text.textembedder.proto.TextEmbedderGraphOptionsProto;
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * Performs embedding extraction on text.
+ *
+ *
This API expects a TFLite model with (optional) TFLite Model Metadata.
+ *
+ *
Metadata is required for models with int32 input tensors because it contains the input process
+ * unit for the model's Tokenizer. No metadata is required for models with string input tensors.
+ *
+ *
+ * - Input tensors
+ *
+ * - Three input tensors ({@code kTfLiteInt32}) of shape {@code [batch_size x
+ * bert_max_seq_len]} representing the input ids, mask ids, and segment ids. This input
+ * signature requires a Bert Tokenizer process unit in the model metadata.
+ *
- Or one input tensor ({@code kTfLiteInt32}) of shape {@code [batch_size x
+ * max_seq_len]} representing the input ids. This input signature requires a Regex
+ * Tokenizer process unit in the model metadata.
+ *
- Or one input tensor ({@code kTfLiteString}) that is shapeless or has shape {@code
+ * [1]} containing the input string.
+ *
+ * - At least one output tensor ({@code kTfLiteFloat32}/{@code kTfLiteUint8}) with shape {@code
+ * [1 x N]} where N is the number of dimensions in the produced embeddings.
+ *
+ */
+public final class TextEmbedder implements AutoCloseable {
+ private static final String TAG = TextEmbedder.class.getSimpleName();
+ private static final String TEXT_IN_STREAM_NAME = "text_in";
+
+ @SuppressWarnings("ConstantCaseForConstants")
+ private static final List INPUT_STREAMS =
+ Collections.unmodifiableList(Arrays.asList("TEXT:" + TEXT_IN_STREAM_NAME));
+
+ @SuppressWarnings("ConstantCaseForConstants")
+ private static final List OUTPUT_STREAMS =
+ Collections.unmodifiableList(Arrays.asList("EMBEDDINGS:embeddings_out"));
+
+ private static final int EMBEDDINGS_OUT_STREAM_INDEX = 0;
+ private static final String TASK_GRAPH_NAME =
+ "mediapipe.tasks.text.text_embedder.TextEmbedderGraph";
+ private final TaskRunner runner;
+
+ static {
+ System.loadLibrary("mediapipe_tasks_text_jni");
+ ProtoUtil.registerTypeName(
+ EmbeddingsProto.EmbeddingResult.class,
+ "mediapipe.tasks.components.containers.proto.EmbeddingResult");
+ }
+
+ /**
+ * Creates a {@link TextEmbedder} instance from a model file and the default {@link
+ * TextEmbedderOptions}.
+ *
+ * @param context an Android {@link Context}.
+ * @param modelPath path to the text model with metadata in the assets.
+ * @throws MediaPipeException if there is is an error during {@link TextEmbedder} creation.
+ */
+ public static TextEmbedder createFromFile(Context context, String modelPath) {
+ BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
+ return createFromOptions(
+ context, TextEmbedderOptions.builder().setBaseOptions(baseOptions).build());
+ }
+
+ /**
+ * Creates a {@link TextEmbedder} instance from a model file and the default {@link
+ * TextEmbedderOptions}.
+ *
+ * @param context an Android {@link Context}.
+ * @param modelFile the text model {@link File} instance.
+ * @throws IOException if an I/O error occurs when opening the tflite model file.
+ * @throws MediaPipeException if there is an error during {@link TextEmbedder} creation.
+ */
+ public static TextEmbedder createFromFile(Context context, File modelFile) throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ BaseOptions baseOptions =
+ BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
+ return createFromOptions(
+ context, TextEmbedderOptions.builder().setBaseOptions(baseOptions).build());
+ }
+ }
+
+ /**
+ * Creates a {@link TextEmbedder} instance from {@link TextEmbedderOptions}.
+ *
+ * @param context an Android {@link Context}.
+ * @param options a {@link TextEmbedderOptions} instance.
+ * @throws MediaPipeException if there is an error during {@link TextEmbedder} creation.
+ */
+ public static TextEmbedder createFromOptions(Context context, TextEmbedderOptions options) {
+ OutputHandler handler = new OutputHandler<>();
+ handler.setOutputPacketConverter(
+ new OutputHandler.OutputPacketConverter() {
+ @Override
+ public TextEmbedderResult convertToTaskResult(List packets) {
+ try {
+ return TextEmbedderResult.create(
+ EmbeddingResult.createFromProto(
+ PacketGetter.getProto(
+ packets.get(EMBEDDINGS_OUT_STREAM_INDEX),
+ EmbeddingsProto.EmbeddingResult.getDefaultInstance())),
+ packets.get(EMBEDDINGS_OUT_STREAM_INDEX).getTimestamp());
+ } catch (IOException e) {
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
+ }
+ }
+
+ @Override
+ public Void convertToTaskInput(List packets) {
+ return null;
+ }
+ });
+ TaskRunner runner =
+ TaskRunner.create(
+ context,
+ TaskInfo.builder()
+ .setTaskGraphName(TASK_GRAPH_NAME)
+ .setInputStreams(INPUT_STREAMS)
+ .setOutputStreams(OUTPUT_STREAMS)
+ .setTaskOptions(options)
+ .setEnableFlowLimiting(false)
+ .build(),
+ handler);
+ return new TextEmbedder(runner);
+ }
+
+ /**
+ * Constructor to initialize a {@link TextEmbedder} from a {@link TaskRunner}.
+ *
+ * @param runner a {@link TaskRunner}.
+ */
+ private TextEmbedder(TaskRunner runner) {
+ this.runner = runner;
+ }
+
+ /**
+ * Performs embedding extraction on the input text.
+ *
+ * @param inputText a {@link String} for processing.
+ */
+ public TextEmbedderResult embed(String inputText) {
+ Map inputPackets = new HashMap<>();
+ inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText));
+ return (TextEmbedderResult) runner.process(inputPackets);
+ }
+
+ /** Closes and cleans up the {@link TextEmbedder}. */
+ @Override
+ public void close() {
+ runner.close();
+ }
+
+ /**
+ * Utility function to compute cosine
+ * similarity between two {@link Embedding} objects.
+ *
+ * @throws IllegalArgumentException if the embeddings are of different types (float vs.
+ * quantized), have different sizes, or have an L2-norm of 0.
+ */
+ public static double cosineSimilarity(Embedding u, Embedding v) {
+ return CosineSimilarity.compute(u, v);
+ }
+
+ /** Options for setting up a {@link TextEmbedder}. */
+ @AutoValue
+ public abstract static class TextEmbedderOptions extends TaskOptions {
+
+ /** Builder for {@link TextEmbedderOptions}. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ /** Sets the base options for the text embedder task. */
+ public abstract Builder setBaseOptions(BaseOptions value);
+
+ /**
+ * Sets the optional {@link EmbedderOptions} controling embedder behavior, such as
+ * L2-normalization and scalar quantization.
+ */
+ public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions);
+
+ public abstract TextEmbedderOptions build();
+ }
+
+ abstract BaseOptions baseOptions();
+
+ abstract Optional embedderOptions();
+
+ public static Builder builder() {
+ return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder();
+ }
+
+ /** Converts a {@link TextEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */
+ @Override
+ public CalculatorOptions convertToCalculatorOptionsProto() {
+ BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
+ BaseOptionsProto.BaseOptions.newBuilder();
+ baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
+ TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.Builder taskOptionsBuilder =
+ TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.newBuilder()
+ .setBaseOptions(baseOptionsBuilder);
+ if (embedderOptions().isPresent()) {
+ taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto());
+ }
+ return CalculatorOptions.newBuilder()
+ .setExtension(
+ TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.ext,
+ taskOptionsBuilder.build())
+ .build();
+ }
+ }
+}
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedderResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedderResult.java
new file mode 100644
index 000000000..9d8e108ec
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedderResult.java
@@ -0,0 +1,54 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.text.textembedder;
+
+import com.google.auto.value.AutoValue;
+import com.google.mediapipe.tasks.components.containers.EmbeddingResult;
+import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
+import com.google.mediapipe.tasks.core.TaskResult;
+
+/** Represents the embedding results generated by {@link TextEmbedder}. */
+@AutoValue
+public abstract class TextEmbedderResult implements TaskResult {
+
+ /**
+ * Creates an {@link TextEmbedderResult} instance.
+ *
+ * @param embeddingResult the {@link EmbeddingResult} object containing one embedding per embedder
+ * head.
+ * @param timestampMs a timestamp for this result.
+ */
+ static TextEmbedderResult create(EmbeddingResult embeddingResult, long timestampMs) {
+ return new AutoValue_TextEmbedderResult(embeddingResult, timestampMs);
+ }
+
+ /**
+ * Creates an {@link TextEmbedderResult} instance from a {@link EmbeddingsProto.EmbeddingResult}
+ * protobuf message.
+ *
+ * @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert.
+ * @param timestampMs a timestamp for this result.
+ */
+ static TextEmbedderResult createFromProto(
+ EmbeddingsProto.EmbeddingResult proto, long timestampMs) {
+ return create(EmbeddingResult.createFromProto(proto), timestampMs);
+ }
+
+ /** Contains one embedding per embedder head. */
+ public abstract EmbeddingResult embeddingResult();
+
+ @Override
+ public abstract long timestampMs();
+}
diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java
index d3f0e90f3..5e03d2a4c 100644
--- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java
+++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java
@@ -67,9 +67,7 @@ public class TextClassifierTest {
ApplicationProvider.getApplicationContext(), options));
// TODO: Make MediaPipe InferenceCalculator report the detailed.
// interpreter errors (e.g., "Encountered unresolved custom op").
- assertThat(exception)
- .hasMessageThat()
- .contains("interpreter_builder(&interpreter) == kTfLiteOk");
+ assertThat(exception).hasMessageThat().contains("== kTfLiteOk");
}
@Test
diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml
new file mode 100644
index 000000000..5d55d7cfe
--- /dev/null
+++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml
@@ -0,0 +1,24 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/BUILD
new file mode 100644
index 000000000..a7f804c64
--- /dev/null
+++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/BUILD
@@ -0,0 +1,19 @@
+# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(default_visibility = ["//mediapipe/tasks:internal"])
+
+licenses(["notice"])
+
+# TODO: Enable this in OSS
diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java
new file mode 100644
index 000000000..b6d53c94d
--- /dev/null
+++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java
@@ -0,0 +1,98 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.text.textembedder;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertThrows;
+
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import com.google.mediapipe.framework.MediaPipeException;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+/** Test for {@link TextEmbedder}/ */
+@RunWith(AndroidJUnit4.class)
+public class TextEmbedderTest {
+ private static final String BERT_MODEL_FILE = "mobilebert_embedding_with_metadata.tflite";
+ private static final String REGEX_MODEL_FILE = "regex_one_embedding_with_metadata.tflite";
+
+ private static final double DOUBLE_DIFF_TOLERANCE = 1e-4;
+ private static final float FLOAT_DIFF_TOLERANCE = 1e-4f;
+
+ @Test
+ public void create_failsWithMissingModel() throws Exception {
+ String nonExistentFile = "/path/to/non/existent/file";
+ MediaPipeException exception =
+ assertThrows(
+ MediaPipeException.class,
+ () ->
+ TextEmbedder.createFromFile(
+ ApplicationProvider.getApplicationContext(), nonExistentFile));
+ assertThat(exception).hasMessageThat().contains(nonExistentFile);
+ }
+
+ @Test
+ public void embed_succeedsWithBert() throws Exception {
+ TextEmbedder textEmbedder =
+ TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
+
+ TextEmbedderResult result0 = textEmbedder.embed("it's a charming and often affecting journey");
+ assertThat(result0.embeddingResult().embeddings().size()).isEqualTo(1);
+ assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(512);
+ assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()[0])
+ .isWithin(FLOAT_DIFF_TOLERANCE)
+ .of(20.59746f);
+ TextEmbedderResult result1 = textEmbedder.embed("what a great and fantastic trip");
+ assertThat(result1.embeddingResult().embeddings().size()).isEqualTo(1);
+ assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(512);
+ assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()[0])
+ .isWithin(FLOAT_DIFF_TOLERANCE)
+ .of(21.774776f);
+
+ // Check cosine similarity.
+ double similarity =
+ TextEmbedder.cosineSimilarity(
+ result0.embeddingResult().embeddings().get(0),
+ result1.embeddingResult().embeddings().get(0));
+ assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.968879);
+ }
+
+ @Test
+ public void embed_succeedsWithRegex() throws Exception {
+ TextEmbedder textEmbedder =
+ TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE);
+
+ TextEmbedderResult result0 = textEmbedder.embed("it's a charming and often affecting journey");
+ assertThat(result0.embeddingResult().embeddings().size()).isEqualTo(1);
+ assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(16);
+ assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()[0])
+ .isWithin(FLOAT_DIFF_TOLERANCE)
+ .of(0.030935612f);
+ TextEmbedderResult result1 = textEmbedder.embed("what a great and fantastic trip");
+ assertThat(result1.embeddingResult().embeddings().size()).isEqualTo(1);
+ assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(16);
+ assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()[0])
+ .isWithin(FLOAT_DIFF_TOLERANCE)
+ .of(0.0312863f);
+
+ // Check cosine similarity.
+ double similarity =
+ TextEmbedder.cosineSimilarity(
+ result0.embeddingResult().embeddings().get(0),
+ result1.embeddingResult().embeddings().get(0));
+ assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999937);
+ }
+}