Add Text Embedder tests for text with different themes

PiperOrigin-RevId: 506023265
This commit is contained in:
Sebastian Schmidt 2023-01-31 09:18:07 -08:00 committed by Copybara-Service
parent 2c4dece023
commit be3bddc620
3 changed files with 77 additions and 0 deletions

View File

@ -139,5 +139,32 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) {
MP_ASSERT_OK(text_embedder->Close()); MP_ASSERT_OK(text_embedder->Close());
} }
TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileBert);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
TextEmbedderResult result0,
text_embedder->Embed("When you go to this restaurant, they hold the "
"pancake upside-down before they hand it "
"to you. It's a great gimmick."));
MP_ASSERT_OK_AND_ASSIGN(
TextEmbedderResult result1,
text_embedder->Embed(
"Let's make a plan to steal the declaration of independence."));
// Check cosine similarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
result1.embeddings[0]));
// TODO: The similarity should likely be lower
EXPECT_NEAR(similarity, 0.98088, kSimilarityTolerancy);
MP_ASSERT_OK(text_embedder->Close());
}
} // namespace } // namespace
} // namespace mediapipe::tasks::text::text_embedder } // namespace mediapipe::tasks::text::text_embedder

View File

@ -95,4 +95,24 @@ public class TextEmbedderTest {
result1.embeddingResult().embeddings().get(0)); result1.embeddingResult().embeddings().get(0));
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999937); assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999937);
} }
@Test
public void classify_succeedsWithBertAndDifferentThemes() throws Exception {
TextEmbedder textEmbedder =
TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
TextEmbedderResult result0 =
textEmbedder.embed(
"When you go to this restaurant, they hold the pancake upside-down before they hand "
+ "it to you. It's a great gimmick.");
TextEmbedderResult result1 =
textEmbedder.embed("Let\'s make a plan to steal the declaration of independence.'");
// Check cosine similarity.
double similarity =
TextEmbedder.cosineSimilarity(
result0.embeddingResult().embeddings().get(0),
result1.embeddingResult().embeddings().get(0));
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.3477488707202946);
}
} }

View File

@ -192,6 +192,36 @@ class TextEmbedderTest(parameterized.TestCase):
self._check_embedding_value(result1, expected_result1_value) self._check_embedding_value(result1, expected_result1_value)
self._check_cosine_similarity(result0, result1, expected_similarity) self._check_cosine_similarity(result0, result1, expected_similarity)
def test_embed_with_mobile_bert_and_different_themes(self):
# Creates embedder.
model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE)
)
base_options = _BaseOptions(model_asset_path=model_path)
options = _TextEmbedderOptions(base_options=base_options)
embedder = _TextEmbedder.create_from_options(options)
# Extracts both embeddings.
text0 = (
'When you go to this restaurant, they hold the pancake upside-down '
"before they hand it to you. It's a great gimmick."
)
result0 = embedder.embed(text0)
text1 = "Let's make a plan to steal the declaration of independence."
result1 = embedder.embed(text1)
similarity = _TextEmbedder.cosine_similarity(
result0.embeddings[0], result1.embeddings[0]
)
# TODO: The similarity should likely be lower
self.assertAlmostEqual(similarity, 0.980880, delta=_SIMILARITY_TOLERANCE)
# Closes the embedder explicitly when the embedder is not used in
# a context.
embedder.close()
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()