From 34daba4747bcb598c576381a5be8e896e7761fc8 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 14 Nov 2022 11:46:37 -0800 Subject: [PATCH] Add Java TextEmbedder API. PiperOrigin-RevId: 488427327 --- .../proto/text_embedder_graph_options.proto | 3 + .../mediapipe/tasks/core/TaskOptions.java | 4 +- .../mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 + .../com/google/mediapipe/tasks/text/BUILD | 28 ++ .../text/textembedder/AndroidManifest.xml | 8 + .../tasks/text/textembedder/TextEmbedder.java | 256 ++++++++++++++++++ .../text/textembedder/TextEmbedderResult.java | 54 ++++ .../textclassifier/TextClassifierTest.java | 4 +- .../text/textembedder/AndroidManifest.xml | 24 ++ .../mediapipe/tasks/text/textembedder/BUILD | 19 ++ .../text/textembedder/TextEmbedderTest.java | 98 +++++++ 11 files changed, 494 insertions(+), 5 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedderResult.java create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/BUILD create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java diff --git a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto index 6b8d41a57..e7e3a63c7 100644 --- a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto @@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +option java_package = "com.google.mediapipe.tasks.text.textembedder.proto"; +option java_outer_classname = "TextEmbedderGraphOptionsProto"; + message TextEmbedderGraphOptions { extend mediapipe.CalculatorOptions { optional TextEmbedderGraphOptions ext = 477589892; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java index 9bf600360..0fc48742e 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java @@ -58,8 +58,8 @@ public abstract class TaskOptions { AccelerationProto.Acceleration.newBuilder(); switch (options.delegate()) { case CPU: - accelerationBuilder.setXnnpack( - InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Xnnpack + accelerationBuilder.setTflite( + InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.TfLite .getDefaultInstance()); break; case GPU: diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index f0c9f81c6..ab7ad6616 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -49,6 +49,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_embedder_graph_options_java_proto_lite", ] def mediapipe_tasks_core_aar(name, srcs, manifest): diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index b49169529..0e72878ab 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -24,6 +24,7 @@ cc_binary( deps = [ "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", + "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", ], ) @@ -60,6 +61,33 @@ android_library( ], ) +android_library( + name = "textembedder", + srcs = [ + "textembedder/TextEmbedder.java", + "textembedder/TextEmbedderResult.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "textembedder/AndroidManifest.xml", + deps = [ + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_text_aar") mediapipe_tasks_text_aar( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml new file mode 100644 index 000000000..d9c885d16 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java new file mode 100644 index 000000000..95fa1f087 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java @@ -0,0 +1,256 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.text.textembedder; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.ProtoUtil; +import com.google.mediapipe.tasks.components.containers.Embedding; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.components.utils.CosineSimilarity; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.text.textembedder.proto.TextEmbedderGraphOptionsProto; +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Performs embedding extraction on text. + * + *

This API expects a TFLite model with (optional) TFLite Model Metadata. + * + *

Metadata is required for models with int32 input tensors because it contains the input process + * unit for the model's Tokenizer. No metadata is required for models with string input tensors. + * + *

+ */ +public final class TextEmbedder implements AutoCloseable { + private static final String TAG = TextEmbedder.class.getSimpleName(); + private static final String TEXT_IN_STREAM_NAME = "text_in"; + + @SuppressWarnings("ConstantCaseForConstants") + private static final List INPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("TEXT:" + TEXT_IN_STREAM_NAME)); + + @SuppressWarnings("ConstantCaseForConstants") + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("EMBEDDINGS:embeddings_out")); + + private static final int EMBEDDINGS_OUT_STREAM_INDEX = 0; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.text.text_embedder.TextEmbedderGraph"; + private final TaskRunner runner; + + static { + System.loadLibrary("mediapipe_tasks_text_jni"); + ProtoUtil.registerTypeName( + EmbeddingsProto.EmbeddingResult.class, + "mediapipe.tasks.components.containers.proto.EmbeddingResult"); + } + + /** + * Creates a {@link TextEmbedder} instance from a model file and the default {@link + * TextEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the text model with metadata in the assets. + * @throws MediaPipeException if there is is an error during {@link TextEmbedder} creation. + */ + public static TextEmbedder createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, TextEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link TextEmbedder} instance from a model file and the default {@link + * TextEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the text model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link TextEmbedder} creation. + */ + public static TextEmbedder createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, TextEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates a {@link TextEmbedder} instance from {@link TextEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param options a {@link TextEmbedderOptions} instance. + * @throws MediaPipeException if there is an error during {@link TextEmbedder} creation. + */ + public static TextEmbedder createFromOptions(Context context, TextEmbedderOptions options) { + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public TextEmbedderResult convertToTaskResult(List packets) { + try { + return TextEmbedderResult.create( + EmbeddingResult.createFromProto( + PacketGetter.getProto( + packets.get(EMBEDDINGS_OUT_STREAM_INDEX), + EmbeddingsProto.EmbeddingResult.getDefaultInstance())), + packets.get(EMBEDDINGS_OUT_STREAM_INDEX).getTimestamp()); + } catch (IOException e) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); + } + } + + @Override + public Void convertToTaskInput(List packets) { + return null; + } + }); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(options) + .setEnableFlowLimiting(false) + .build(), + handler); + return new TextEmbedder(runner); + } + + /** + * Constructor to initialize a {@link TextEmbedder} from a {@link TaskRunner}. + * + * @param runner a {@link TaskRunner}. + */ + private TextEmbedder(TaskRunner runner) { + this.runner = runner; + } + + /** + * Performs embedding extraction on the input text. + * + * @param inputText a {@link String} for processing. + */ + public TextEmbedderResult embed(String inputText) { + Map inputPackets = new HashMap<>(); + inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText)); + return (TextEmbedderResult) runner.process(inputPackets); + } + + /** Closes and cleans up the {@link TextEmbedder}. */ + @Override + public void close() { + runner.close(); + } + + /** + * Utility function to compute cosine + * similarity between two {@link Embedding} objects. + * + * @throws IllegalArgumentException if the embeddings are of different types (float vs. + * quantized), have different sizes, or have an L2-norm of 0. + */ + public static double cosineSimilarity(Embedding u, Embedding v) { + return CosineSimilarity.compute(u, v); + } + + /** Options for setting up a {@link TextEmbedder}. */ + @AutoValue + public abstract static class TextEmbedderOptions extends TaskOptions { + + /** Builder for {@link TextEmbedderOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the base options for the text embedder task. */ + public abstract Builder setBaseOptions(BaseOptions value); + + /** + * Sets the optional {@link EmbedderOptions} controling embedder behavior, such as + * L2-normalization and scalar quantization. + */ + public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); + + public abstract TextEmbedderOptions build(); + } + + abstract BaseOptions baseOptions(); + + abstract Optional embedderOptions(); + + public static Builder builder() { + return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder(); + } + + /** Converts a {@link TextEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder(); + baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.Builder taskOptionsBuilder = + TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder); + if (embedderOptions().isPresent()) { + taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto()); + } + return CalculatorOptions.newBuilder() + .setExtension( + TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedderResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedderResult.java new file mode 100644 index 000000000..9d8e108ec --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedderResult.java @@ -0,0 +1,54 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.text.textembedder; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.core.TaskResult; + +/** Represents the embedding results generated by {@link TextEmbedder}. */ +@AutoValue +public abstract class TextEmbedderResult implements TaskResult { + + /** + * Creates an {@link TextEmbedderResult} instance. + * + * @param embeddingResult the {@link EmbeddingResult} object containing one embedding per embedder + * head. + * @param timestampMs a timestamp for this result. + */ + static TextEmbedderResult create(EmbeddingResult embeddingResult, long timestampMs) { + return new AutoValue_TextEmbedderResult(embeddingResult, timestampMs); + } + + /** + * Creates an {@link TextEmbedderResult} instance from a {@link EmbeddingsProto.EmbeddingResult} + * protobuf message. + * + * @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert. + * @param timestampMs a timestamp for this result. + */ + static TextEmbedderResult createFromProto( + EmbeddingsProto.EmbeddingResult proto, long timestampMs) { + return create(EmbeddingResult.createFromProto(proto), timestampMs); + } + + /** Contains one embedding per embedder head. */ + public abstract EmbeddingResult embeddingResult(); + + @Override + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java index d3f0e90f3..5e03d2a4c 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java @@ -67,9 +67,7 @@ public class TextClassifierTest { ApplicationProvider.getApplicationContext(), options)); // TODO: Make MediaPipe InferenceCalculator report the detailed. // interpreter errors (e.g., "Encountered unresolved custom op"). - assertThat(exception) - .hasMessageThat() - .contains("interpreter_builder(&interpreter) == kTfLiteOk"); + assertThat(exception).hasMessageThat().contains("== kTfLiteOk"); } @Test diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml new file mode 100644 index 000000000..5d55d7cfe --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/BUILD new file mode 100644 index 000000000..a7f804c64 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java new file mode 100644 index 000000000..b6d53c94d --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java @@ -0,0 +1,98 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.text.textembedder; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.framework.MediaPipeException; +import org.junit.Test; +import org.junit.runner.RunWith; + +/** Test for {@link TextEmbedder}/ */ +@RunWith(AndroidJUnit4.class) +public class TextEmbedderTest { + private static final String BERT_MODEL_FILE = "mobilebert_embedding_with_metadata.tflite"; + private static final String REGEX_MODEL_FILE = "regex_one_embedding_with_metadata.tflite"; + + private static final double DOUBLE_DIFF_TOLERANCE = 1e-4; + private static final float FLOAT_DIFF_TOLERANCE = 1e-4f; + + @Test + public void create_failsWithMissingModel() throws Exception { + String nonExistentFile = "/path/to/non/existent/file"; + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + TextEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), nonExistentFile)); + assertThat(exception).hasMessageThat().contains(nonExistentFile); + } + + @Test + public void embed_succeedsWithBert() throws Exception { + TextEmbedder textEmbedder = + TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE); + + TextEmbedderResult result0 = textEmbedder.embed("it's a charming and often affecting journey"); + assertThat(result0.embeddingResult().embeddings().size()).isEqualTo(1); + assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(512); + assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()[0]) + .isWithin(FLOAT_DIFF_TOLERANCE) + .of(20.59746f); + TextEmbedderResult result1 = textEmbedder.embed("what a great and fantastic trip"); + assertThat(result1.embeddingResult().embeddings().size()).isEqualTo(1); + assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(512); + assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()[0]) + .isWithin(FLOAT_DIFF_TOLERANCE) + .of(21.774776f); + + // Check cosine similarity. + double similarity = + TextEmbedder.cosineSimilarity( + result0.embeddingResult().embeddings().get(0), + result1.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.968879); + } + + @Test + public void embed_succeedsWithRegex() throws Exception { + TextEmbedder textEmbedder = + TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE); + + TextEmbedderResult result0 = textEmbedder.embed("it's a charming and often affecting journey"); + assertThat(result0.embeddingResult().embeddings().size()).isEqualTo(1); + assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(16); + assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()[0]) + .isWithin(FLOAT_DIFF_TOLERANCE) + .of(0.030935612f); + TextEmbedderResult result1 = textEmbedder.embed("what a great and fantastic trip"); + assertThat(result1.embeddingResult().embeddings().size()).isEqualTo(1); + assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(16); + assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()[0]) + .isWithin(FLOAT_DIFF_TOLERANCE) + .of(0.0312863f); + + // Check cosine similarity. + double similarity = + TextEmbedder.cosineSimilarity( + result0.embeddingResult().embeddings().get(0), + result1.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999937); + } +}