diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index 5c41edbad..aa835894e 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -89,3 +89,30 @@ filegroup( srcs = glob(["*.java"]), visibility = ["//mediapipe/tasks/java/com/google/mediapipe/tasks/core:__subpackages__"], ) + +android_library( + name = "embedding", + srcs = ["Embedding.java"], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + deps = [ + "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + +android_library( + name = "embeddingresult", + srcs = ["EmbeddingResult.java"], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + deps = [ + ":embedding", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Embedding.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Embedding.java new file mode 100644 index 000000000..6cb0e325b --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Embedding.java @@ -0,0 +1,88 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import java.util.Optional; + +/** + * Represents the embedding for a given embedder head. Typically used in embedding tasks. + * + *

One and only one of the two 'floatEmbedding' and 'quantizedEmbedding' will contain data, based + * on whether or not the embedder was configured to perform scala quantization. + */ +@AutoValue +public abstract class Embedding { + + /** + * Creates an {@link Embedding} instance. + * + * @param floatEmbedding the floating-point embedding + * @param quantizedEmbedding the quantized embedding. + * @param headIndex the index of the embedder head. + * @param headName the optional name of the embedder head. + */ + public static Embedding create( + float[] floatEmbedding, byte[] quantizedEmbedding, int headIndex, Optional headName) { + return new AutoValue_Embedding(floatEmbedding, quantizedEmbedding, headIndex, headName); + } + + /** + * Creates an {@link Embedding} object from an {@link EmbeddingsProto.Embedding} protobuf message. + * + * @param proto the {@link EmbeddingsProto.Embedding} protobuf message to convert. + */ + public static Embedding createFromProto(EmbeddingsProto.Embedding proto) { + float[] floatEmbedding; + if (proto.hasFloatEmbedding()) { + floatEmbedding = new float[proto.getFloatEmbedding().getValuesCount()]; + for (int i = 0; i < floatEmbedding.length; i++) { + floatEmbedding[i] = proto.getFloatEmbedding().getValues(i); + } + } else { + floatEmbedding = new float[0]; + } + return Embedding.create( + floatEmbedding, + proto.hasQuantizedEmbedding() + ? proto.getQuantizedEmbedding().getValues().toByteArray() + : new byte[0], + proto.getHeadIndex(), + proto.hasHeadName() ? Optional.of(proto.getHeadName()) : Optional.empty()); + } + + /** + * Floating-point embedding. + * + *

Empty if the embedder was configured to perform scalar quantization. + */ + public abstract float[] floatEmbedding(); + + /** + * Quantized embedding. + * + *

Empty if the embedder was not configured to perform scalar quantization. + */ + public abstract byte[] quantizedEmbedding(); + + /** + * The index of the embedder head these entries refer to. This is useful for multi-head models. + */ + public abstract int headIndex(); + + /** The optional name of the embedder head, which is the corresponding tensor metadata name. */ + public abstract Optional headName(); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/EmbeddingResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/EmbeddingResult.java new file mode 100644 index 000000000..32fa7b47d --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/EmbeddingResult.java @@ -0,0 +1,69 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** Represents the embedding results of a model. Typically used as a result for embedding tasks. */ +@AutoValue +public abstract class EmbeddingResult { + + /** + * Creates a {@link EmbeddingResult} instance. + * + * @param embeddings the list of {@link Embedding} objects containing the embedding for each head + * of the model. + * @param timestampMs the optional timestamp (in milliseconds) of the start of the chunk of data + * corresponding to these results. + */ + public static EmbeddingResult create(List embeddings, Optional timestampMs) { + return new AutoValue_EmbeddingResult(Collections.unmodifiableList(embeddings), timestampMs); + } + + /** + * Creates a {@link EmbeddingResult} object from a {@link EmbeddingsProto.EmbeddingResult} + * protobuf message. + * + * @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert. + */ + public static EmbeddingResult createFromProto(EmbeddingsProto.EmbeddingResult proto) { + List embeddings = new ArrayList<>(); + for (EmbeddingsProto.Embedding embeddingProto : proto.getEmbeddingsList()) { + embeddings.add(Embedding.createFromProto(embeddingProto)); + } + Optional timestampMs = + proto.hasTimestampMs() ? Optional.of(proto.getTimestampMs()) : Optional.empty(); + return create(embeddings, timestampMs); + } + + /** The embedding results for each head of the model. */ + public abstract List embeddings(); + + /** + * The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to + * these results. + * + *

This is only used for embedding extraction on time series (e.g. audio embedder). In these + * use cases, the amount of data to process might exceed the maximum size that the model can + * process: to solve this, the input data is split into multiple chunks starting at different + * timestamps. + */ + public abstract Optional timestampMs(); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD index 1f99f1612..e61e59390 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD @@ -29,6 +29,19 @@ android_library( ], ) +android_library( + name = "embedderoptions", + srcs = ["EmbedderOptions.java"], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + deps = [ + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + # Expose the java source files for building mediapipe tasks core AAR. filegroup( name = "java_src", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java new file mode 100644 index 000000000..3cd197234 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java @@ -0,0 +1,68 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.processors; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; + +/** Embedder options shared across MediaPipe Java embedding tasks. */ +@AutoValue +public abstract class EmbedderOptions { + + /** Builder for {@link EmbedderOptions} */ + @AutoValue.Builder + public abstract static class Builder { + /** + * Sets whether L2 normalization should be performed on the returned embeddings. Use this option + * only if the model does not already contain a native L2_NORMALIZATION TF Lite Op. + * In most cases, this is already the case and L2 norm is thus achieved through TF Lite + * inference. + * + *

False by default. + */ + public abstract Builder setL2Normalize(boolean l2Normalize); + + /** + * Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is guaranteed + * to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} if this is + * not the case. + * + *

False by default. + */ + public abstract Builder setQuantize(boolean quantize); + + public abstract EmbedderOptions build(); + } + + public abstract boolean l2Normalize(); + + public abstract boolean quantize(); + + public static Builder builder() { + return new AutoValue_EmbedderOptions.Builder().setL2Normalize(false).setQuantize(false); + } + + /** + * Converts an {@link EmbedderOptions} object to an {@link EmbedderOptionsProto.EmbedderOptions} + * protobuf message. + */ + public EmbedderOptionsProto.EmbedderOptions convertToProto() { + return EmbedderOptionsProto.EmbedderOptions.newBuilder() + .setL2Normalize(l2Normalize()) + .setQuantize(quantize()) + .build(); + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD new file mode 100644 index 000000000..cd2bbafc8 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/BUILD @@ -0,0 +1,26 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +android_library( + name = "cosinesimilarity", + srcs = ["CosineSimilarity.java"], + deps = [ + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/CosineSimilarity.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/CosineSimilarity.java new file mode 100644 index 000000000..6b995731f --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils/CosineSimilarity.java @@ -0,0 +1,88 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.utils; + +import com.google.mediapipe.tasks.components.containers.Embedding; + +/** Utility class for computing cosine similarity between {@link Embedding} objects. */ +public class CosineSimilarity { + + // Non-instantiable class. + private CosineSimilarity() {} + + /** + * Computes cosine similarity + * between two {@link Embedding} objects. + * + * @throws IllegalArgumentException if the embeddings are of different types (float vs. + * quantized), have different sizes, or have an L2-norm of 0. + */ + public static double compute(Embedding u, Embedding v) { + if (u.floatEmbedding().length > 0 && v.floatEmbedding().length > 0) { + return computeFloat(u.floatEmbedding(), v.floatEmbedding()); + } + if (u.quantizedEmbedding().length > 0 && v.quantizedEmbedding().length > 0) { + return computeQuantized(u.quantizedEmbedding(), v.quantizedEmbedding()); + } + throw new IllegalArgumentException( + "Cannot compute cosine similarity between quantized and float embeddings."); + } + + private static double computeFloat(float[] u, float[] v) { + if (u.length != v.length) { + throw new IllegalArgumentException( + String.format( + "Cannot compute cosine similarity between embeddings of different sizes (%d vs." + + " %d).", + u.length, v.length)); + } + double dotProduct = 0.0; + double normU = 0.0; + double normV = 0.0; + for (int i = 0; i < u.length; i++) { + dotProduct += u[i] * v[i]; + normU += u[i] * u[i]; + normV += v[i] * v[i]; + } + if (normU <= 0 || normV <= 0) { + throw new IllegalArgumentException( + "Cannot compute cosine similarity on embedding with 0 norm."); + } + return dotProduct / Math.sqrt(normU * normV); + } + + private static double computeQuantized(byte[] u, byte[] v) { + if (u.length != v.length) { + throw new IllegalArgumentException( + String.format( + "Cannot compute cosine similarity between embeddings of different sizes (%d vs." + + " %d).", + u.length, v.length)); + } + double dotProduct = 0.0; + double normU = 0.0; + double normV = 0.0; + for (int i = 0; i < u.length; i++) { + dotProduct += u[i] * v[i]; + normU += u[i] * u[i]; + normV += v[i] * v[i]; + } + if (normU <= 0 || normV <= 0) { + throw new IllegalArgumentException( + "Cannot compute cosine similarity on embedding with 0 norm."); + } + return dotProduct / Math.sqrt(normU * normV); + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/AndroidManifest.xml new file mode 100644 index 000000000..d0c1ff91f --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/BUILD new file mode 100644 index 000000000..a7f804c64 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/CosineSimilarityTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/CosineSimilarityTest.java new file mode 100644 index 000000000..f7a1ae002 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/utils/CosineSimilarityTest.java @@ -0,0 +1,116 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.utils; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.tasks.components.containers.Embedding; +import java.util.Optional; +import org.junit.Test; +import org.junit.runner.RunWith; + +/** Tests for {@link CosineSimilarity}. */ +@RunWith(AndroidJUnit4.class) +public final class CosineSimilarityTest { + + @Test + public void failsWithQuantizedAndFloatEmbeddings() { + Embedding u = + Embedding.create( + new float[] {1.0f}, new byte[0], /*headIndex=*/ 0, /*headName=*/ Optional.empty()); + Embedding v = + Embedding.create( + new float[0], new byte[] {1}, /*headIndex=*/ 0, /*headName=*/ Optional.empty()); + + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> CosineSimilarity.compute(u, v)); + assertThat(exception) + .hasMessageThat() + .contains("Cannot compute cosine similarity between quantized and float embeddings"); + } + + @Test + public void failsWithZeroNorm() { + Embedding u = + Embedding.create( + new float[] {0.0f}, new byte[0], /*headIndex=*/ 0, /*headName=*/ Optional.empty()); + + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> CosineSimilarity.compute(u, u)); + assertThat(exception) + .hasMessageThat() + .contains("Cannot compute cosine similarity on embedding with 0 norm"); + } + + @Test + public void failsWithDifferentSizes() { + Embedding u = + Embedding.create( + new float[] {1.0f, 2.0f}, + new byte[0], + /*headIndex=*/ 0, + /*headName=*/ Optional.empty()); + Embedding v = + Embedding.create( + new float[] {1.0f, 2.0f, 3.0f}, + new byte[0], + /*headIndex=*/ 0, + /*headName=*/ Optional.empty()); + + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> CosineSimilarity.compute(u, v)); + assertThat(exception) + .hasMessageThat() + .contains("Cannot compute cosine similarity between embeddings of different sizes"); + } + + @Test + public void succeedsWithFloatEmbeddings() { + Embedding u = + Embedding.create( + new float[] {1.0f, 0.0f, 0.0f, 0.0f}, + new byte[0], + /*headIndex=*/ 0, + /*headName=*/ Optional.empty()); + Embedding v = + Embedding.create( + new float[] {0.5f, 0.5f, 0.5f, 0.5f}, + new byte[0], + /*headIndex=*/ 0, + /*headName=*/ Optional.empty()); + + assertThat(CosineSimilarity.compute(u, v)).isEqualTo(0.5); + } + + @Test + public void succeedsWithQuantizedEmbeddings() { + Embedding u = + Embedding.create( + new float[0], + new byte[] {127, 0, 0, 0}, + /*headIndex=*/ 0, + /*headName=*/ Optional.empty()); + Embedding v = + Embedding.create( + new float[0], + new byte[] {-128, 0, 0, 0}, + /*headIndex=*/ 0, + /*headName=*/ Optional.empty()); + + assertThat(CosineSimilarity.compute(u, v)).isEqualTo(-1.0); + } +}