diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD
index 5c41edbad..aa835894e 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD
@@ -89,3 +89,30 @@ filegroup(
srcs = glob(["*.java"]),
visibility = ["//mediapipe/tasks/java/com/google/mediapipe/tasks/core:__subpackages__"],
)
+
+android_library(
+ name = "embedding",
+ srcs = ["Embedding.java"],
+ javacopts = [
+ "-Xep:AndroidJdkLibsChecker:OFF",
+ ],
+ deps = [
+ "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
+ "//third_party:autovalue",
+ "@maven//:com_google_guava_guava",
+ ],
+)
+
+android_library(
+ name = "embeddingresult",
+ srcs = ["EmbeddingResult.java"],
+ javacopts = [
+ "-Xep:AndroidJdkLibsChecker:OFF",
+ ],
+ deps = [
+ ":embedding",
+ "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
+ "//third_party:autovalue",
+ "@maven//:com_google_guava_guava",
+ ],
+)
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Embedding.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Embedding.java
new file mode 100644
index 000000000..6cb0e325b
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Embedding.java
@@ -0,0 +1,88 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.components.containers;
+
+import com.google.auto.value.AutoValue;
+import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
+import java.util.Optional;
+
+/**
+ * Represents the embedding for a given embedder head. Typically used in embedding tasks.
+ *
+ *
One and only one of the two 'floatEmbedding' and 'quantizedEmbedding' will contain data, based
+ * on whether or not the embedder was configured to perform scala quantization.
+ */
+@AutoValue
+public abstract class Embedding {
+
+ /**
+ * Creates an {@link Embedding} instance.
+ *
+ * @param floatEmbedding the floating-point embedding
+ * @param quantizedEmbedding the quantized embedding.
+ * @param headIndex the index of the embedder head.
+ * @param headName the optional name of the embedder head.
+ */
+ public static Embedding create(
+ float[] floatEmbedding, byte[] quantizedEmbedding, int headIndex, Optional headName) {
+ return new AutoValue_Embedding(floatEmbedding, quantizedEmbedding, headIndex, headName);
+ }
+
+ /**
+ * Creates an {@link Embedding} object from an {@link EmbeddingsProto.Embedding} protobuf message.
+ *
+ * @param proto the {@link EmbeddingsProto.Embedding} protobuf message to convert.
+ */
+ public static Embedding createFromProto(EmbeddingsProto.Embedding proto) {
+ float[] floatEmbedding;
+ if (proto.hasFloatEmbedding()) {
+ floatEmbedding = new float[proto.getFloatEmbedding().getValuesCount()];
+ for (int i = 0; i < floatEmbedding.length; i++) {
+ floatEmbedding[i] = proto.getFloatEmbedding().getValues(i);
+ }
+ } else {
+ floatEmbedding = new float[0];
+ }
+ return Embedding.create(
+ floatEmbedding,
+ proto.hasQuantizedEmbedding()
+ ? proto.getQuantizedEmbedding().getValues().toByteArray()
+ : new byte[0],
+ proto.getHeadIndex(),
+ proto.hasHeadName() ? Optional.of(proto.getHeadName()) : Optional.empty());
+ }
+
+ /**
+ * Floating-point embedding.
+ *
+ * Empty if the embedder was configured to perform scalar quantization.
+ */
+ public abstract float[] floatEmbedding();
+
+ /**
+ * Quantized embedding.
+ *
+ *
Empty if the embedder was not configured to perform scalar quantization.
+ */
+ public abstract byte[] quantizedEmbedding();
+
+ /**
+ * The index of the embedder head these entries refer to. This is useful for multi-head models.
+ */
+ public abstract int headIndex();
+
+ /** The optional name of the embedder head, which is the corresponding tensor metadata name. */
+ public abstract Optional headName();
+}
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/EmbeddingResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/EmbeddingResult.java
new file mode 100644
index 000000000..32fa7b47d
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/EmbeddingResult.java
@@ -0,0 +1,69 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.components.containers;
+
+import com.google.auto.value.AutoValue;
+import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+
+/** Represents the embedding results of a model. Typically used as a result for embedding tasks. */
+@AutoValue
+public abstract class EmbeddingResult {
+
+ /**
+ * Creates a {@link EmbeddingResult} instance.
+ *
+ * @param embeddings the list of {@link Embedding} objects containing the embedding for each head
+ * of the model.
+ * @param timestampMs the optional timestamp (in milliseconds) of the start of the chunk of data
+ * corresponding to these results.
+ */
+ public static EmbeddingResult create(List embeddings, Optional timestampMs) {
+ return new AutoValue_EmbeddingResult(Collections.unmodifiableList(embeddings), timestampMs);
+ }
+
+ /**
+ * Creates a {@link EmbeddingResult} object from a {@link EmbeddingsProto.EmbeddingResult}
+ * protobuf message.
+ *
+ * @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert.
+ */
+ public static EmbeddingResult createFromProto(EmbeddingsProto.EmbeddingResult proto) {
+ List embeddings = new ArrayList<>();
+ for (EmbeddingsProto.Embedding embeddingProto : proto.getEmbeddingsList()) {
+ embeddings.add(Embedding.createFromProto(embeddingProto));
+ }
+ Optional timestampMs =
+ proto.hasTimestampMs() ? Optional.of(proto.getTimestampMs()) : Optional.empty();
+ return create(embeddings, timestampMs);
+ }
+
+ /** The embedding results for each head of the model. */
+ public abstract List embeddings();
+
+ /**
+ * The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to
+ * these results.
+ *
+ * This is only used for embedding extraction on time series (e.g. audio embedder). In these
+ * use cases, the amount of data to process might exceed the maximum size that the model can
+ * process: to solve this, the input data is split into multiple chunks starting at different
+ * timestamps.
+ */
+ public abstract Optional timestampMs();
+}
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD
index 1f99f1612..e61e59390 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD
@@ -29,6 +29,19 @@ android_library(
],
)
+android_library(
+ name = "embedderoptions",
+ srcs = ["EmbedderOptions.java"],
+ javacopts = [
+ "-Xep:AndroidJdkLibsChecker:OFF",
+ ],
+ deps = [
+ "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
+ "//third_party:autovalue",
+ "@maven//:com_google_guava_guava",
+ ],
+)
+
# Expose the java source files for building mediapipe tasks core AAR.
filegroup(
name = "java_src",
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java
new file mode 100644
index 000000000..3cd197234
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java
@@ -0,0 +1,68 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.components.processors;
+
+import com.google.auto.value.AutoValue;
+import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto;
+
+/** Embedder options shared across MediaPipe Java embedding tasks. */
+@AutoValue
+public abstract class EmbedderOptions {
+
+ /** Builder for {@link EmbedderOptions} */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ /**
+ * Sets whether L2 normalization should be performed on the returned embeddings. Use this option
+ * only if the model does not already contain a native L2_NORMALIZATION
TF Lite Op.
+ * In most cases, this is already the case and L2 norm is thus achieved through TF Lite
+ * inference.
+ *
+ * False by default.
+ */
+ public abstract Builder setL2Normalize(boolean l2Normalize);
+
+ /**
+ * Sets whether the returned embedding should be quantized to bytes via scalar quantization.
+ * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is guaranteed
+ * to have value in [-1.0, 1.0]
. Use {@link #setL2Normalize(boolean)} if this is
+ * not the case.
+ *
+ *
False by default.
+ */
+ public abstract Builder setQuantize(boolean quantize);
+
+ public abstract EmbedderOptions build();
+ }
+
+ public abstract boolean l2Normalize();
+
+ public abstract boolean quantize();
+
+ public static Builder builder() {
+ return new AutoValue_EmbedderOptions.Builder().setL2Normalize(false).setQuantize(false);
+ }
+
+ /**
+ * Converts an {@link EmbedderOptions} object to an {@link EmbedderOptionsProto.EmbedderOptions}
+ * protobuf message.
+ */
+ public EmbedderOptionsProto.EmbedderOptions convertToProto() {
+ return EmbedderOptionsProto.EmbedderOptions.newBuilder()
+ .setL2Normalize(l2Normalize())
+ .setQuantize(quantize())
+ .build();
+ }
+}
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD
new file mode 100644
index 000000000..cd2bbafc8
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD
@@ -0,0 +1,26 @@
+# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(default_visibility = ["//mediapipe/tasks:internal"])
+
+licenses(["notice"])
+
+android_library(
+ name = "cosinesimilarity",
+ srcs = ["CosineSimilarity.java"],
+ deps = [
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding",
+ "@maven//:com_google_guava_guava",
+ ],
+)
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/CosineSimilarity.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/CosineSimilarity.java
new file mode 100644
index 000000000..6b995731f
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/CosineSimilarity.java
@@ -0,0 +1,88 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.components.utils;
+
+import com.google.mediapipe.tasks.components.containers.Embedding;
+
+/** Utility class for computing cosine similarity between {@link Embedding} objects. */
+public class CosineSimilarity {
+
+ // Non-instantiable class.
+ private CosineSimilarity() {}
+
+ /**
+ * Computes cosine similarity
+ * between two {@link Embedding} objects.
+ *
+ * @throws IllegalArgumentException if the embeddings are of different types (float vs.
+ * quantized), have different sizes, or have an L2-norm of 0.
+ */
+ public static double compute(Embedding u, Embedding v) {
+ if (u.floatEmbedding().length > 0 && v.floatEmbedding().length > 0) {
+ return computeFloat(u.floatEmbedding(), v.floatEmbedding());
+ }
+ if (u.quantizedEmbedding().length > 0 && v.quantizedEmbedding().length > 0) {
+ return computeQuantized(u.quantizedEmbedding(), v.quantizedEmbedding());
+ }
+ throw new IllegalArgumentException(
+ "Cannot compute cosine similarity between quantized and float embeddings.");
+ }
+
+ private static double computeFloat(float[] u, float[] v) {
+ if (u.length != v.length) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Cannot compute cosine similarity between embeddings of different sizes (%d vs."
+ + " %d).",
+ u.length, v.length));
+ }
+ double dotProduct = 0.0;
+ double normU = 0.0;
+ double normV = 0.0;
+ for (int i = 0; i < u.length; i++) {
+ dotProduct += u[i] * v[i];
+ normU += u[i] * u[i];
+ normV += v[i] * v[i];
+ }
+ if (normU <= 0 || normV <= 0) {
+ throw new IllegalArgumentException(
+ "Cannot compute cosine similarity on embedding with 0 norm.");
+ }
+ return dotProduct / Math.sqrt(normU * normV);
+ }
+
+ private static double computeQuantized(byte[] u, byte[] v) {
+ if (u.length != v.length) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Cannot compute cosine similarity between embeddings of different sizes (%d vs."
+ + " %d).",
+ u.length, v.length));
+ }
+ double dotProduct = 0.0;
+ double normU = 0.0;
+ double normV = 0.0;
+ for (int i = 0; i < u.length; i++) {
+ dotProduct += u[i] * v[i];
+ normU += u[i] * u[i];
+ normV += v[i] * v[i];
+ }
+ if (normU <= 0 || normV <= 0) {
+ throw new IllegalArgumentException(
+ "Cannot compute cosine similarity on embedding with 0 norm.");
+ }
+ return dotProduct / Math.sqrt(normU * normV);
+ }
+}
diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/AndroidManifest.xml
new file mode 100644
index 000000000..d0c1ff91f
--- /dev/null
+++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/AndroidManifest.xml
@@ -0,0 +1,24 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/BUILD
new file mode 100644
index 000000000..a7f804c64
--- /dev/null
+++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/BUILD
@@ -0,0 +1,19 @@
+# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(default_visibility = ["//mediapipe/tasks:internal"])
+
+licenses(["notice"])
+
+# TODO: Enable this in OSS
diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/CosineSimilarityTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/CosineSimilarityTest.java
new file mode 100644
index 000000000..f7a1ae002
--- /dev/null
+++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/CosineSimilarityTest.java
@@ -0,0 +1,116 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.components.utils;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertThrows;
+
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import com.google.mediapipe.tasks.components.containers.Embedding;
+import java.util.Optional;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+/** Tests for {@link CosineSimilarity}. */
+@RunWith(AndroidJUnit4.class)
+public final class CosineSimilarityTest {
+
+ @Test
+ public void failsWithQuantizedAndFloatEmbeddings() {
+ Embedding u =
+ Embedding.create(
+ new float[] {1.0f}, new byte[0], /*headIndex=*/ 0, /*headName=*/ Optional.empty());
+ Embedding v =
+ Embedding.create(
+ new float[0], new byte[] {1}, /*headIndex=*/ 0, /*headName=*/ Optional.empty());
+
+ IllegalArgumentException exception =
+ assertThrows(IllegalArgumentException.class, () -> CosineSimilarity.compute(u, v));
+ assertThat(exception)
+ .hasMessageThat()
+ .contains("Cannot compute cosine similarity between quantized and float embeddings");
+ }
+
+ @Test
+ public void failsWithZeroNorm() {
+ Embedding u =
+ Embedding.create(
+ new float[] {0.0f}, new byte[0], /*headIndex=*/ 0, /*headName=*/ Optional.empty());
+
+ IllegalArgumentException exception =
+ assertThrows(IllegalArgumentException.class, () -> CosineSimilarity.compute(u, u));
+ assertThat(exception)
+ .hasMessageThat()
+ .contains("Cannot compute cosine similarity on embedding with 0 norm");
+ }
+
+ @Test
+ public void failsWithDifferentSizes() {
+ Embedding u =
+ Embedding.create(
+ new float[] {1.0f, 2.0f},
+ new byte[0],
+ /*headIndex=*/ 0,
+ /*headName=*/ Optional.empty());
+ Embedding v =
+ Embedding.create(
+ new float[] {1.0f, 2.0f, 3.0f},
+ new byte[0],
+ /*headIndex=*/ 0,
+ /*headName=*/ Optional.empty());
+
+ IllegalArgumentException exception =
+ assertThrows(IllegalArgumentException.class, () -> CosineSimilarity.compute(u, v));
+ assertThat(exception)
+ .hasMessageThat()
+ .contains("Cannot compute cosine similarity between embeddings of different sizes");
+ }
+
+ @Test
+ public void succeedsWithFloatEmbeddings() {
+ Embedding u =
+ Embedding.create(
+ new float[] {1.0f, 0.0f, 0.0f, 0.0f},
+ new byte[0],
+ /*headIndex=*/ 0,
+ /*headName=*/ Optional.empty());
+ Embedding v =
+ Embedding.create(
+ new float[] {0.5f, 0.5f, 0.5f, 0.5f},
+ new byte[0],
+ /*headIndex=*/ 0,
+ /*headName=*/ Optional.empty());
+
+ assertThat(CosineSimilarity.compute(u, v)).isEqualTo(0.5);
+ }
+
+ @Test
+ public void succeedsWithQuantizedEmbeddings() {
+ Embedding u =
+ Embedding.create(
+ new float[0],
+ new byte[] {127, 0, 0, 0},
+ /*headIndex=*/ 0,
+ /*headName=*/ Optional.empty());
+ Embedding v =
+ Embedding.create(
+ new float[0],
+ new byte[] {-128, 0, 0, 0},
+ /*headIndex=*/ 0,
+ /*headName=*/ Optional.empty());
+
+ assertThat(CosineSimilarity.compute(u, v)).isEqualTo(-1.0);
+ }
+}