Internal MediaPipe Tasks change.
PiperOrigin-RevId: 514150403
This commit is contained in:
parent
763842289a
commit
a43be73ee4
|
@ -28,6 +28,7 @@ import org.junit.runner.RunWith;
|
||||||
public class TextEmbedderTest {
|
public class TextEmbedderTest {
|
||||||
private static final String BERT_MODEL_FILE = "mobilebert_embedding_with_metadata.tflite";
|
private static final String BERT_MODEL_FILE = "mobilebert_embedding_with_metadata.tflite";
|
||||||
private static final String REGEX_MODEL_FILE = "regex_one_embedding_with_metadata.tflite";
|
private static final String REGEX_MODEL_FILE = "regex_one_embedding_with_metadata.tflite";
|
||||||
|
private static final String USE_MODEL_FILE = "universal_sentence_encoder_qa_with_metadata.tflite";
|
||||||
|
|
||||||
private static final double DOUBLE_DIFF_TOLERANCE = 1e-4;
|
private static final double DOUBLE_DIFF_TOLERANCE = 1e-4;
|
||||||
private static final float FLOAT_DIFF_TOLERANCE = 1e-4f;
|
private static final float FLOAT_DIFF_TOLERANCE = 1e-4f;
|
||||||
|
@ -70,6 +71,32 @@ public class TextEmbedderTest {
|
||||||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.968879);
|
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.968879);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void embed_succeedsWithUSE() throws Exception {
|
||||||
|
TextEmbedder textEmbedder =
|
||||||
|
TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), USE_MODEL_FILE);
|
||||||
|
|
||||||
|
TextEmbedderResult result0 = textEmbedder.embed("it's a charming and often affecting journey");
|
||||||
|
assertThat(result0.embeddingResult().embeddings().size()).isEqualTo(1);
|
||||||
|
assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(100);
|
||||||
|
assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()[0])
|
||||||
|
.isWithin(FLOAT_DIFF_TOLERANCE)
|
||||||
|
.of(1.422951f);
|
||||||
|
TextEmbedderResult result1 = textEmbedder.embed("what a great and fantastic trip");
|
||||||
|
assertThat(result1.embeddingResult().embeddings().size()).isEqualTo(1);
|
||||||
|
assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(100);
|
||||||
|
assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()[0])
|
||||||
|
.isWithin(FLOAT_DIFF_TOLERANCE)
|
||||||
|
.of(1.404664f);
|
||||||
|
|
||||||
|
// Check cosine similarity.
|
||||||
|
double similarity =
|
||||||
|
TextEmbedder.cosineSimilarity(
|
||||||
|
result0.embeddingResult().embeddings().get(0),
|
||||||
|
result1.embeddingResult().embeddings().get(0));
|
||||||
|
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.851961);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void embed_succeedsWithRegex() throws Exception {
|
public void embed_succeedsWithRegex() throws Exception {
|
||||||
TextEmbedder textEmbedder =
|
TextEmbedder textEmbedder =
|
||||||
|
@ -115,4 +142,24 @@ public class TextEmbedderTest {
|
||||||
result1.embeddingResult().embeddings().get(0));
|
result1.embeddingResult().embeddings().get(0));
|
||||||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.3477488707202946);
|
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.3477488707202946);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_succeedsWithUSEAndDifferentThemes() throws Exception {
|
||||||
|
TextEmbedder textEmbedder =
|
||||||
|
TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), USE_MODEL_FILE);
|
||||||
|
|
||||||
|
TextEmbedderResult result0 =
|
||||||
|
textEmbedder.embed(
|
||||||
|
"When you go to this restaurant, they hold the pancake upside-down before they hand "
|
||||||
|
+ "it to you. It's a great gimmick.");
|
||||||
|
TextEmbedderResult result1 =
|
||||||
|
textEmbedder.embed("Let\'s make a plan to steal the declaration of independence.'");
|
||||||
|
|
||||||
|
// Check cosine similarity.
|
||||||
|
double similarity =
|
||||||
|
TextEmbedder.cosineSimilarity(
|
||||||
|
result0.embeddingResult().embeddings().get(0),
|
||||||
|
result1.embeddingResult().embeddings().get(0));
|
||||||
|
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.7835510599396296);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
2
mediapipe/tasks/testdata/text/BUILD
vendored
2
mediapipe/tasks/testdata/text/BUILD
vendored
|
@ -103,5 +103,5 @@ filegroup(
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "universal_sentence_encoder_qa",
|
name = "universal_sentence_encoder_qa",
|
||||||
data = ["universal_sentence_encoder_qa_with_metadata.tflite"],
|
srcs = ["universal_sentence_encoder_qa_with_metadata.tflite"],
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user