Add Text Embedder tests for text with different themes
PiperOrigin-RevId: 506023265
This commit is contained in:
parent
2c4dece023
commit
be3bddc620
|
@ -139,5 +139,32 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) {
|
||||||
MP_ASSERT_OK(text_embedder->Close());
|
MP_ASSERT_OK(text_embedder->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
|
||||||
|
auto options = std::make_unique<TextEmbedderOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kMobileBert);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
|
||||||
|
TextEmbedder::Create(std::move(options)));
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
TextEmbedderResult result0,
|
||||||
|
text_embedder->Embed("When you go to this restaurant, they hold the "
|
||||||
|
"pancake upside-down before they hand it "
|
||||||
|
"to you. It's a great gimmick."));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
TextEmbedderResult result1,
|
||||||
|
text_embedder->Embed(
|
||||||
|
"Let's make a plan to steal the declaration of independence."));
|
||||||
|
|
||||||
|
// Check cosine similarity.
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
||||||
|
result1.embeddings[0]));
|
||||||
|
// TODO: The similarity should likely be lower
|
||||||
|
EXPECT_NEAR(similarity, 0.98088, kSimilarityTolerancy);
|
||||||
|
|
||||||
|
MP_ASSERT_OK(text_embedder->Close());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe::tasks::text::text_embedder
|
} // namespace mediapipe::tasks::text::text_embedder
|
||||||
|
|
|
@ -95,4 +95,24 @@ public class TextEmbedderTest {
|
||||||
result1.embeddingResult().embeddings().get(0));
|
result1.embeddingResult().embeddings().get(0));
|
||||||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999937);
|
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999937);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_succeedsWithBertAndDifferentThemes() throws Exception {
|
||||||
|
TextEmbedder textEmbedder =
|
||||||
|
TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
|
||||||
|
|
||||||
|
TextEmbedderResult result0 =
|
||||||
|
textEmbedder.embed(
|
||||||
|
"When you go to this restaurant, they hold the pancake upside-down before they hand "
|
||||||
|
+ "it to you. It's a great gimmick.");
|
||||||
|
TextEmbedderResult result1 =
|
||||||
|
textEmbedder.embed("Let\'s make a plan to steal the declaration of independence.'");
|
||||||
|
|
||||||
|
// Check cosine similarity.
|
||||||
|
double similarity =
|
||||||
|
TextEmbedder.cosineSimilarity(
|
||||||
|
result0.embeddingResult().embeddings().get(0),
|
||||||
|
result1.embeddingResult().embeddings().get(0));
|
||||||
|
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.3477488707202946);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -192,6 +192,36 @@ class TextEmbedderTest(parameterized.TestCase):
|
||||||
self._check_embedding_value(result1, expected_result1_value)
|
self._check_embedding_value(result1, expected_result1_value)
|
||||||
self._check_cosine_similarity(result0, result1, expected_similarity)
|
self._check_cosine_similarity(result0, result1, expected_similarity)
|
||||||
|
|
||||||
|
def test_embed_with_mobile_bert_and_different_themes(self):
|
||||||
|
# Creates embedder.
|
||||||
|
model_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE)
|
||||||
|
)
|
||||||
|
base_options = _BaseOptions(model_asset_path=model_path)
|
||||||
|
options = _TextEmbedderOptions(base_options=base_options)
|
||||||
|
embedder = _TextEmbedder.create_from_options(options)
|
||||||
|
|
||||||
|
# Extracts both embeddings.
|
||||||
|
text0 = (
|
||||||
|
'When you go to this restaurant, they hold the pancake upside-down '
|
||||||
|
"before they hand it to you. It's a great gimmick."
|
||||||
|
)
|
||||||
|
result0 = embedder.embed(text0)
|
||||||
|
|
||||||
|
text1 = "Let's make a plan to steal the declaration of independence."
|
||||||
|
result1 = embedder.embed(text1)
|
||||||
|
|
||||||
|
similarity = _TextEmbedder.cosine_similarity(
|
||||||
|
result0.embeddings[0], result1.embeddings[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: The similarity should likely be lower
|
||||||
|
self.assertAlmostEqual(similarity, 0.980880, delta=_SIMILARITY_TOLERANCE)
|
||||||
|
|
||||||
|
# Closes the embedder explicitly when the embedder is not used in
|
||||||
|
# a context.
|
||||||
|
embedder.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user