Add option, result types and utils for Java embedders.
PiperOrigin-RevId: 487615327
This commit is contained in:
parent
aeb2466844
commit
8ec4427bd7
|
@ -89,3 +89,30 @@ filegroup(
|
|||
srcs = glob(["*.java"]),
|
||||
visibility = ["//mediapipe/tasks/java/com/google/mediapipe/tasks/core:__subpackages__"],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "embedding",
|
||||
srcs = ["Embedding.java"],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "embeddingresult",
|
||||
srcs = ["EmbeddingResult.java"],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
deps = [
|
||||
":embedding",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.containers;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Represents the embedding for a given embedder head. Typically used in embedding tasks.
|
||||
*
|
||||
* <p>One and only one of the two 'floatEmbedding' and 'quantizedEmbedding' will contain data, based
|
||||
* on whether or not the embedder was configured to perform scala quantization.
|
||||
*/
|
||||
@AutoValue
|
||||
public abstract class Embedding {
|
||||
|
||||
/**
|
||||
* Creates an {@link Embedding} instance.
|
||||
*
|
||||
* @param floatEmbedding the floating-point embedding
|
||||
* @param quantizedEmbedding the quantized embedding.
|
||||
* @param headIndex the index of the embedder head.
|
||||
* @param headName the optional name of the embedder head.
|
||||
*/
|
||||
public static Embedding create(
|
||||
float[] floatEmbedding, byte[] quantizedEmbedding, int headIndex, Optional<String> headName) {
|
||||
return new AutoValue_Embedding(floatEmbedding, quantizedEmbedding, headIndex, headName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link Embedding} object from an {@link EmbeddingsProto.Embedding} protobuf message.
|
||||
*
|
||||
* @param proto the {@link EmbeddingsProto.Embedding} protobuf message to convert.
|
||||
*/
|
||||
public static Embedding createFromProto(EmbeddingsProto.Embedding proto) {
|
||||
float[] floatEmbedding;
|
||||
if (proto.hasFloatEmbedding()) {
|
||||
floatEmbedding = new float[proto.getFloatEmbedding().getValuesCount()];
|
||||
for (int i = 0; i < floatEmbedding.length; i++) {
|
||||
floatEmbedding[i] = proto.getFloatEmbedding().getValues(i);
|
||||
}
|
||||
} else {
|
||||
floatEmbedding = new float[0];
|
||||
}
|
||||
return Embedding.create(
|
||||
floatEmbedding,
|
||||
proto.hasQuantizedEmbedding()
|
||||
? proto.getQuantizedEmbedding().getValues().toByteArray()
|
||||
: new byte[0],
|
||||
proto.getHeadIndex(),
|
||||
proto.hasHeadName() ? Optional.of(proto.getHeadName()) : Optional.empty());
|
||||
}
|
||||
|
||||
/**
|
||||
* Floating-point embedding.
|
||||
*
|
||||
* <p>Empty if the embedder was configured to perform scalar quantization.
|
||||
*/
|
||||
public abstract float[] floatEmbedding();
|
||||
|
||||
/**
|
||||
* Quantized embedding.
|
||||
*
|
||||
* <p>Empty if the embedder was not configured to perform scalar quantization.
|
||||
*/
|
||||
public abstract byte[] quantizedEmbedding();
|
||||
|
||||
/**
|
||||
* The index of the embedder head these entries refer to. This is useful for multi-head models.
|
||||
*/
|
||||
public abstract int headIndex();
|
||||
|
||||
/** The optional name of the embedder head, which is the corresponding tensor metadata name. */
|
||||
public abstract Optional<String> headName();
|
||||
}
|
|
@ -0,0 +1,69 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.containers;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/** Represents the embedding results of a model. Typically used as a result for embedding tasks. */
|
||||
@AutoValue
|
||||
public abstract class EmbeddingResult {
|
||||
|
||||
/**
|
||||
* Creates a {@link EmbeddingResult} instance.
|
||||
*
|
||||
* @param embeddings the list of {@link Embedding} objects containing the embedding for each head
|
||||
* of the model.
|
||||
* @param timestampMs the optional timestamp (in milliseconds) of the start of the chunk of data
|
||||
* corresponding to these results.
|
||||
*/
|
||||
public static EmbeddingResult create(List<Embedding> embeddings, Optional<Long> timestampMs) {
|
||||
return new AutoValue_EmbeddingResult(Collections.unmodifiableList(embeddings), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link EmbeddingResult} object from a {@link EmbeddingsProto.EmbeddingResult}
|
||||
* protobuf message.
|
||||
*
|
||||
* @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert.
|
||||
*/
|
||||
public static EmbeddingResult createFromProto(EmbeddingsProto.EmbeddingResult proto) {
|
||||
List<Embedding> embeddings = new ArrayList<>();
|
||||
for (EmbeddingsProto.Embedding embeddingProto : proto.getEmbeddingsList()) {
|
||||
embeddings.add(Embedding.createFromProto(embeddingProto));
|
||||
}
|
||||
Optional<Long> timestampMs =
|
||||
proto.hasTimestampMs() ? Optional.of(proto.getTimestampMs()) : Optional.empty();
|
||||
return create(embeddings, timestampMs);
|
||||
}
|
||||
|
||||
/** The embedding results for each head of the model. */
|
||||
public abstract List<Embedding> embeddings();
|
||||
|
||||
/**
|
||||
* The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to
|
||||
* these results.
|
||||
*
|
||||
* <p>This is only used for embedding extraction on time series (e.g. audio embedder). In these
|
||||
* use cases, the amount of data to process might exceed the maximum size that the model can
|
||||
* process: to solve this, the input data is split into multiple chunks starting at different
|
||||
* timestamps.
|
||||
*/
|
||||
public abstract Optional<Long> timestampMs();
|
||||
}
|
|
@ -29,6 +29,19 @@ android_library(
|
|||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "embedderoptions",
|
||||
srcs = ["EmbedderOptions.java"],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
# Expose the java source files for building mediapipe tasks core AAR.
|
||||
filegroup(
|
||||
name = "java_src",
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.processors;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto;
|
||||
|
||||
/** Embedder options shared across MediaPipe Java embedding tasks. */
|
||||
@AutoValue
|
||||
public abstract class EmbedderOptions {
|
||||
|
||||
/** Builder for {@link EmbedderOptions} */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/**
|
||||
* Sets whether L2 normalization should be performed on the returned embeddings. Use this option
|
||||
* only if the model does not already contain a native <code>L2_NORMALIZATION</code> TF Lite Op.
|
||||
* In most cases, this is already the case and L2 norm is thus achieved through TF Lite
|
||||
* inference.
|
||||
*
|
||||
* <p>False by default.
|
||||
*/
|
||||
public abstract Builder setL2Normalize(boolean l2Normalize);
|
||||
|
||||
/**
|
||||
* Sets whether the returned embedding should be quantized to bytes via scalar quantization.
|
||||
* Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is guaranteed
|
||||
* to have value in <code>[-1.0, 1.0]</code>. Use {@link #setL2Normalize(boolean)} if this is
|
||||
* not the case.
|
||||
*
|
||||
* <p>False by default.
|
||||
*/
|
||||
public abstract Builder setQuantize(boolean quantize);
|
||||
|
||||
public abstract EmbedderOptions build();
|
||||
}
|
||||
|
||||
public abstract boolean l2Normalize();
|
||||
|
||||
public abstract boolean quantize();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_EmbedderOptions.Builder().setL2Normalize(false).setQuantize(false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts an {@link EmbedderOptions} object to an {@link EmbedderOptionsProto.EmbedderOptions}
|
||||
* protobuf message.
|
||||
*/
|
||||
public EmbedderOptionsProto.EmbedderOptions convertToProto() {
|
||||
return EmbedderOptionsProto.EmbedderOptions.newBuilder()
|
||||
.setL2Normalize(l2Normalize())
|
||||
.setQuantize(quantize())
|
||||
.build();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
android_library(
|
||||
name = "cosinesimilarity",
|
||||
srcs = ["CosineSimilarity.java"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,88 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.utils;
|
||||
|
||||
import com.google.mediapipe.tasks.components.containers.Embedding;
|
||||
|
||||
/** Utility class for computing cosine similarity between {@link Embedding} objects. */
|
||||
public class CosineSimilarity {
|
||||
|
||||
// Non-instantiable class.
|
||||
private CosineSimilarity() {}
|
||||
|
||||
/**
|
||||
* Computes <a href="https://en.wikipedia.org/wiki/Cosine_similarity">cosine similarity</a>
|
||||
* between two {@link Embedding} objects.
|
||||
*
|
||||
* @throws IllegalArgumentException if the embeddings are of different types (float vs.
|
||||
* quantized), have different sizes, or have an L2-norm of 0.
|
||||
*/
|
||||
public static double compute(Embedding u, Embedding v) {
|
||||
if (u.floatEmbedding().length > 0 && v.floatEmbedding().length > 0) {
|
||||
return computeFloat(u.floatEmbedding(), v.floatEmbedding());
|
||||
}
|
||||
if (u.quantizedEmbedding().length > 0 && v.quantizedEmbedding().length > 0) {
|
||||
return computeQuantized(u.quantizedEmbedding(), v.quantizedEmbedding());
|
||||
}
|
||||
throw new IllegalArgumentException(
|
||||
"Cannot compute cosine similarity between quantized and float embeddings.");
|
||||
}
|
||||
|
||||
private static double computeFloat(float[] u, float[] v) {
|
||||
if (u.length != v.length) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
"Cannot compute cosine similarity between embeddings of different sizes (%d vs."
|
||||
+ " %d).",
|
||||
u.length, v.length));
|
||||
}
|
||||
double dotProduct = 0.0;
|
||||
double normU = 0.0;
|
||||
double normV = 0.0;
|
||||
for (int i = 0; i < u.length; i++) {
|
||||
dotProduct += u[i] * v[i];
|
||||
normU += u[i] * u[i];
|
||||
normV += v[i] * v[i];
|
||||
}
|
||||
if (normU <= 0 || normV <= 0) {
|
||||
throw new IllegalArgumentException(
|
||||
"Cannot compute cosine similarity on embedding with 0 norm.");
|
||||
}
|
||||
return dotProduct / Math.sqrt(normU * normV);
|
||||
}
|
||||
|
||||
private static double computeQuantized(byte[] u, byte[] v) {
|
||||
if (u.length != v.length) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
"Cannot compute cosine similarity between embeddings of different sizes (%d vs."
|
||||
+ " %d).",
|
||||
u.length, v.length));
|
||||
}
|
||||
double dotProduct = 0.0;
|
||||
double normU = 0.0;
|
||||
double normV = 0.0;
|
||||
for (int i = 0; i < u.length; i++) {
|
||||
dotProduct += u[i] * v[i];
|
||||
normU += u[i] * u[i];
|
||||
normV += v[i] * v[i];
|
||||
}
|
||||
if (normU <= 0 || normV <= 0) {
|
||||
throw new IllegalArgumentException(
|
||||
"Cannot compute cosine similarity on embedding with 0 norm.");
|
||||
}
|
||||
return dotProduct / Math.sqrt(normU * normV);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.components.utilstest"
|
||||
android:versionCode="1"
|
||||
android:versionName="1.0" >
|
||||
|
||||
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
<application
|
||||
android:label="utilstest"
|
||||
android:name="android.support.multidex.MultiDexApplication"
|
||||
android:taskAffinity="">
|
||||
<uses-library android:name="android.test.runner" />
|
||||
</application>
|
||||
|
||||
<instrumentation
|
||||
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||
android:targetPackage="com.google.mediapipe.tasks.components.utilstest" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# TODO: Enable this in OSS
|
|
@ -0,0 +1,116 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.utils;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.assertThrows;
|
||||
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import com.google.mediapipe.tasks.components.containers.Embedding;
|
||||
import java.util.Optional;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
/** Tests for {@link CosineSimilarity}. */
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public final class CosineSimilarityTest {
|
||||
|
||||
@Test
|
||||
public void failsWithQuantizedAndFloatEmbeddings() {
|
||||
Embedding u =
|
||||
Embedding.create(
|
||||
new float[] {1.0f}, new byte[0], /*headIndex=*/ 0, /*headName=*/ Optional.empty());
|
||||
Embedding v =
|
||||
Embedding.create(
|
||||
new float[0], new byte[] {1}, /*headIndex=*/ 0, /*headName=*/ Optional.empty());
|
||||
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(IllegalArgumentException.class, () -> CosineSimilarity.compute(u, v));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("Cannot compute cosine similarity between quantized and float embeddings");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void failsWithZeroNorm() {
|
||||
Embedding u =
|
||||
Embedding.create(
|
||||
new float[] {0.0f}, new byte[0], /*headIndex=*/ 0, /*headName=*/ Optional.empty());
|
||||
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(IllegalArgumentException.class, () -> CosineSimilarity.compute(u, u));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("Cannot compute cosine similarity on embedding with 0 norm");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void failsWithDifferentSizes() {
|
||||
Embedding u =
|
||||
Embedding.create(
|
||||
new float[] {1.0f, 2.0f},
|
||||
new byte[0],
|
||||
/*headIndex=*/ 0,
|
||||
/*headName=*/ Optional.empty());
|
||||
Embedding v =
|
||||
Embedding.create(
|
||||
new float[] {1.0f, 2.0f, 3.0f},
|
||||
new byte[0],
|
||||
/*headIndex=*/ 0,
|
||||
/*headName=*/ Optional.empty());
|
||||
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(IllegalArgumentException.class, () -> CosineSimilarity.compute(u, v));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("Cannot compute cosine similarity between embeddings of different sizes");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void succeedsWithFloatEmbeddings() {
|
||||
Embedding u =
|
||||
Embedding.create(
|
||||
new float[] {1.0f, 0.0f, 0.0f, 0.0f},
|
||||
new byte[0],
|
||||
/*headIndex=*/ 0,
|
||||
/*headName=*/ Optional.empty());
|
||||
Embedding v =
|
||||
Embedding.create(
|
||||
new float[] {0.5f, 0.5f, 0.5f, 0.5f},
|
||||
new byte[0],
|
||||
/*headIndex=*/ 0,
|
||||
/*headName=*/ Optional.empty());
|
||||
|
||||
assertThat(CosineSimilarity.compute(u, v)).isEqualTo(0.5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void succeedsWithQuantizedEmbeddings() {
|
||||
Embedding u =
|
||||
Embedding.create(
|
||||
new float[0],
|
||||
new byte[] {127, 0, 0, 0},
|
||||
/*headIndex=*/ 0,
|
||||
/*headName=*/ Optional.empty());
|
||||
Embedding v =
|
||||
Embedding.create(
|
||||
new float[0],
|
||||
new byte[] {-128, 0, 0, 0},
|
||||
/*headIndex=*/ 0,
|
||||
/*headName=*/ Optional.empty());
|
||||
|
||||
assertThat(CosineSimilarity.compute(u, v)).isEqualTo(-1.0);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user