diff --git a/mediapipe/tasks/python/test/text/BUILD b/mediapipe/tasks/python/test/text/BUILD index 0e2b06012..5f2d18bc5 100644 --- a/mediapipe/tasks/python/test/text/BUILD +++ b/mediapipe/tasks/python/test/text/BUILD @@ -40,6 +40,7 @@ py_test( data = [ "//mediapipe/tasks/testdata/text:mobilebert_embedding_model", "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", + "//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa", ], deps = [ "//mediapipe/tasks/python/components/containers:embedding_result", diff --git a/mediapipe/tasks/python/test/text/text_embedder_test.py b/mediapipe/tasks/python/test/text/text_embedder_test.py index 455deba03..78e98a1b4 100644 --- a/mediapipe/tasks/python/test/text/text_embedder_test.py +++ b/mediapipe/tasks/python/test/text/text_embedder_test.py @@ -102,14 +102,42 @@ class TextEmbedderTest(parameterized.TestCase): similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE) @parameterized.parameters( - (False, False, _BERT_MODEL_FILE, ModelFileType.FILE_NAME, 0.969514, 512, - (19.9016, 22.626251)), - (True, False, _BERT_MODEL_FILE, ModelFileType.FILE_NAME, 0.969514, 512, - (0.0585837, 0.0723035)), - (False, False, _REGEX_MODEL_FILE, ModelFileType.FILE_NAME, 0.999937, 16, - (0.0309356, 0.0312863)), - (True, False, _REGEX_MODEL_FILE, ModelFileType.FILE_CONTENT, 0.999937, 16, - (0.549632, 0.552879)), + ( + False, + False, + _BERT_MODEL_FILE, + ModelFileType.FILE_NAME, + 0.969514, + 512, + (19.9016, 22.626251), + ), + ( + True, + False, + _BERT_MODEL_FILE, + ModelFileType.FILE_NAME, + 0.969514, + 512, + (0.0585837, 0.0723035), + ), + ( + False, + False, + _REGEX_MODEL_FILE, + ModelFileType.FILE_NAME, + 0.999937, + 16, + (0.0309356, 0.0312863), + ), + ( + True, + False, + _REGEX_MODEL_FILE, + ModelFileType.FILE_CONTENT, + 0.999937, + 16, + (0.549632, 0.552879), + ), ) def test_embed(self, l2_normalize, quantize, model_name, model_file_type, expected_similarity, expected_size, expected_first_values): @@ -149,14 +177,42 @@ class TextEmbedderTest(parameterized.TestCase): embedder.close() @parameterized.parameters( - (False, False, _BERT_MODEL_FILE, ModelFileType.FILE_NAME, 0.969514, 512, - (19.9016, 22.626251)), - (True, False, _BERT_MODEL_FILE, ModelFileType.FILE_NAME, 0.969514, 512, - (0.0585837, 0.0723035)), - (False, False, _REGEX_MODEL_FILE, ModelFileType.FILE_NAME, 0.999937, 16, - (0.0309356, 0.0312863)), - (True, False, _REGEX_MODEL_FILE, ModelFileType.FILE_CONTENT, 0.999937, 16, - (0.549632, 0.552879)), + ( + False, + False, + _BERT_MODEL_FILE, + ModelFileType.FILE_NAME, + 0.969514, + 512, + (19.9016, 22.626251), + ), + ( + True, + False, + _BERT_MODEL_FILE, + ModelFileType.FILE_NAME, + 0.969514, + 512, + (0.0585837, 0.0723035), + ), + ( + False, + False, + _REGEX_MODEL_FILE, + ModelFileType.FILE_NAME, + 0.999937, + 16, + (0.0309356, 0.0312863), + ), + ( + True, + False, + _REGEX_MODEL_FILE, + ModelFileType.FILE_CONTENT, + 0.999937, + 16, + (0.549632, 0.552879), + ), ) def test_embed_in_context(self, l2_normalize, quantize, model_name, model_file_type, expected_similarity, expected_size, @@ -192,10 +248,14 @@ class TextEmbedderTest(parameterized.TestCase): self._check_embedding_value(result1, expected_result1_value) self._check_cosine_similarity(result0, result1, expected_similarity) - def test_embed_with_mobile_bert_and_different_themes(self): + @parameterized.parameters( + # TODO: The similarity should likely be lower + (_BERT_MODEL_FILE, 0.980880), + ) + def test_embed_with_different_themes(self, model_file, expected_similarity): # Creates embedder. model_path = test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE) + os.path.join(_TEST_DATA_DIR, model_file) ) base_options = _BaseOptions(model_asset_path=model_path) options = _TextEmbedderOptions(base_options=base_options) @@ -215,8 +275,9 @@ class TextEmbedderTest(parameterized.TestCase): result0.embeddings[0], result1.embeddings[0] ) - # TODO: The similarity should likely be lower - self.assertAlmostEqual(similarity, 0.980880, delta=_SIMILARITY_TOLERANCE) + self.assertAlmostEqual( + similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE + ) # Closes the embedder explicitly when the embedder is not used in # a context.