Add Java TextEmbedder API.
PiperOrigin-RevId: 488427327
This commit is contained in:
parent
b40b2ade14
commit
34daba4747
|
@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
|
option java_package = "com.google.mediapipe.tasks.text.textembedder.proto";
|
||||||
|
option java_outer_classname = "TextEmbedderGraphOptionsProto";
|
||||||
|
|
||||||
message TextEmbedderGraphOptions {
|
message TextEmbedderGraphOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional TextEmbedderGraphOptions ext = 477589892;
|
optional TextEmbedderGraphOptions ext = 477589892;
|
||||||
|
|
|
@ -58,8 +58,8 @@ public abstract class TaskOptions {
|
||||||
AccelerationProto.Acceleration.newBuilder();
|
AccelerationProto.Acceleration.newBuilder();
|
||||||
switch (options.delegate()) {
|
switch (options.delegate()) {
|
||||||
case CPU:
|
case CPU:
|
||||||
accelerationBuilder.setXnnpack(
|
accelerationBuilder.setTflite(
|
||||||
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Xnnpack
|
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.TfLite
|
||||||
.getDefaultInstance());
|
.getDefaultInstance());
|
||||||
break;
|
break;
|
||||||
case GPU:
|
case GPU:
|
||||||
|
|
|
@ -49,6 +49,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
||||||
|
|
||||||
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
||||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_embedder_graph_options_java_proto_lite",
|
||||||
]
|
]
|
||||||
|
|
||||||
def mediapipe_tasks_core_aar(name, srcs, manifest):
|
def mediapipe_tasks_core_aar(name, srcs, manifest):
|
||||||
|
|
|
@ -24,6 +24,7 @@ cc_binary(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||||
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
|
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
|
||||||
|
"//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -60,6 +61,33 @@ android_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "textembedder",
|
||||||
|
srcs = [
|
||||||
|
"textembedder/TextEmbedder.java",
|
||||||
|
"textembedder/TextEmbedderResult.java",
|
||||||
|
],
|
||||||
|
javacopts = [
|
||||||
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
],
|
||||||
|
manifest = "textembedder/AndroidManifest.xml",
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_text_aar")
|
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_text_aar")
|
||||||
|
|
||||||
mediapipe_tasks_text_aar(
|
mediapipe_tasks_text_aar(
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="com.google.mediapipe.tasks.text.textembedder">
|
||||||
|
|
||||||
|
<uses-sdk android:minSdkVersion="24"
|
||||||
|
android:targetSdkVersion="30" />
|
||||||
|
|
||||||
|
</manifest>
|
|
@ -0,0 +1,256 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.text.textembedder;
|
||||||
|
|
||||||
|
import android.content.Context;
|
||||||
|
import android.os.ParcelFileDescriptor;
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||||
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
|
import com.google.mediapipe.framework.Packet;
|
||||||
|
import com.google.mediapipe.framework.PacketGetter;
|
||||||
|
import com.google.mediapipe.framework.ProtoUtil;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.Embedding;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.EmbeddingResult;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
|
||||||
|
import com.google.mediapipe.tasks.components.processors.EmbedderOptions;
|
||||||
|
import com.google.mediapipe.tasks.components.utils.CosineSimilarity;
|
||||||
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||||
|
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||||
|
import com.google.mediapipe.tasks.text.textembedder.proto.TextEmbedderGraphOptionsProto;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs embedding extraction on text.
|
||||||
|
*
|
||||||
|
* <p>This API expects a TFLite model with (optional) <a
|
||||||
|
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata</a>.
|
||||||
|
*
|
||||||
|
* <p>Metadata is required for models with int32 input tensors because it contains the input process
|
||||||
|
* unit for the model's Tokenizer. No metadata is required for models with string input tensors.
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>Input tensors
|
||||||
|
* <ul>
|
||||||
|
* <li>Three input tensors ({@code kTfLiteInt32}) of shape {@code [batch_size x
|
||||||
|
* bert_max_seq_len]} representing the input ids, mask ids, and segment ids. This input
|
||||||
|
* signature requires a Bert Tokenizer process unit in the model metadata.
|
||||||
|
* <li>Or one input tensor ({@code kTfLiteInt32}) of shape {@code [batch_size x
|
||||||
|
* max_seq_len]} representing the input ids. This input signature requires a Regex
|
||||||
|
* Tokenizer process unit in the model metadata.
|
||||||
|
* <li>Or one input tensor ({@code kTfLiteString}) that is shapeless or has shape {@code
|
||||||
|
* [1]} containing the input string.
|
||||||
|
* </ul>
|
||||||
|
* <li>At least one output tensor ({@code kTfLiteFloat32}/{@code kTfLiteUint8}) with shape {@code
|
||||||
|
* [1 x N]} where N is the number of dimensions in the produced embeddings.
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
|
public final class TextEmbedder implements AutoCloseable {
|
||||||
|
private static final String TAG = TextEmbedder.class.getSimpleName();
|
||||||
|
private static final String TEXT_IN_STREAM_NAME = "text_in";
|
||||||
|
|
||||||
|
@SuppressWarnings("ConstantCaseForConstants")
|
||||||
|
private static final List<String> INPUT_STREAMS =
|
||||||
|
Collections.unmodifiableList(Arrays.asList("TEXT:" + TEXT_IN_STREAM_NAME));
|
||||||
|
|
||||||
|
@SuppressWarnings("ConstantCaseForConstants")
|
||||||
|
private static final List<String> OUTPUT_STREAMS =
|
||||||
|
Collections.unmodifiableList(Arrays.asList("EMBEDDINGS:embeddings_out"));
|
||||||
|
|
||||||
|
private static final int EMBEDDINGS_OUT_STREAM_INDEX = 0;
|
||||||
|
private static final String TASK_GRAPH_NAME =
|
||||||
|
"mediapipe.tasks.text.text_embedder.TextEmbedderGraph";
|
||||||
|
private final TaskRunner runner;
|
||||||
|
|
||||||
|
static {
|
||||||
|
System.loadLibrary("mediapipe_tasks_text_jni");
|
||||||
|
ProtoUtil.registerTypeName(
|
||||||
|
EmbeddingsProto.EmbeddingResult.class,
|
||||||
|
"mediapipe.tasks.components.containers.proto.EmbeddingResult");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link TextEmbedder} instance from a model file and the default {@link
|
||||||
|
* TextEmbedderOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelPath path to the text model with metadata in the assets.
|
||||||
|
* @throws MediaPipeException if there is is an error during {@link TextEmbedder} creation.
|
||||||
|
*/
|
||||||
|
public static TextEmbedder createFromFile(Context context, String modelPath) {
|
||||||
|
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, TextEmbedderOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link TextEmbedder} instance from a model file and the default {@link
|
||||||
|
* TextEmbedderOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelFile the text model {@link File} instance.
|
||||||
|
* @throws IOException if an I/O error occurs when opening the tflite model file.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link TextEmbedder} creation.
|
||||||
|
*/
|
||||||
|
public static TextEmbedder createFromFile(Context context, File modelFile) throws IOException {
|
||||||
|
try (ParcelFileDescriptor descriptor =
|
||||||
|
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
|
||||||
|
BaseOptions baseOptions =
|
||||||
|
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, TextEmbedderOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link TextEmbedder} instance from {@link TextEmbedderOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param options a {@link TextEmbedderOptions} instance.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link TextEmbedder} creation.
|
||||||
|
*/
|
||||||
|
public static TextEmbedder createFromOptions(Context context, TextEmbedderOptions options) {
|
||||||
|
OutputHandler<TextEmbedderResult, Void> handler = new OutputHandler<>();
|
||||||
|
handler.setOutputPacketConverter(
|
||||||
|
new OutputHandler.OutputPacketConverter<TextEmbedderResult, Void>() {
|
||||||
|
@Override
|
||||||
|
public TextEmbedderResult convertToTaskResult(List<Packet> packets) {
|
||||||
|
try {
|
||||||
|
return TextEmbedderResult.create(
|
||||||
|
EmbeddingResult.createFromProto(
|
||||||
|
PacketGetter.getProto(
|
||||||
|
packets.get(EMBEDDINGS_OUT_STREAM_INDEX),
|
||||||
|
EmbeddingsProto.EmbeddingResult.getDefaultInstance())),
|
||||||
|
packets.get(EMBEDDINGS_OUT_STREAM_INDEX).getTimestamp());
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Void convertToTaskInput(List<Packet> packets) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
TaskRunner runner =
|
||||||
|
TaskRunner.create(
|
||||||
|
context,
|
||||||
|
TaskInfo.<TextEmbedderOptions>builder()
|
||||||
|
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||||
|
.setInputStreams(INPUT_STREAMS)
|
||||||
|
.setOutputStreams(OUTPUT_STREAMS)
|
||||||
|
.setTaskOptions(options)
|
||||||
|
.setEnableFlowLimiting(false)
|
||||||
|
.build(),
|
||||||
|
handler);
|
||||||
|
return new TextEmbedder(runner);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor to initialize a {@link TextEmbedder} from a {@link TaskRunner}.
|
||||||
|
*
|
||||||
|
* @param runner a {@link TaskRunner}.
|
||||||
|
*/
|
||||||
|
private TextEmbedder(TaskRunner runner) {
|
||||||
|
this.runner = runner;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs embedding extraction on the input text.
|
||||||
|
*
|
||||||
|
* @param inputText a {@link String} for processing.
|
||||||
|
*/
|
||||||
|
public TextEmbedderResult embed(String inputText) {
|
||||||
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
|
inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText));
|
||||||
|
return (TextEmbedderResult) runner.process(inputPackets);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Closes and cleans up the {@link TextEmbedder}. */
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
runner.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utility function to compute <a href="https://en.wikipedia.org/wiki/Cosine_similarity">cosine
|
||||||
|
* similarity</a> between two {@link Embedding} objects.
|
||||||
|
*
|
||||||
|
* @throws IllegalArgumentException if the embeddings are of different types (float vs.
|
||||||
|
* quantized), have different sizes, or have an L2-norm of 0.
|
||||||
|
*/
|
||||||
|
public static double cosineSimilarity(Embedding u, Embedding v) {
|
||||||
|
return CosineSimilarity.compute(u, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Options for setting up a {@link TextEmbedder}. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract static class TextEmbedderOptions extends TaskOptions {
|
||||||
|
|
||||||
|
/** Builder for {@link TextEmbedderOptions}. */
|
||||||
|
@AutoValue.Builder
|
||||||
|
public abstract static class Builder {
|
||||||
|
/** Sets the base options for the text embedder task. */
|
||||||
|
public abstract Builder setBaseOptions(BaseOptions value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional {@link EmbedderOptions} controling embedder behavior, such as
|
||||||
|
* L2-normalization and scalar quantization.
|
||||||
|
*/
|
||||||
|
public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions);
|
||||||
|
|
||||||
|
public abstract TextEmbedderOptions build();
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract BaseOptions baseOptions();
|
||||||
|
|
||||||
|
abstract Optional<EmbedderOptions> embedderOptions();
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Converts a {@link TextEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */
|
||||||
|
@Override
|
||||||
|
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||||
|
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
|
||||||
|
BaseOptionsProto.BaseOptions.newBuilder();
|
||||||
|
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||||
|
TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.Builder taskOptionsBuilder =
|
||||||
|
TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.newBuilder()
|
||||||
|
.setBaseOptions(baseOptionsBuilder);
|
||||||
|
if (embedderOptions().isPresent()) {
|
||||||
|
taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto());
|
||||||
|
}
|
||||||
|
return CalculatorOptions.newBuilder()
|
||||||
|
.setExtension(
|
||||||
|
TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.ext,
|
||||||
|
taskOptionsBuilder.build())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,54 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.text.textembedder;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.EmbeddingResult;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
|
|
||||||
|
/** Represents the embedding results generated by {@link TextEmbedder}. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract class TextEmbedderResult implements TaskResult {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link TextEmbedderResult} instance.
|
||||||
|
*
|
||||||
|
* @param embeddingResult the {@link EmbeddingResult} object containing one embedding per embedder
|
||||||
|
* head.
|
||||||
|
* @param timestampMs a timestamp for this result.
|
||||||
|
*/
|
||||||
|
static TextEmbedderResult create(EmbeddingResult embeddingResult, long timestampMs) {
|
||||||
|
return new AutoValue_TextEmbedderResult(embeddingResult, timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link TextEmbedderResult} instance from a {@link EmbeddingsProto.EmbeddingResult}
|
||||||
|
* protobuf message.
|
||||||
|
*
|
||||||
|
* @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert.
|
||||||
|
* @param timestampMs a timestamp for this result.
|
||||||
|
*/
|
||||||
|
static TextEmbedderResult createFromProto(
|
||||||
|
EmbeddingsProto.EmbeddingResult proto, long timestampMs) {
|
||||||
|
return create(EmbeddingResult.createFromProto(proto), timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Contains one embedding per embedder head. */
|
||||||
|
public abstract EmbeddingResult embeddingResult();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public abstract long timestampMs();
|
||||||
|
}
|
|
@ -67,9 +67,7 @@ public class TextClassifierTest {
|
||||||
ApplicationProvider.getApplicationContext(), options));
|
ApplicationProvider.getApplicationContext(), options));
|
||||||
// TODO: Make MediaPipe InferenceCalculator report the detailed.
|
// TODO: Make MediaPipe InferenceCalculator report the detailed.
|
||||||
// interpreter errors (e.g., "Encountered unresolved custom op").
|
// interpreter errors (e.g., "Encountered unresolved custom op").
|
||||||
assertThat(exception)
|
assertThat(exception).hasMessageThat().contains("== kTfLiteOk");
|
||||||
.hasMessageThat()
|
|
||||||
.contains("interpreter_builder(&interpreter) == kTfLiteOk");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="com.google.mediapipe.tasks.text.textembeddertest"
|
||||||
|
android:versionCode="1"
|
||||||
|
android:versionName="1.0" >
|
||||||
|
|
||||||
|
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||||
|
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||||
|
|
||||||
|
<uses-sdk android:minSdkVersion="24"
|
||||||
|
android:targetSdkVersion="30" />
|
||||||
|
|
||||||
|
<application
|
||||||
|
android:label="textembeddertest"
|
||||||
|
android:name="android.support.multidex.MultiDexApplication"
|
||||||
|
android:taskAffinity="">
|
||||||
|
<uses-library android:name="android.test.runner" />
|
||||||
|
</application>
|
||||||
|
|
||||||
|
<instrumentation
|
||||||
|
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||||
|
android:targetPackage="com.google.mediapipe.tasks.text.textembeddertest" />
|
||||||
|
|
||||||
|
</manifest>
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
# TODO: Enable this in OSS
|
|
@ -0,0 +1,98 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.text.textembedder;
|
||||||
|
|
||||||
|
import static com.google.common.truth.Truth.assertThat;
|
||||||
|
import static org.junit.Assert.assertThrows;
|
||||||
|
|
||||||
|
import androidx.test.core.app.ApplicationProvider;
|
||||||
|
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||||
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
|
||||||
|
/** Test for {@link TextEmbedder}/ */
|
||||||
|
@RunWith(AndroidJUnit4.class)
|
||||||
|
public class TextEmbedderTest {
|
||||||
|
private static final String BERT_MODEL_FILE = "mobilebert_embedding_with_metadata.tflite";
|
||||||
|
private static final String REGEX_MODEL_FILE = "regex_one_embedding_with_metadata.tflite";
|
||||||
|
|
||||||
|
private static final double DOUBLE_DIFF_TOLERANCE = 1e-4;
|
||||||
|
private static final float FLOAT_DIFF_TOLERANCE = 1e-4f;
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void create_failsWithMissingModel() throws Exception {
|
||||||
|
String nonExistentFile = "/path/to/non/existent/file";
|
||||||
|
MediaPipeException exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() ->
|
||||||
|
TextEmbedder.createFromFile(
|
||||||
|
ApplicationProvider.getApplicationContext(), nonExistentFile));
|
||||||
|
assertThat(exception).hasMessageThat().contains(nonExistentFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void embed_succeedsWithBert() throws Exception {
|
||||||
|
TextEmbedder textEmbedder =
|
||||||
|
TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
|
||||||
|
|
||||||
|
TextEmbedderResult result0 = textEmbedder.embed("it's a charming and often affecting journey");
|
||||||
|
assertThat(result0.embeddingResult().embeddings().size()).isEqualTo(1);
|
||||||
|
assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(512);
|
||||||
|
assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()[0])
|
||||||
|
.isWithin(FLOAT_DIFF_TOLERANCE)
|
||||||
|
.of(20.59746f);
|
||||||
|
TextEmbedderResult result1 = textEmbedder.embed("what a great and fantastic trip");
|
||||||
|
assertThat(result1.embeddingResult().embeddings().size()).isEqualTo(1);
|
||||||
|
assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(512);
|
||||||
|
assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()[0])
|
||||||
|
.isWithin(FLOAT_DIFF_TOLERANCE)
|
||||||
|
.of(21.774776f);
|
||||||
|
|
||||||
|
// Check cosine similarity.
|
||||||
|
double similarity =
|
||||||
|
TextEmbedder.cosineSimilarity(
|
||||||
|
result0.embeddingResult().embeddings().get(0),
|
||||||
|
result1.embeddingResult().embeddings().get(0));
|
||||||
|
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.968879);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void embed_succeedsWithRegex() throws Exception {
|
||||||
|
TextEmbedder textEmbedder =
|
||||||
|
TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE);
|
||||||
|
|
||||||
|
TextEmbedderResult result0 = textEmbedder.embed("it's a charming and often affecting journey");
|
||||||
|
assertThat(result0.embeddingResult().embeddings().size()).isEqualTo(1);
|
||||||
|
assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(16);
|
||||||
|
assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()[0])
|
||||||
|
.isWithin(FLOAT_DIFF_TOLERANCE)
|
||||||
|
.of(0.030935612f);
|
||||||
|
TextEmbedderResult result1 = textEmbedder.embed("what a great and fantastic trip");
|
||||||
|
assertThat(result1.embeddingResult().embeddings().size()).isEqualTo(1);
|
||||||
|
assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(16);
|
||||||
|
assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()[0])
|
||||||
|
.isWithin(FLOAT_DIFF_TOLERANCE)
|
||||||
|
.of(0.0312863f);
|
||||||
|
|
||||||
|
// Check cosine similarity.
|
||||||
|
double similarity =
|
||||||
|
TextEmbedder.cosineSimilarity(
|
||||||
|
result0.embeddingResult().embeddings().get(0),
|
||||||
|
result1.embeddingResult().embeddings().get(0));
|
||||||
|
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999937);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user