Internal MediaPipe Tasks change.
PiperOrigin-RevId: 514150403
This commit is contained in:
parent
763842289a
commit
a43be73ee4
|
@ -28,6 +28,7 @@ import org.junit.runner.RunWith;
|
|||
public class TextEmbedderTest {
|
||||
private static final String BERT_MODEL_FILE = "mobilebert_embedding_with_metadata.tflite";
|
||||
private static final String REGEX_MODEL_FILE = "regex_one_embedding_with_metadata.tflite";
|
||||
private static final String USE_MODEL_FILE = "universal_sentence_encoder_qa_with_metadata.tflite";
|
||||
|
||||
private static final double DOUBLE_DIFF_TOLERANCE = 1e-4;
|
||||
private static final float FLOAT_DIFF_TOLERANCE = 1e-4f;
|
||||
|
@ -70,6 +71,32 @@ public class TextEmbedderTest {
|
|||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.968879);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void embed_succeedsWithUSE() throws Exception {
|
||||
TextEmbedder textEmbedder =
|
||||
TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), USE_MODEL_FILE);
|
||||
|
||||
TextEmbedderResult result0 = textEmbedder.embed("it's a charming and often affecting journey");
|
||||
assertThat(result0.embeddingResult().embeddings().size()).isEqualTo(1);
|
||||
assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(100);
|
||||
assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()[0])
|
||||
.isWithin(FLOAT_DIFF_TOLERANCE)
|
||||
.of(1.422951f);
|
||||
TextEmbedderResult result1 = textEmbedder.embed("what a great and fantastic trip");
|
||||
assertThat(result1.embeddingResult().embeddings().size()).isEqualTo(1);
|
||||
assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(100);
|
||||
assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()[0])
|
||||
.isWithin(FLOAT_DIFF_TOLERANCE)
|
||||
.of(1.404664f);
|
||||
|
||||
// Check cosine similarity.
|
||||
double similarity =
|
||||
TextEmbedder.cosineSimilarity(
|
||||
result0.embeddingResult().embeddings().get(0),
|
||||
result1.embeddingResult().embeddings().get(0));
|
||||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.851961);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void embed_succeedsWithRegex() throws Exception {
|
||||
TextEmbedder textEmbedder =
|
||||
|
@ -115,4 +142,24 @@ public class TextEmbedderTest {
|
|||
result1.embeddingResult().embeddings().get(0));
|
||||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.3477488707202946);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void classify_succeedsWithUSEAndDifferentThemes() throws Exception {
|
||||
TextEmbedder textEmbedder =
|
||||
TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), USE_MODEL_FILE);
|
||||
|
||||
TextEmbedderResult result0 =
|
||||
textEmbedder.embed(
|
||||
"When you go to this restaurant, they hold the pancake upside-down before they hand "
|
||||
+ "it to you. It's a great gimmick.");
|
||||
TextEmbedderResult result1 =
|
||||
textEmbedder.embed("Let\'s make a plan to steal the declaration of independence.'");
|
||||
|
||||
// Check cosine similarity.
|
||||
double similarity =
|
||||
TextEmbedder.cosineSimilarity(
|
||||
result0.embeddingResult().embeddings().get(0),
|
||||
result1.embeddingResult().embeddings().get(0));
|
||||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.7835510599396296);
|
||||
}
|
||||
}
|
||||
|
|
2
mediapipe/tasks/testdata/text/BUILD
vendored
2
mediapipe/tasks/testdata/text/BUILD
vendored
|
@ -103,5 +103,5 @@ filegroup(
|
|||
|
||||
filegroup(
|
||||
name = "universal_sentence_encoder_qa",
|
||||
data = ["universal_sentence_encoder_qa_with_metadata.tflite"],
|
||||
srcs = ["universal_sentence_encoder_qa_with_metadata.tflite"],
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user