Add option, result types and utils for Java embedders.
PiperOrigin-RevId: 487615327
This commit is contained in:
parent
aeb2466844
commit
8ec4427bd7
|
@ -89,3 +89,30 @@ filegroup(
|
||||||
srcs = glob(["*.java"]),
|
srcs = glob(["*.java"]),
|
||||||
visibility = ["//mediapipe/tasks/java/com/google/mediapipe/tasks/core:__subpackages__"],
|
visibility = ["//mediapipe/tasks/java/com/google/mediapipe/tasks/core:__subpackages__"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "embedding",
|
||||||
|
srcs = ["Embedding.java"],
|
||||||
|
javacopts = [
|
||||||
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "embeddingresult",
|
||||||
|
srcs = ["EmbeddingResult.java"],
|
||||||
|
javacopts = [
|
||||||
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":embedding",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.components.containers;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents the embedding for a given embedder head. Typically used in embedding tasks.
|
||||||
|
*
|
||||||
|
* <p>One and only one of the two 'floatEmbedding' and 'quantizedEmbedding' will contain data, based
|
||||||
|
* on whether or not the embedder was configured to perform scala quantization.
|
||||||
|
*/
|
||||||
|
@AutoValue
|
||||||
|
public abstract class Embedding {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link Embedding} instance.
|
||||||
|
*
|
||||||
|
* @param floatEmbedding the floating-point embedding
|
||||||
|
* @param quantizedEmbedding the quantized embedding.
|
||||||
|
* @param headIndex the index of the embedder head.
|
||||||
|
* @param headName the optional name of the embedder head.
|
||||||
|
*/
|
||||||
|
public static Embedding create(
|
||||||
|
float[] floatEmbedding, byte[] quantizedEmbedding, int headIndex, Optional<String> headName) {
|
||||||
|
return new AutoValue_Embedding(floatEmbedding, quantizedEmbedding, headIndex, headName);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link Embedding} object from an {@link EmbeddingsProto.Embedding} protobuf message.
|
||||||
|
*
|
||||||
|
* @param proto the {@link EmbeddingsProto.Embedding} protobuf message to convert.
|
||||||
|
*/
|
||||||
|
public static Embedding createFromProto(EmbeddingsProto.Embedding proto) {
|
||||||
|
float[] floatEmbedding;
|
||||||
|
if (proto.hasFloatEmbedding()) {
|
||||||
|
floatEmbedding = new float[proto.getFloatEmbedding().getValuesCount()];
|
||||||
|
for (int i = 0; i < floatEmbedding.length; i++) {
|
||||||
|
floatEmbedding[i] = proto.getFloatEmbedding().getValues(i);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
floatEmbedding = new float[0];
|
||||||
|
}
|
||||||
|
return Embedding.create(
|
||||||
|
floatEmbedding,
|
||||||
|
proto.hasQuantizedEmbedding()
|
||||||
|
? proto.getQuantizedEmbedding().getValues().toByteArray()
|
||||||
|
: new byte[0],
|
||||||
|
proto.getHeadIndex(),
|
||||||
|
proto.hasHeadName() ? Optional.of(proto.getHeadName()) : Optional.empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Floating-point embedding.
|
||||||
|
*
|
||||||
|
* <p>Empty if the embedder was configured to perform scalar quantization.
|
||||||
|
*/
|
||||||
|
public abstract float[] floatEmbedding();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Quantized embedding.
|
||||||
|
*
|
||||||
|
* <p>Empty if the embedder was not configured to perform scalar quantization.
|
||||||
|
*/
|
||||||
|
public abstract byte[] quantizedEmbedding();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The index of the embedder head these entries refer to. This is useful for multi-head models.
|
||||||
|
*/
|
||||||
|
public abstract int headIndex();
|
||||||
|
|
||||||
|
/** The optional name of the embedder head, which is the corresponding tensor metadata name. */
|
||||||
|
public abstract Optional<String> headName();
|
||||||
|
}
|
|
@ -0,0 +1,69 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.components.containers;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/** Represents the embedding results of a model. Typically used as a result for embedding tasks. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract class EmbeddingResult {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link EmbeddingResult} instance.
|
||||||
|
*
|
||||||
|
* @param embeddings the list of {@link Embedding} objects containing the embedding for each head
|
||||||
|
* of the model.
|
||||||
|
* @param timestampMs the optional timestamp (in milliseconds) of the start of the chunk of data
|
||||||
|
* corresponding to these results.
|
||||||
|
*/
|
||||||
|
public static EmbeddingResult create(List<Embedding> embeddings, Optional<Long> timestampMs) {
|
||||||
|
return new AutoValue_EmbeddingResult(Collections.unmodifiableList(embeddings), timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link EmbeddingResult} object from a {@link EmbeddingsProto.EmbeddingResult}
|
||||||
|
* protobuf message.
|
||||||
|
*
|
||||||
|
* @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert.
|
||||||
|
*/
|
||||||
|
public static EmbeddingResult createFromProto(EmbeddingsProto.EmbeddingResult proto) {
|
||||||
|
List<Embedding> embeddings = new ArrayList<>();
|
||||||
|
for (EmbeddingsProto.Embedding embeddingProto : proto.getEmbeddingsList()) {
|
||||||
|
embeddings.add(Embedding.createFromProto(embeddingProto));
|
||||||
|
}
|
||||||
|
Optional<Long> timestampMs =
|
||||||
|
proto.hasTimestampMs() ? Optional.of(proto.getTimestampMs()) : Optional.empty();
|
||||||
|
return create(embeddings, timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** The embedding results for each head of the model. */
|
||||||
|
public abstract List<Embedding> embeddings();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to
|
||||||
|
* these results.
|
||||||
|
*
|
||||||
|
* <p>This is only used for embedding extraction on time series (e.g. audio embedder). In these
|
||||||
|
* use cases, the amount of data to process might exceed the maximum size that the model can
|
||||||
|
* process: to solve this, the input data is split into multiple chunks starting at different
|
||||||
|
* timestamps.
|
||||||
|
*/
|
||||||
|
public abstract Optional<Long> timestampMs();
|
||||||
|
}
|
|
@ -29,6 +29,19 @@ android_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "embedderoptions",
|
||||||
|
srcs = ["EmbedderOptions.java"],
|
||||||
|
javacopts = [
|
||||||
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Expose the java source files for building mediapipe tasks core AAR.
|
# Expose the java source files for building mediapipe tasks core AAR.
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "java_src",
|
name = "java_src",
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.components.processors;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto;
|
||||||
|
|
||||||
|
/** Embedder options shared across MediaPipe Java embedding tasks. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract class EmbedderOptions {
|
||||||
|
|
||||||
|
/** Builder for {@link EmbedderOptions} */
|
||||||
|
@AutoValue.Builder
|
||||||
|
public abstract static class Builder {
|
||||||
|
/**
|
||||||
|
* Sets whether L2 normalization should be performed on the returned embeddings. Use this option
|
||||||
|
* only if the model does not already contain a native <code>L2_NORMALIZATION</code> TF Lite Op.
|
||||||
|
* In most cases, this is already the case and L2 norm is thus achieved through TF Lite
|
||||||
|
* inference.
|
||||||
|
*
|
||||||
|
* <p>False by default.
|
||||||
|
*/
|
||||||
|
public abstract Builder setL2Normalize(boolean l2Normalize);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets whether the returned embedding should be quantized to bytes via scalar quantization.
|
||||||
|
* Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is guaranteed
|
||||||
|
* to have value in <code>[-1.0, 1.0]</code>. Use {@link #setL2Normalize(boolean)} if this is
|
||||||
|
* not the case.
|
||||||
|
*
|
||||||
|
* <p>False by default.
|
||||||
|
*/
|
||||||
|
public abstract Builder setQuantize(boolean quantize);
|
||||||
|
|
||||||
|
public abstract EmbedderOptions build();
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract boolean l2Normalize();
|
||||||
|
|
||||||
|
public abstract boolean quantize();
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new AutoValue_EmbedderOptions.Builder().setL2Normalize(false).setQuantize(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts an {@link EmbedderOptions} object to an {@link EmbedderOptionsProto.EmbedderOptions}
|
||||||
|
* protobuf message.
|
||||||
|
*/
|
||||||
|
public EmbedderOptionsProto.EmbedderOptions convertToProto() {
|
||||||
|
return EmbedderOptionsProto.EmbedderOptions.newBuilder()
|
||||||
|
.setL2Normalize(l2Normalize())
|
||||||
|
.setQuantize(quantize())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "cosinesimilarity",
|
||||||
|
srcs = ["CosineSimilarity.java"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,88 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.components.utils;
|
||||||
|
|
||||||
|
import com.google.mediapipe.tasks.components.containers.Embedding;
|
||||||
|
|
||||||
|
/** Utility class for computing cosine similarity between {@link Embedding} objects. */
|
||||||
|
public class CosineSimilarity {
|
||||||
|
|
||||||
|
// Non-instantiable class.
|
||||||
|
private CosineSimilarity() {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes <a href="https://en.wikipedia.org/wiki/Cosine_similarity">cosine similarity</a>
|
||||||
|
* between two {@link Embedding} objects.
|
||||||
|
*
|
||||||
|
* @throws IllegalArgumentException if the embeddings are of different types (float vs.
|
||||||
|
* quantized), have different sizes, or have an L2-norm of 0.
|
||||||
|
*/
|
||||||
|
public static double compute(Embedding u, Embedding v) {
|
||||||
|
if (u.floatEmbedding().length > 0 && v.floatEmbedding().length > 0) {
|
||||||
|
return computeFloat(u.floatEmbedding(), v.floatEmbedding());
|
||||||
|
}
|
||||||
|
if (u.quantizedEmbedding().length > 0 && v.quantizedEmbedding().length > 0) {
|
||||||
|
return computeQuantized(u.quantizedEmbedding(), v.quantizedEmbedding());
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Cannot compute cosine similarity between quantized and float embeddings.");
|
||||||
|
}
|
||||||
|
|
||||||
|
private static double computeFloat(float[] u, float[] v) {
|
||||||
|
if (u.length != v.length) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
String.format(
|
||||||
|
"Cannot compute cosine similarity between embeddings of different sizes (%d vs."
|
||||||
|
+ " %d).",
|
||||||
|
u.length, v.length));
|
||||||
|
}
|
||||||
|
double dotProduct = 0.0;
|
||||||
|
double normU = 0.0;
|
||||||
|
double normV = 0.0;
|
||||||
|
for (int i = 0; i < u.length; i++) {
|
||||||
|
dotProduct += u[i] * v[i];
|
||||||
|
normU += u[i] * u[i];
|
||||||
|
normV += v[i] * v[i];
|
||||||
|
}
|
||||||
|
if (normU <= 0 || normV <= 0) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Cannot compute cosine similarity on embedding with 0 norm.");
|
||||||
|
}
|
||||||
|
return dotProduct / Math.sqrt(normU * normV);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static double computeQuantized(byte[] u, byte[] v) {
|
||||||
|
if (u.length != v.length) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
String.format(
|
||||||
|
"Cannot compute cosine similarity between embeddings of different sizes (%d vs."
|
||||||
|
+ " %d).",
|
||||||
|
u.length, v.length));
|
||||||
|
}
|
||||||
|
double dotProduct = 0.0;
|
||||||
|
double normU = 0.0;
|
||||||
|
double normV = 0.0;
|
||||||
|
for (int i = 0; i < u.length; i++) {
|
||||||
|
dotProduct += u[i] * v[i];
|
||||||
|
normU += u[i] * u[i];
|
||||||
|
normV += v[i] * v[i];
|
||||||
|
}
|
||||||
|
if (normU <= 0 || normV <= 0) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Cannot compute cosine similarity on embedding with 0 norm.");
|
||||||
|
}
|
||||||
|
return dotProduct / Math.sqrt(normU * normV);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="com.google.mediapipe.tasks.components.utilstest"
|
||||||
|
android:versionCode="1"
|
||||||
|
android:versionName="1.0" >
|
||||||
|
|
||||||
|
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||||
|
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||||
|
|
||||||
|
<uses-sdk android:minSdkVersion="24"
|
||||||
|
android:targetSdkVersion="30" />
|
||||||
|
|
||||||
|
<application
|
||||||
|
android:label="utilstest"
|
||||||
|
android:name="android.support.multidex.MultiDexApplication"
|
||||||
|
android:taskAffinity="">
|
||||||
|
<uses-library android:name="android.test.runner" />
|
||||||
|
</application>
|
||||||
|
|
||||||
|
<instrumentation
|
||||||
|
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||||
|
android:targetPackage="com.google.mediapipe.tasks.components.utilstest" />
|
||||||
|
|
||||||
|
</manifest>
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
# TODO: Enable this in OSS
|
|
@ -0,0 +1,116 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.components.utils;
|
||||||
|
|
||||||
|
import static com.google.common.truth.Truth.assertThat;
|
||||||
|
import static org.junit.Assert.assertThrows;
|
||||||
|
|
||||||
|
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.Embedding;
|
||||||
|
import java.util.Optional;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
|
||||||
|
/** Tests for {@link CosineSimilarity}. */
|
||||||
|
@RunWith(AndroidJUnit4.class)
|
||||||
|
public final class CosineSimilarityTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void failsWithQuantizedAndFloatEmbeddings() {
|
||||||
|
Embedding u =
|
||||||
|
Embedding.create(
|
||||||
|
new float[] {1.0f}, new byte[0], /*headIndex=*/ 0, /*headName=*/ Optional.empty());
|
||||||
|
Embedding v =
|
||||||
|
Embedding.create(
|
||||||
|
new float[0], new byte[] {1}, /*headIndex=*/ 0, /*headName=*/ Optional.empty());
|
||||||
|
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> CosineSimilarity.compute(u, v));
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("Cannot compute cosine similarity between quantized and float embeddings");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void failsWithZeroNorm() {
|
||||||
|
Embedding u =
|
||||||
|
Embedding.create(
|
||||||
|
new float[] {0.0f}, new byte[0], /*headIndex=*/ 0, /*headName=*/ Optional.empty());
|
||||||
|
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> CosineSimilarity.compute(u, u));
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("Cannot compute cosine similarity on embedding with 0 norm");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void failsWithDifferentSizes() {
|
||||||
|
Embedding u =
|
||||||
|
Embedding.create(
|
||||||
|
new float[] {1.0f, 2.0f},
|
||||||
|
new byte[0],
|
||||||
|
/*headIndex=*/ 0,
|
||||||
|
/*headName=*/ Optional.empty());
|
||||||
|
Embedding v =
|
||||||
|
Embedding.create(
|
||||||
|
new float[] {1.0f, 2.0f, 3.0f},
|
||||||
|
new byte[0],
|
||||||
|
/*headIndex=*/ 0,
|
||||||
|
/*headName=*/ Optional.empty());
|
||||||
|
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> CosineSimilarity.compute(u, v));
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("Cannot compute cosine similarity between embeddings of different sizes");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void succeedsWithFloatEmbeddings() {
|
||||||
|
Embedding u =
|
||||||
|
Embedding.create(
|
||||||
|
new float[] {1.0f, 0.0f, 0.0f, 0.0f},
|
||||||
|
new byte[0],
|
||||||
|
/*headIndex=*/ 0,
|
||||||
|
/*headName=*/ Optional.empty());
|
||||||
|
Embedding v =
|
||||||
|
Embedding.create(
|
||||||
|
new float[] {0.5f, 0.5f, 0.5f, 0.5f},
|
||||||
|
new byte[0],
|
||||||
|
/*headIndex=*/ 0,
|
||||||
|
/*headName=*/ Optional.empty());
|
||||||
|
|
||||||
|
assertThat(CosineSimilarity.compute(u, v)).isEqualTo(0.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void succeedsWithQuantizedEmbeddings() {
|
||||||
|
Embedding u =
|
||||||
|
Embedding.create(
|
||||||
|
new float[0],
|
||||||
|
new byte[] {127, 0, 0, 0},
|
||||||
|
/*headIndex=*/ 0,
|
||||||
|
/*headName=*/ Optional.empty());
|
||||||
|
Embedding v =
|
||||||
|
Embedding.create(
|
||||||
|
new float[0],
|
||||||
|
new byte[] {-128, 0, 0, 0},
|
||||||
|
/*headIndex=*/ 0,
|
||||||
|
/*headName=*/ Optional.empty());
|
||||||
|
|
||||||
|
assertThat(CosineSimilarity.compute(u, v)).isEqualTo(-1.0);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user