Add option, result types and utils for Java embedders.

PiperOrigin-RevId: 487615327
This commit is contained in:
MediaPipe Team 2022-11-10 12:51:34 -08:00 committed by Copybara-Service
parent aeb2466844
commit 8ec4427bd7
10 changed files with 538 additions and 0 deletions

View File

@ -89,3 +89,30 @@ filegroup(
srcs = glob(["*.java"]),
visibility = ["//mediapipe/tasks/java/com/google/mediapipe/tasks/core:__subpackages__"],
)
android_library(
name = "embedding",
srcs = ["Embedding.java"],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
android_library(
name = "embeddingresult",
srcs = ["EmbeddingResult.java"],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [
":embedding",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,88 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
import java.util.Optional;
/**
* Represents the embedding for a given embedder head. Typically used in embedding tasks.
*
* <p>One and only one of the two 'floatEmbedding' and 'quantizedEmbedding' will contain data, based
* on whether or not the embedder was configured to perform scala quantization.
*/
@AutoValue
public abstract class Embedding {
/**
* Creates an {@link Embedding} instance.
*
* @param floatEmbedding the floating-point embedding
* @param quantizedEmbedding the quantized embedding.
* @param headIndex the index of the embedder head.
* @param headName the optional name of the embedder head.
*/
public static Embedding create(
float[] floatEmbedding, byte[] quantizedEmbedding, int headIndex, Optional<String> headName) {
return new AutoValue_Embedding(floatEmbedding, quantizedEmbedding, headIndex, headName);
}
/**
* Creates an {@link Embedding} object from an {@link EmbeddingsProto.Embedding} protobuf message.
*
* @param proto the {@link EmbeddingsProto.Embedding} protobuf message to convert.
*/
public static Embedding createFromProto(EmbeddingsProto.Embedding proto) {
float[] floatEmbedding;
if (proto.hasFloatEmbedding()) {
floatEmbedding = new float[proto.getFloatEmbedding().getValuesCount()];
for (int i = 0; i < floatEmbedding.length; i++) {
floatEmbedding[i] = proto.getFloatEmbedding().getValues(i);
}
} else {
floatEmbedding = new float[0];
}
return Embedding.create(
floatEmbedding,
proto.hasQuantizedEmbedding()
? proto.getQuantizedEmbedding().getValues().toByteArray()
: new byte[0],
proto.getHeadIndex(),
proto.hasHeadName() ? Optional.of(proto.getHeadName()) : Optional.empty());
}
/**
* Floating-point embedding.
*
* <p>Empty if the embedder was configured to perform scalar quantization.
*/
public abstract float[] floatEmbedding();
/**
* Quantized embedding.
*
* <p>Empty if the embedder was not configured to perform scalar quantization.
*/
public abstract byte[] quantizedEmbedding();
/**
* The index of the embedder head these entries refer to. This is useful for multi-head models.
*/
public abstract int headIndex();
/** The optional name of the embedder head, which is the corresponding tensor metadata name. */
public abstract Optional<String> headName();
}

View File

@ -0,0 +1,69 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
/** Represents the embedding results of a model. Typically used as a result for embedding tasks. */
@AutoValue
public abstract class EmbeddingResult {
/**
* Creates a {@link EmbeddingResult} instance.
*
* @param embeddings the list of {@link Embedding} objects containing the embedding for each head
* of the model.
* @param timestampMs the optional timestamp (in milliseconds) of the start of the chunk of data
* corresponding to these results.
*/
public static EmbeddingResult create(List<Embedding> embeddings, Optional<Long> timestampMs) {
return new AutoValue_EmbeddingResult(Collections.unmodifiableList(embeddings), timestampMs);
}
/**
* Creates a {@link EmbeddingResult} object from a {@link EmbeddingsProto.EmbeddingResult}
* protobuf message.
*
* @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert.
*/
public static EmbeddingResult createFromProto(EmbeddingsProto.EmbeddingResult proto) {
List<Embedding> embeddings = new ArrayList<>();
for (EmbeddingsProto.Embedding embeddingProto : proto.getEmbeddingsList()) {
embeddings.add(Embedding.createFromProto(embeddingProto));
}
Optional<Long> timestampMs =
proto.hasTimestampMs() ? Optional.of(proto.getTimestampMs()) : Optional.empty();
return create(embeddings, timestampMs);
}
/** The embedding results for each head of the model. */
public abstract List<Embedding> embeddings();
/**
* The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to
* these results.
*
* <p>This is only used for embedding extraction on time series (e.g. audio embedder). In these
* use cases, the amount of data to process might exceed the maximum size that the model can
* process: to solve this, the input data is split into multiple chunks starting at different
* timestamps.
*/
public abstract Optional<Long> timestampMs();
}

View File

@ -29,6 +29,19 @@ android_library(
],
)
android_library(
name = "embedderoptions",
srcs = ["EmbedderOptions.java"],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
# Expose the java source files for building mediapipe tasks core AAR.
filegroup(
name = "java_src",

View File

@ -0,0 +1,68 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.processors;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto;
/** Embedder options shared across MediaPipe Java embedding tasks. */
@AutoValue
public abstract class EmbedderOptions {
/** Builder for {@link EmbedderOptions} */
@AutoValue.Builder
public abstract static class Builder {
/**
* Sets whether L2 normalization should be performed on the returned embeddings. Use this option
* only if the model does not already contain a native <code>L2_NORMALIZATION</code> TF Lite Op.
* In most cases, this is already the case and L2 norm is thus achieved through TF Lite
* inference.
*
* <p>False by default.
*/
public abstract Builder setL2Normalize(boolean l2Normalize);
/**
* Sets whether the returned embedding should be quantized to bytes via scalar quantization.
* Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is guaranteed
* to have value in <code>[-1.0, 1.0]</code>. Use {@link #setL2Normalize(boolean)} if this is
* not the case.
*
* <p>False by default.
*/
public abstract Builder setQuantize(boolean quantize);
public abstract EmbedderOptions build();
}
public abstract boolean l2Normalize();
public abstract boolean quantize();
public static Builder builder() {
return new AutoValue_EmbedderOptions.Builder().setL2Normalize(false).setQuantize(false);
}
/**
* Converts an {@link EmbedderOptions} object to an {@link EmbedderOptionsProto.EmbedderOptions}
* protobuf message.
*/
public EmbedderOptionsProto.EmbedderOptions convertToProto() {
return EmbedderOptionsProto.EmbedderOptions.newBuilder()
.setL2Normalize(l2Normalize())
.setQuantize(quantize())
.build();
}
}

View File

@ -0,0 +1,26 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
android_library(
name = "cosinesimilarity",
srcs = ["CosineSimilarity.java"],
deps = [
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,88 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.utils;
import com.google.mediapipe.tasks.components.containers.Embedding;
/** Utility class for computing cosine similarity between {@link Embedding} objects. */
public class CosineSimilarity {
// Non-instantiable class.
private CosineSimilarity() {}
/**
* Computes <a href="https://en.wikipedia.org/wiki/Cosine_similarity">cosine similarity</a>
* between two {@link Embedding} objects.
*
* @throws IllegalArgumentException if the embeddings are of different types (float vs.
* quantized), have different sizes, or have an L2-norm of 0.
*/
public static double compute(Embedding u, Embedding v) {
if (u.floatEmbedding().length > 0 && v.floatEmbedding().length > 0) {
return computeFloat(u.floatEmbedding(), v.floatEmbedding());
}
if (u.quantizedEmbedding().length > 0 && v.quantizedEmbedding().length > 0) {
return computeQuantized(u.quantizedEmbedding(), v.quantizedEmbedding());
}
throw new IllegalArgumentException(
"Cannot compute cosine similarity between quantized and float embeddings.");
}
private static double computeFloat(float[] u, float[] v) {
if (u.length != v.length) {
throw new IllegalArgumentException(
String.format(
"Cannot compute cosine similarity between embeddings of different sizes (%d vs."
+ " %d).",
u.length, v.length));
}
double dotProduct = 0.0;
double normU = 0.0;
double normV = 0.0;
for (int i = 0; i < u.length; i++) {
dotProduct += u[i] * v[i];
normU += u[i] * u[i];
normV += v[i] * v[i];
}
if (normU <= 0 || normV <= 0) {
throw new IllegalArgumentException(
"Cannot compute cosine similarity on embedding with 0 norm.");
}
return dotProduct / Math.sqrt(normU * normV);
}
private static double computeQuantized(byte[] u, byte[] v) {
if (u.length != v.length) {
throw new IllegalArgumentException(
String.format(
"Cannot compute cosine similarity between embeddings of different sizes (%d vs."
+ " %d).",
u.length, v.length));
}
double dotProduct = 0.0;
double normU = 0.0;
double normV = 0.0;
for (int i = 0; i < u.length; i++) {
dotProduct += u[i] * v[i];
normU += u[i] * u[i];
normV += v[i] * v[i];
}
if (normU <= 0 || normV <= 0) {
throw new IllegalArgumentException(
"Cannot compute cosine similarity on embedding with 0 norm.");
}
return dotProduct / Math.sqrt(normU * normV);
}
}

View File

@ -0,0 +1,24 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.components.utilstest"
android:versionCode="1"
android:versionName="1.0" >
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
<application
android:label="utilstest"
android:name="android.support.multidex.MultiDexApplication"
android:taskAffinity="">
<uses-library android:name="android.test.runner" />
</application>
<instrumentation
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
android:targetPackage="com.google.mediapipe.tasks.components.utilstest" />
</manifest>

View File

@ -0,0 +1,19 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# TODO: Enable this in OSS

View File

@ -0,0 +1,116 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.utils;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.mediapipe.tasks.components.containers.Embedding;
import java.util.Optional;
import org.junit.Test;
import org.junit.runner.RunWith;
/** Tests for {@link CosineSimilarity}. */
@RunWith(AndroidJUnit4.class)
public final class CosineSimilarityTest {
@Test
public void failsWithQuantizedAndFloatEmbeddings() {
Embedding u =
Embedding.create(
new float[] {1.0f}, new byte[0], /*headIndex=*/ 0, /*headName=*/ Optional.empty());
Embedding v =
Embedding.create(
new float[0], new byte[] {1}, /*headIndex=*/ 0, /*headName=*/ Optional.empty());
IllegalArgumentException exception =
assertThrows(IllegalArgumentException.class, () -> CosineSimilarity.compute(u, v));
assertThat(exception)
.hasMessageThat()
.contains("Cannot compute cosine similarity between quantized and float embeddings");
}
@Test
public void failsWithZeroNorm() {
Embedding u =
Embedding.create(
new float[] {0.0f}, new byte[0], /*headIndex=*/ 0, /*headName=*/ Optional.empty());
IllegalArgumentException exception =
assertThrows(IllegalArgumentException.class, () -> CosineSimilarity.compute(u, u));
assertThat(exception)
.hasMessageThat()
.contains("Cannot compute cosine similarity on embedding with 0 norm");
}
@Test
public void failsWithDifferentSizes() {
Embedding u =
Embedding.create(
new float[] {1.0f, 2.0f},
new byte[0],
/*headIndex=*/ 0,
/*headName=*/ Optional.empty());
Embedding v =
Embedding.create(
new float[] {1.0f, 2.0f, 3.0f},
new byte[0],
/*headIndex=*/ 0,
/*headName=*/ Optional.empty());
IllegalArgumentException exception =
assertThrows(IllegalArgumentException.class, () -> CosineSimilarity.compute(u, v));
assertThat(exception)
.hasMessageThat()
.contains("Cannot compute cosine similarity between embeddings of different sizes");
}
@Test
public void succeedsWithFloatEmbeddings() {
Embedding u =
Embedding.create(
new float[] {1.0f, 0.0f, 0.0f, 0.0f},
new byte[0],
/*headIndex=*/ 0,
/*headName=*/ Optional.empty());
Embedding v =
Embedding.create(
new float[] {0.5f, 0.5f, 0.5f, 0.5f},
new byte[0],
/*headIndex=*/ 0,
/*headName=*/ Optional.empty());
assertThat(CosineSimilarity.compute(u, v)).isEqualTo(0.5);
}
@Test
public void succeedsWithQuantizedEmbeddings() {
Embedding u =
Embedding.create(
new float[0],
new byte[] {127, 0, 0, 0},
/*headIndex=*/ 0,
/*headName=*/ Optional.empty());
Embedding v =
Embedding.create(
new float[0],
new byte[] {-128, 0, 0, 0},
/*headIndex=*/ 0,
/*headName=*/ Optional.empty());
assertThat(CosineSimilarity.compute(u, v)).isEqualTo(-1.0);
}
}