Add Text Embedder tests for text with different themes
PiperOrigin-RevId: 506023265
This commit is contained in:
parent
2c4dece023
commit
be3bddc620
|
@ -139,5 +139,32 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) {
|
|||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
|
||||
auto options = std::make_unique<TextEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileBert);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
|
||||
TextEmbedder::Create(std::move(options)));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
TextEmbedderResult result0,
|
||||
text_embedder->Embed("When you go to this restaurant, they hold the "
|
||||
"pancake upside-down before they hand it "
|
||||
"to you. It's a great gimmick."));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
TextEmbedderResult result1,
|
||||
text_embedder->Embed(
|
||||
"Let's make a plan to steal the declaration of independence."));
|
||||
|
||||
// Check cosine similarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
||||
result1.embeddings[0]));
|
||||
// TODO: The similarity should likely be lower
|
||||
EXPECT_NEAR(similarity, 0.98088, kSimilarityTolerancy);
|
||||
|
||||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe::tasks::text::text_embedder
|
||||
|
|
|
@ -95,4 +95,24 @@ public class TextEmbedderTest {
|
|||
result1.embeddingResult().embeddings().get(0));
|
||||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999937);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void classify_succeedsWithBertAndDifferentThemes() throws Exception {
|
||||
TextEmbedder textEmbedder =
|
||||
TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
|
||||
|
||||
TextEmbedderResult result0 =
|
||||
textEmbedder.embed(
|
||||
"When you go to this restaurant, they hold the pancake upside-down before they hand "
|
||||
+ "it to you. It's a great gimmick.");
|
||||
TextEmbedderResult result1 =
|
||||
textEmbedder.embed("Let\'s make a plan to steal the declaration of independence.'");
|
||||
|
||||
// Check cosine similarity.
|
||||
double similarity =
|
||||
TextEmbedder.cosineSimilarity(
|
||||
result0.embeddingResult().embeddings().get(0),
|
||||
result1.embeddingResult().embeddings().get(0));
|
||||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.3477488707202946);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -192,6 +192,36 @@ class TextEmbedderTest(parameterized.TestCase):
|
|||
self._check_embedding_value(result1, expected_result1_value)
|
||||
self._check_cosine_similarity(result0, result1, expected_similarity)
|
||||
|
||||
def test_embed_with_mobile_bert_and_different_themes(self):
|
||||
# Creates embedder.
|
||||
model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE)
|
||||
)
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
options = _TextEmbedderOptions(base_options=base_options)
|
||||
embedder = _TextEmbedder.create_from_options(options)
|
||||
|
||||
# Extracts both embeddings.
|
||||
text0 = (
|
||||
'When you go to this restaurant, they hold the pancake upside-down '
|
||||
"before they hand it to you. It's a great gimmick."
|
||||
)
|
||||
result0 = embedder.embed(text0)
|
||||
|
||||
text1 = "Let's make a plan to steal the declaration of independence."
|
||||
result1 = embedder.embed(text1)
|
||||
|
||||
similarity = _TextEmbedder.cosine_similarity(
|
||||
result0.embeddings[0], result1.embeddings[0]
|
||||
)
|
||||
|
||||
# TODO: The similarity should likely be lower
|
||||
self.assertAlmostEqual(similarity, 0.980880, delta=_SIMILARITY_TOLERANCE)
|
||||
|
||||
# Closes the embedder explicitly when the embedder is not used in
|
||||
# a context.
|
||||
embedder.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in New Issue
Block a user