diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc index fa3d8af91..1ddea3358 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc @@ -139,5 +139,32 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) { MP_ASSERT_OK(text_embedder->Close()); } +TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileBert); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr text_embedder, + TextEmbedder::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN( + TextEmbedderResult result0, + text_embedder->Embed("When you go to this restaurant, they hold the " + "pancake upside-down before they hand it " + "to you. It's a great gimmick.")); + MP_ASSERT_OK_AND_ASSIGN( + TextEmbedderResult result1, + text_embedder->Embed( + "Let's make a plan to steal the declaration of independence.")); + + // Check cosine similarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], + result1.embeddings[0])); + // TODO: The similarity should likely be lower + EXPECT_NEAR(similarity, 0.98088, kSimilarityTolerancy); + + MP_ASSERT_OK(text_embedder->Close()); +} + } // namespace } // namespace mediapipe::tasks::text::text_embedder diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java index b6d53c94d..48f214770 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java @@ -95,4 +95,24 @@ public class TextEmbedderTest { result1.embeddingResult().embeddings().get(0)); assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999937); } + + @Test + public void classify_succeedsWithBertAndDifferentThemes() throws Exception { + TextEmbedder textEmbedder = + TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE); + + TextEmbedderResult result0 = + textEmbedder.embed( + "When you go to this restaurant, they hold the pancake upside-down before they hand " + + "it to you. It's a great gimmick."); + TextEmbedderResult result1 = + textEmbedder.embed("Let\'s make a plan to steal the declaration of independence.'"); + + // Check cosine similarity. + double similarity = + TextEmbedder.cosineSimilarity( + result0.embeddingResult().embeddings().get(0), + result1.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.3477488707202946); + } } diff --git a/mediapipe/tasks/python/test/text/text_embedder_test.py b/mediapipe/tasks/python/test/text/text_embedder_test.py index 1346ba373..455deba03 100644 --- a/mediapipe/tasks/python/test/text/text_embedder_test.py +++ b/mediapipe/tasks/python/test/text/text_embedder_test.py @@ -192,6 +192,36 @@ class TextEmbedderTest(parameterized.TestCase): self._check_embedding_value(result1, expected_result1_value) self._check_cosine_similarity(result0, result1, expected_similarity) + def test_embed_with_mobile_bert_and_different_themes(self): + # Creates embedder. + model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE) + ) + base_options = _BaseOptions(model_asset_path=model_path) + options = _TextEmbedderOptions(base_options=base_options) + embedder = _TextEmbedder.create_from_options(options) + + # Extracts both embeddings. + text0 = ( + 'When you go to this restaurant, they hold the pancake upside-down ' + "before they hand it to you. It's a great gimmick." + ) + result0 = embedder.embed(text0) + + text1 = "Let's make a plan to steal the declaration of independence." + result1 = embedder.embed(text1) + + similarity = _TextEmbedder.cosine_similarity( + result0.embeddings[0], result1.embeddings[0] + ) + + # TODO: The similarity should likely be lower + self.assertAlmostEqual(similarity, 0.980880, delta=_SIMILARITY_TOLERANCE) + + # Closes the embedder explicitly when the embedder is not used in + # a context. + embedder.close() + if __name__ == '__main__': absltest.main()