Internal MediaPipe Tasks change.

PiperOrigin-RevId: 514150403
This commit is contained in:
MediaPipe Team 2023-03-04 22:28:21 -08:00 committed by Copybara-Service
parent 763842289a
commit a43be73ee4
2 changed files with 48 additions and 1 deletions

View File

@ -28,6 +28,7 @@ import org.junit.runner.RunWith;
public class TextEmbedderTest { public class TextEmbedderTest {
private static final String BERT_MODEL_FILE = "mobilebert_embedding_with_metadata.tflite"; private static final String BERT_MODEL_FILE = "mobilebert_embedding_with_metadata.tflite";
private static final String REGEX_MODEL_FILE = "regex_one_embedding_with_metadata.tflite"; private static final String REGEX_MODEL_FILE = "regex_one_embedding_with_metadata.tflite";
private static final String USE_MODEL_FILE = "universal_sentence_encoder_qa_with_metadata.tflite";
private static final double DOUBLE_DIFF_TOLERANCE = 1e-4; private static final double DOUBLE_DIFF_TOLERANCE = 1e-4;
private static final float FLOAT_DIFF_TOLERANCE = 1e-4f; private static final float FLOAT_DIFF_TOLERANCE = 1e-4f;
@ -70,6 +71,32 @@ public class TextEmbedderTest {
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.968879); assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.968879);
} }
@Test
public void embed_succeedsWithUSE() throws Exception {
TextEmbedder textEmbedder =
TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), USE_MODEL_FILE);
TextEmbedderResult result0 = textEmbedder.embed("it's a charming and often affecting journey");
assertThat(result0.embeddingResult().embeddings().size()).isEqualTo(1);
assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(100);
assertThat(result0.embeddingResult().embeddings().get(0).floatEmbedding()[0])
.isWithin(FLOAT_DIFF_TOLERANCE)
.of(1.422951f);
TextEmbedderResult result1 = textEmbedder.embed("what a great and fantastic trip");
assertThat(result1.embeddingResult().embeddings().size()).isEqualTo(1);
assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(100);
assertThat(result1.embeddingResult().embeddings().get(0).floatEmbedding()[0])
.isWithin(FLOAT_DIFF_TOLERANCE)
.of(1.404664f);
// Check cosine similarity.
double similarity =
TextEmbedder.cosineSimilarity(
result0.embeddingResult().embeddings().get(0),
result1.embeddingResult().embeddings().get(0));
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.851961);
}
@Test @Test
public void embed_succeedsWithRegex() throws Exception { public void embed_succeedsWithRegex() throws Exception {
TextEmbedder textEmbedder = TextEmbedder textEmbedder =
@ -115,4 +142,24 @@ public class TextEmbedderTest {
result1.embeddingResult().embeddings().get(0)); result1.embeddingResult().embeddings().get(0));
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.3477488707202946); assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.3477488707202946);
} }
@Test
public void classify_succeedsWithUSEAndDifferentThemes() throws Exception {
TextEmbedder textEmbedder =
TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), USE_MODEL_FILE);
TextEmbedderResult result0 =
textEmbedder.embed(
"When you go to this restaurant, they hold the pancake upside-down before they hand "
+ "it to you. It's a great gimmick.");
TextEmbedderResult result1 =
textEmbedder.embed("Let\'s make a plan to steal the declaration of independence.'");
// Check cosine similarity.
double similarity =
TextEmbedder.cosineSimilarity(
result0.embeddingResult().embeddings().get(0),
result1.embeddingResult().embeddings().get(0));
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.7835510599396296);
}
} }

View File

@ -103,5 +103,5 @@ filegroup(
filegroup( filegroup(
name = "universal_sentence_encoder_qa", name = "universal_sentence_encoder_qa",
data = ["universal_sentence_encoder_qa_with_metadata.tflite"], srcs = ["universal_sentence_encoder_qa_with_metadata.tflite"],
) )