Add Java TextEmbedder API.
PiperOrigin-RevId: 488427327
This commit is contained in:
		
							parent
							
								
									b40b2ade14
								
							
						
					
					
						commit
						34daba4747
					
				| 
						 | 
				
			
			@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
 | 
			
		|||
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
 | 
			
		||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
 | 
			
		||||
 | 
			
		||||
option java_package = "com.google.mediapipe.tasks.text.textembedder.proto";
 | 
			
		||||
option java_outer_classname = "TextEmbedderGraphOptionsProto";
 | 
			
		||||
 | 
			
		||||
message TextEmbedderGraphOptions {
 | 
			
		||||
  extend mediapipe.CalculatorOptions {
 | 
			
		||||
    optional TextEmbedderGraphOptions ext = 477589892;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -58,8 +58,8 @@ public abstract class TaskOptions {
 | 
			
		|||
        AccelerationProto.Acceleration.newBuilder();
 | 
			
		||||
    switch (options.delegate()) {
 | 
			
		||||
      case CPU:
 | 
			
		||||
        accelerationBuilder.setXnnpack(
 | 
			
		||||
            InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Xnnpack
 | 
			
		||||
        accelerationBuilder.setTflite(
 | 
			
		||||
            InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.TfLite
 | 
			
		||||
                .getDefaultInstance());
 | 
			
		||||
        break;
 | 
			
		||||
      case GPU:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -49,6 +49,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
 | 
			
		|||
 | 
			
		||||
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [
 | 
			
		||||
    "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
 | 
			
		||||
    "//mediapipe/tasks/cc/text/text_classifier/proto:text_embedder_graph_options_java_proto_lite",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
def mediapipe_tasks_core_aar(name, srcs, manifest):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -24,6 +24,7 @@ cc_binary(
 | 
			
		|||
    deps = [
 | 
			
		||||
        "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
 | 
			
		||||
        "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
 | 
			
		||||
        "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph",
 | 
			
		||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -60,6 +61,33 @@ android_library(
 | 
			
		|||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
android_library(
 | 
			
		||||
    name = "textembedder",
 | 
			
		||||
    srcs = [
 | 
			
		||||
        "textembedder/TextEmbedder.java",
 | 
			
		||||
        "textembedder/TextEmbedderResult.java",
 | 
			
		||||
    ],
 | 
			
		||||
    javacopts = [
 | 
			
		||||
        "-Xep:AndroidJdkLibsChecker:OFF",
 | 
			
		||||
    ],
 | 
			
		||||
    manifest = "textembedder/AndroidManifest.xml",
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/framework:calculator_options_java_proto_lite",
 | 
			
		||||
        "//mediapipe/java/com/google/mediapipe/framework:android_framework",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
 | 
			
		||||
        "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
 | 
			
		||||
        "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite",
 | 
			
		||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding",
 | 
			
		||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult",
 | 
			
		||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions",
 | 
			
		||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity",
 | 
			
		||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
 | 
			
		||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
 | 
			
		||||
        "//third_party:autovalue",
 | 
			
		||||
        "@maven//:com_google_guava_guava",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_text_aar")
 | 
			
		||||
 | 
			
		||||
mediapipe_tasks_text_aar(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,8 @@
 | 
			
		|||
<?xml version="1.0" encoding="utf-8"?>
 | 
			
		||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
 | 
			
		||||
    package="com.google.mediapipe.tasks.text.textembedder">
 | 
			
		||||
 | 
			
		||||
    <uses-sdk android:minSdkVersion="24"
 | 
			
		||||
        android:targetSdkVersion="30" />
 | 
			
		||||
 | 
			
		||||
</manifest>
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,256 @@
 | 
			
		|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
package com.google.mediapipe.tasks.text.textembedder;
 | 
			
		||||
 | 
			
		||||
import android.content.Context;
 | 
			
		||||
import android.os.ParcelFileDescriptor;
 | 
			
		||||
import com.google.auto.value.AutoValue;
 | 
			
		||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
 | 
			
		||||
import com.google.mediapipe.framework.MediaPipeException;
 | 
			
		||||
import com.google.mediapipe.framework.Packet;
 | 
			
		||||
import com.google.mediapipe.framework.PacketGetter;
 | 
			
		||||
import com.google.mediapipe.framework.ProtoUtil;
 | 
			
		||||
import com.google.mediapipe.tasks.components.containers.Embedding;
 | 
			
		||||
import com.google.mediapipe.tasks.components.containers.EmbeddingResult;
 | 
			
		||||
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
 | 
			
		||||
import com.google.mediapipe.tasks.components.processors.EmbedderOptions;
 | 
			
		||||
import com.google.mediapipe.tasks.components.utils.CosineSimilarity;
 | 
			
		||||
import com.google.mediapipe.tasks.core.BaseOptions;
 | 
			
		||||
import com.google.mediapipe.tasks.core.OutputHandler;
 | 
			
		||||
import com.google.mediapipe.tasks.core.TaskInfo;
 | 
			
		||||
import com.google.mediapipe.tasks.core.TaskOptions;
 | 
			
		||||
import com.google.mediapipe.tasks.core.TaskRunner;
 | 
			
		||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
 | 
			
		||||
import com.google.mediapipe.tasks.text.textembedder.proto.TextEmbedderGraphOptionsProto;
 | 
			
		||||
import java.io.File;
 | 
			
		||||
import java.io.IOException;
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.Collections;
 | 
			
		||||
import java.util.HashMap;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import java.util.Optional;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Performs embedding extraction on text.
 | 
			
		||||
 *
 | 
			
		||||
 * <p>This API expects a TFLite model with (optional) <a
 | 
			
		||||
 * href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata</a>.
 | 
			
		||||
 *
 | 
			
		||||
 * <p>Metadata is required for models with int32 input tensors because it contains the input process
 | 
			
		||||
 * unit for the model's Tokenizer. No metadata is required for models with string input tensors.
 | 
			
		||||
 *
 | 
			
		||||
 * <ul>
 | 
			
		||||
 *   <li>Input tensors
 | 
			
		||||
 *       <ul>
 | 
			
		||||
 *         <li>Three input tensors ({@code kTfLiteInt32}) of shape {@code [batch_size x
 | 
			
		||||
 *             bert_max_seq_len]} representing the input ids, mask ids, and segment ids. This input
 | 
			
		||||
 *             signature requires a Bert Tokenizer process unit in the model metadata.
 | 
			
		||||
 *         <li>Or one input tensor ({@code kTfLiteInt32}) of shape {@code [batch_size x
 | 
			
		||||
 *             max_seq_len]} representing the input ids. This input signature requires a Regex
 | 
			
		||||
 *             Tokenizer process unit in the model metadata.
 | 
			
		||||
 *         <li>Or one input tensor ({@code kTfLiteString}) that is shapeless or has shape {@code
 | 
			
		||||
 *             [1]} containing the input string.
 | 
			
		||||
 *       </ul>
 | 
			
		||||
 *   <li>At least one output tensor ({@code kTfLiteFloat32}/{@code kTfLiteUint8}) with shape {@code
 | 
			
		||||
 *       [1 x N]} where N is the number of dimensions in the produced embeddings.
 | 
			
		||||
 * </ul>
 | 
			
		||||
 */
 | 
			
		||||
public final class TextEmbedder implements AutoCloseable {
 | 
			
		||||
  private static final String TAG = TextEmbedder.class.getSimpleName();
 | 
			
		||||
  private static final String TEXT_IN_STREAM_NAME = "text_in";
 | 
			
		||||
 | 
			
		||||
  @SuppressWarnings("ConstantCaseForConstants")
 | 
			
		||||
  private static final List<String> INPUT_STREAMS =
 | 
			
		||||
      Collections.unmodifiableList(Arrays.asList("TEXT:" + TEXT_IN_STREAM_NAME));
 | 
			
		||||
 | 
			
		||||
  @SuppressWarnings("ConstantCaseForConstants")
 | 
			
		||||
  private static final List<String> OUTPUT_STREAMS =
 | 
			
		||||
      Collections.unmodifiableList(Arrays.asList("EMBEDDINGS:embeddings_out"));
 | 
			
		||||
 | 
			
		||||
  private static final int EMBEDDINGS_OUT_STREAM_INDEX = 0;
 | 
			
		||||
  private static final String TASK_GRAPH_NAME =
 | 
			
		||||
      "mediapipe.tasks.text.text_embedder.TextEmbedderGraph";
 | 
			
		||||
  private final TaskRunner runner;
 | 
			
		||||
 | 
			
		||||
  static {
 | 
			
		||||
    System.loadLibrary("mediapipe_tasks_text_jni");
 | 
			
		||||
    ProtoUtil.registerTypeName(
 | 
			
		||||
        EmbeddingsProto.EmbeddingResult.class,
 | 
			
		||||
        "mediapipe.tasks.components.containers.proto.EmbeddingResult");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Creates a {@link TextEmbedder} instance from a model file and the default {@link
 | 
			
		||||
   * TextEmbedderOptions}.
 | 
			
		||||
   *
 | 
			
		||||
   * @param context an Android {@link Context}.
 | 
			
		||||
   * @param modelPath path to the text model with metadata in the assets.
 | 
			
		||||
   * @throws MediaPipeException if there is is an error during {@link TextEmbedder} creation.
 | 
			
		||||
   */
 | 
			
		||||
  public static TextEmbedder createFromFile(Context context, String modelPath) {
 | 
			
		||||
    BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
 | 
			
		||||
    return createFromOptions(
 | 
			
		||||
        context, TextEmbedderOptions.builder().setBaseOptions(baseOptions).build());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Creates a {@link TextEmbedder} instance from a model file and the default {@link
 | 
			
		||||
   * TextEmbedderOptions}.
 | 
			
		||||
   *
 | 
			
		||||
   * @param context an Android {@link Context}.
 | 
			
		||||
   * @param modelFile the text model {@link File} instance.
 | 
			
		||||
   * @throws IOException if an I/O error occurs when opening the tflite model file.
 | 
			
		||||
   * @throws MediaPipeException if there is an error during {@link TextEmbedder} creation.
 | 
			
		||||
   */
 | 
			
		||||
  public static TextEmbedder createFromFile(Context context, File modelFile) throws IOException {
 | 
			
		||||
    try (ParcelFileDescriptor descriptor =
 | 
			
		||||
        ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
 | 
			
		||||
      BaseOptions baseOptions =
 | 
			
		||||
          BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
 | 
			
		||||
      return createFromOptions(
 | 
			
		||||
          context, TextEmbedderOptions.builder().setBaseOptions(baseOptions).build());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Creates a {@link TextEmbedder} instance from {@link TextEmbedderOptions}.
 | 
			
		||||
   *
 | 
			
		||||
   * @param context an Android {@link Context}.
 | 
			
		||||
   * @param options a {@link TextEmbedderOptions} instance.
 | 
			
		||||
   * @throws MediaPipeException if there is an error during {@link TextEmbedder} creation.
 | 
			
		||||
   */
 | 
			
		||||
  public static TextEmbedder createFromOptions(Context context, TextEmbedderOptions options) {
 | 
			
		||||
    OutputHandler<TextEmbedderResult, Void> handler = new OutputHandler<>();
 | 
			
		||||
    handler.setOutputPacketConverter(
 | 
			
		||||
        new OutputHandler.OutputPacketConverter<TextEmbedderResult, Void>() {
 | 
			
		||||
          @Override
 | 
			
		||||
          public TextEmbedderResult convertToTaskResult(List<Packet> packets) {
 | 
			
		||||
            try {
 | 
			
		||||
              return TextEmbedderResult.create(
 | 
			
		||||
                  EmbeddingResult.createFromProto(
 | 
			
		||||
                      PacketGetter.getProto(
 | 
			
		||||
                          packets.get(EMBEDDINGS_OUT_STREAM_INDEX),
 | 
			
		||||
                          EmbeddingsProto.EmbeddingResult.getDefaultInstance())),
 | 
			
		||||
                  packets.get(EMBEDDINGS_OUT_STREAM_INDEX).getTimestamp());
 | 
			
		||||
            } catch (IOException e) {
 | 
			
		||||
              throw new MediaPipeException(
 | 
			
		||||
                  MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          @Override
 | 
			
		||||
          public Void convertToTaskInput(List<Packet> packets) {
 | 
			
		||||
            return null;
 | 
			
		||||
          }
 | 
			
		||||
        });
 | 
			
		||||
    TaskRunner runner =
 | 
			
		||||
        TaskRunner.create(
 | 
			
		||||
            context,
 | 
			
		||||
            TaskInfo.<TextEmbedderOptions>builder()
 | 
			
		||||
                .setTaskGraphName(TASK_GRAPH_NAME)
 | 
			
		||||
                .setInputStreams(INPUT_STREAMS)
 | 
			
		||||
                .setOutputStreams(OUTPUT_STREAMS)
 | 
			
		||||
                .setTaskOptions(options)
 | 
			
		||||
                .setEnableFlowLimiting(false)
 | 
			
		||||
                .build(),
 | 
			
		||||
            handler);
 | 
			
		||||
    return new TextEmbedder(runner);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Constructor to initialize a {@link TextEmbedder} from a {@link TaskRunner}.
 | 
			
		||||
   *
 | 
			
		||||
   * @param runner a {@link TaskRunner}.
 | 
			
		||||
   */
 | 
			
		||||
  private TextEmbedder(TaskRunner runner) {
 | 
			
		||||
    this.runner = runner;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Performs embedding extraction on the input text.
 | 
			
		||||
   *
 | 
			
		||||
   * @param inputText a {@link String} for processing.
 | 
			
		||||
   */
 | 
			
		||||
  public TextEmbedderResult embed(String inputText) {
 | 
			
		||||
    Map<String, Packet> inputPackets = new HashMap<>();
 | 
			
		||||
    inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText));
 | 
			
		||||
    return (TextEmbedderResult) runner.process(inputPackets);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Closes and cleans up the {@link TextEmbedder}. */
 | 
			
		||||
  @Override
 | 
			
		||||
  public void close() {
 | 
			
		||||
    runner.close();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Utility function to compute <a href="https://en.wikipedia.org/wiki/Cosine_similarity">cosine
 | 
			
		||||
   * similarity</a> between two {@link Embedding} objects.
 | 
			
		||||
   *
 | 
			
		||||
   * @throws IllegalArgumentException if the embeddings are of different types (float vs.
 | 
			
		||||
   *     quantized), have different sizes, or have an L2-norm of 0.
 | 
			
		||||
   */
 | 
			
		||||
  public static double cosineSimilarity(Embedding u, Embedding v) {
 | 
			
		||||
    return CosineSimilarity.compute(u, v);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Options for setting up a {@link TextEmbedder}. */
 | 
			
		||||
  @AutoValue
 | 
			
		||||
  public abstract static class TextEmbedderOptions extends TaskOptions {
 | 
			
		||||
 | 
			
		||||
    /** Builder for {@link TextEmbedderOptions}. */
 | 
			
		||||
    @AutoValue.Builder
 | 
			
		||||
    public abstract static class Builder {
 | 
			
		||||
      /** Sets the base options for the text embedder task. */
 | 
			
		||||
      public abstract Builder setBaseOptions(BaseOptions value);
 | 
			
		||||
 | 
			
		||||
      /**
 | 
			
		||||
       * Sets the optional {@link EmbedderOptions} controling embedder behavior, such as
 | 
			
		||||
       * L2-normalization and scalar quantization.
 | 
			
		||||
       */
 | 
			
		||||
      public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions);
 | 
			
		||||
 | 
			
		||||
      public abstract TextEmbedderOptions build();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    abstract BaseOptions baseOptions();
 | 
			
		||||
 | 
			
		||||
    abstract Optional<EmbedderOptions> embedderOptions();
 | 
			
		||||
 | 
			
		||||
    public static Builder builder() {
 | 
			
		||||
      return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /** Converts a {@link TextEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */
 | 
			
		||||
    @Override
 | 
			
		||||
    public CalculatorOptions convertToCalculatorOptionsProto() {
 | 
			
		||||
      BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
 | 
			
		||||
          BaseOptionsProto.BaseOptions.newBuilder();
 | 
			
		||||
      baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
 | 
			
		||||
      TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.Builder taskOptionsBuilder =
 | 
			
		||||
          TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.newBuilder()
 | 
			
		||||
              .setBaseOptions(baseOptionsBuilder);
 | 
			
		||||
      if (embedderOptions().isPresent()) {
 | 
			
		||||
        taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto());
 | 
			
		||||
      }
 | 
			
		||||
      return CalculatorOptions.newBuilder()
 | 
			
		||||
          .setExtension(
 | 
			
		||||
              TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.ext,
 | 
			
		||||
              taskOptionsBuilder.build())
 | 
			
		||||
          .build();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,54 @@
 | 
			
		|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
package com.google.mediapipe.tasks.text.textembedder;
 | 
			
		||||
 | 
			
		||||
import com.google.auto.value.AutoValue;
 | 
			
		||||
import com.google.mediapipe.tasks.components.containers.EmbeddingResult;
 | 
			
		||||
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
 | 
			
		||||
import com.google.mediapipe.tasks.core.TaskResult;
 | 
			
		||||
 | 
			
		||||
/** Represents the embedding results generated by {@link TextEmbedder}. */
 | 
			
		||||
@AutoValue
 | 
			
		||||
public abstract class TextEmbedderResult implements TaskResult {
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Creates an {@link TextEmbedderResult} instance.
 | 
			
		||||
   *
 | 
			
		||||
   * @param embeddingResult the {@link EmbeddingResult} object containing one embedding per embedder
 | 
			
		||||
   *     head.
 | 
			
		||||
   * @param timestampMs a timestamp for this result.
 | 
			
		||||
   */
 | 
			
		||||
  static TextEmbedderResult create(EmbeddingResult embeddingResult, long timestampMs) {
 | 
			
		||||
    return new AutoValue_TextEmbedderResult(embeddingResult, timestampMs);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Creates an {@link TextEmbedderResult} instance from a {@link EmbeddingsProto.EmbeddingResult}
 | 
			
		||||
   * protobuf message.
 | 
			
		||||
   *
 | 
			
		||||
   * @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert.
 | 
			
		||||
   * @param timestampMs a timestamp for this result.
 | 
			
		||||
   */
 | 
			
		||||
  static TextEmbedderResult createFromProto(
 | 
			
		||||
      EmbeddingsProto.EmbeddingResult proto, long timestampMs) {
 | 
			
		||||
    return create(EmbeddingResult.createFromProto(proto), timestampMs);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Contains one embedding per embedder head. */
 | 
			
		||||
  public abstract EmbeddingResult embeddingResult();
 | 
			
		||||
 | 
			
		||||
  @Override
 | 
			
		||||
  public abstract long timestampMs();
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -67,9 +67,7 @@ public class TextClassifierTest {
 | 
			
		|||
                    ApplicationProvider.getApplicationContext(), options));
 | 
			
		||||
    // TODO: Make MediaPipe InferenceCalculator report the detailed.
 | 
			
		||||
    // interpreter errors (e.g., "Encountered unresolved custom op").
 | 
			
		||||
    assertThat(exception)
 | 
			
		||||
        .hasMessageThat()
 | 
			
		||||
        .contains("interpreter_builder(&interpreter) == kTfLiteOk");
 | 
			
		||||
    assertThat(exception).hasMessageThat().contains("== kTfLiteOk");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  @Test
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,24 @@
 | 
			
		|||
<?xml version="1.0" encoding="utf-8"?>
 | 
			
		||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
 | 
			
		||||
    package="com.google.mediapipe.tasks.text.textembeddertest"
 | 
			
		||||
    android:versionCode="1"
 | 
			
		||||
    android:versionName="1.0" >
 | 
			
		||||
 | 
			
		||||
    <uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
 | 
			
		||||
    <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
 | 
			
		||||
 | 
			
		||||
    <uses-sdk android:minSdkVersion="24"
 | 
			
		||||
        android:targetSdkVersion="30" />
 | 
			
		||||
 | 
			
		||||
    <application
 | 
			
		||||
        android:label="textembeddertest"
 | 
			
		||||
        android:name="android.support.multidex.MultiDexApplication"
 | 
			
		||||
        android:taskAffinity="">
 | 
			
		||||
        <uses-library android:name="android.test.runner" />
 | 
			
		||||
    </application>
 | 
			
		||||
 | 
			
		||||
    <instrumentation
 | 
			
		||||
        android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
 | 
			
		||||
        android:targetPackage="com.google.mediapipe.tasks.text.textembeddertest" />
 | 
			
		||||
 | 
			
		||||
</manifest>
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,19 @@
 | 
			
		|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
# TODO: Enable this in OSS
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,98 @@
 | 
			
		|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
package com.google.mediapipe.tasks.text.textembedder;
 | 
			
		||||
 | 
			
		||||
import static com.google.common.truth.Truth.assertThat;
 | 
			
		||||
import static org.junit.Assert.assertThrows;
 | 
			
		||||
 | 
			
		||||
import androidx.test.core.app.ApplicationProvider;
 | 
			
		||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
 | 
			
		||||
import com.google.mediapipe.framework.MediaPipeException;
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.junit.runner.RunWith;
 | 
			
		||||
 | 
			
		||||
/** Test for {@link TextEmbedder}/ */
 | 
			
		||||
@RunWith(AndroidJUnit4.class)
 | 
			
		||||
public class TextEmbedderTest {
 | 
			
		||||
  private static final String BERT_MODEL_FILE = "mobilebert_embedding_with_metadata.tflite";
 | 
			
		||||
  private static final String REGEX_MODEL_FILE = "regex_one_embedding_with_metadata.tflite";
 | 
			
		||||
 | 
			
		||||
  private static final double DOUBLE_DIFF_TOLERANCE = 1e-4;
 | 
			
		||||
  private static final float FLOAT_DIFF_TOLERANCE = 1e-4f;
 | 
			
		||||
 | 
			
		||||
  @Test
 | 
			
		||||
  public void create_failsWithMissingModel() throws Exception {
 | 
			
		||||
    String nonExistentFile = "/path/to/non/existent/file";
 | 
			
		||||
    MediaPipeException exception =
 | 
			
		||||
        assertThrows(
 | 
			
		||||
            MediaPipeException.class,
 | 
			
		||||
            () ->
 | 
			
		||||
                TextEmbedder.createFromFile(
 | 
			
		||||
                    ApplicationProvider.getApplicationContext(), nonExistentFile));
 | 
			
		||||
    assertThat(exception).hasMessageThat().contains(nonExistentFile);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  @Test
 | 
			
		||||
  public void embed_succeedsWithBert() throws Exception {
 | 
			
		||||
    TextEmbedder textEmbedder =
 | 
			
		||||
        TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
 | 
			
		||||
 | 
			
		||||
    TextEmbedderResult result0 = textEmbedder.embed("it's a charming and often affecting journey");
 | 
			
		||||
    assertThat(result0.embeddingResult().embeddings().size()).isEqualTo(1);
 | 
			
		||||
    assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(512);
 | 
			
		||||
    assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()[0])
 | 
			
		||||
        .isWithin(FLOAT_DIFF_TOLERANCE)
 | 
			
		||||
        .of(20.59746f);
 | 
			
		||||
    TextEmbedderResult result1 = textEmbedder.embed("what a great and fantastic trip");
 | 
			
		||||
    assertThat(result1.embeddingResult().embeddings().size()).isEqualTo(1);
 | 
			
		||||
    assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(512);
 | 
			
		||||
    assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()[0])
 | 
			
		||||
        .isWithin(FLOAT_DIFF_TOLERANCE)
 | 
			
		||||
        .of(21.774776f);
 | 
			
		||||
 | 
			
		||||
    // Check cosine similarity.
 | 
			
		||||
    double similarity =
 | 
			
		||||
        TextEmbedder.cosineSimilarity(
 | 
			
		||||
            result0.embeddingResult().embeddings().get(0),
 | 
			
		||||
            result1.embeddingResult().embeddings().get(0));
 | 
			
		||||
    assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.968879);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  @Test
 | 
			
		||||
  public void embed_succeedsWithRegex() throws Exception {
 | 
			
		||||
    TextEmbedder textEmbedder =
 | 
			
		||||
        TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE);
 | 
			
		||||
 | 
			
		||||
    TextEmbedderResult result0 = textEmbedder.embed("it's a charming and often affecting journey");
 | 
			
		||||
    assertThat(result0.embeddingResult().embeddings().size()).isEqualTo(1);
 | 
			
		||||
    assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(16);
 | 
			
		||||
    assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()[0])
 | 
			
		||||
        .isWithin(FLOAT_DIFF_TOLERANCE)
 | 
			
		||||
        .of(0.030935612f);
 | 
			
		||||
    TextEmbedderResult result1 = textEmbedder.embed("what a great and fantastic trip");
 | 
			
		||||
    assertThat(result1.embeddingResult().embeddings().size()).isEqualTo(1);
 | 
			
		||||
    assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(16);
 | 
			
		||||
    assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()[0])
 | 
			
		||||
        .isWithin(FLOAT_DIFF_TOLERANCE)
 | 
			
		||||
        .of(0.0312863f);
 | 
			
		||||
 | 
			
		||||
    // Check cosine similarity.
 | 
			
		||||
    double similarity =
 | 
			
		||||
        TextEmbedder.cosineSimilarity(
 | 
			
		||||
            result0.embeddingResult().embeddings().get(0),
 | 
			
		||||
            result1.embeddingResult().embeddings().get(0));
 | 
			
		||||
    assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999937);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user