Internal MediaPipe Tasks change.

PiperOrigin-RevId: 514637484
This commit is contained in:
MediaPipe Team 2023-03-06 23:15:37 -08:00 committed by Copybara-Service
parent c1b460920c
commit 2f2a74da6a
2 changed files with 82 additions and 20 deletions

View File

@ -40,6 +40,7 @@ py_test(
data = [ data = [
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model", "//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
"//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa",
], ],
deps = [ deps = [
"//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/components/containers:embedding_result",

View File

@ -102,14 +102,42 @@ class TextEmbedderTest(parameterized.TestCase):
similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE) similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE)
@parameterized.parameters( @parameterized.parameters(
(False, False, _BERT_MODEL_FILE, ModelFileType.FILE_NAME, 0.969514, 512, (
(19.9016, 22.626251)), False,
(True, False, _BERT_MODEL_FILE, ModelFileType.FILE_NAME, 0.969514, 512, False,
(0.0585837, 0.0723035)), _BERT_MODEL_FILE,
(False, False, _REGEX_MODEL_FILE, ModelFileType.FILE_NAME, 0.999937, 16, ModelFileType.FILE_NAME,
(0.0309356, 0.0312863)), 0.969514,
(True, False, _REGEX_MODEL_FILE, ModelFileType.FILE_CONTENT, 0.999937, 16, 512,
(0.549632, 0.552879)), (19.9016, 22.626251),
),
(
True,
False,
_BERT_MODEL_FILE,
ModelFileType.FILE_NAME,
0.969514,
512,
(0.0585837, 0.0723035),
),
(
False,
False,
_REGEX_MODEL_FILE,
ModelFileType.FILE_NAME,
0.999937,
16,
(0.0309356, 0.0312863),
),
(
True,
False,
_REGEX_MODEL_FILE,
ModelFileType.FILE_CONTENT,
0.999937,
16,
(0.549632, 0.552879),
),
) )
def test_embed(self, l2_normalize, quantize, model_name, model_file_type, def test_embed(self, l2_normalize, quantize, model_name, model_file_type,
expected_similarity, expected_size, expected_first_values): expected_similarity, expected_size, expected_first_values):
@ -149,14 +177,42 @@ class TextEmbedderTest(parameterized.TestCase):
embedder.close() embedder.close()
@parameterized.parameters( @parameterized.parameters(
(False, False, _BERT_MODEL_FILE, ModelFileType.FILE_NAME, 0.969514, 512, (
(19.9016, 22.626251)), False,
(True, False, _BERT_MODEL_FILE, ModelFileType.FILE_NAME, 0.969514, 512, False,
(0.0585837, 0.0723035)), _BERT_MODEL_FILE,
(False, False, _REGEX_MODEL_FILE, ModelFileType.FILE_NAME, 0.999937, 16, ModelFileType.FILE_NAME,
(0.0309356, 0.0312863)), 0.969514,
(True, False, _REGEX_MODEL_FILE, ModelFileType.FILE_CONTENT, 0.999937, 16, 512,
(0.549632, 0.552879)), (19.9016, 22.626251),
),
(
True,
False,
_BERT_MODEL_FILE,
ModelFileType.FILE_NAME,
0.969514,
512,
(0.0585837, 0.0723035),
),
(
False,
False,
_REGEX_MODEL_FILE,
ModelFileType.FILE_NAME,
0.999937,
16,
(0.0309356, 0.0312863),
),
(
True,
False,
_REGEX_MODEL_FILE,
ModelFileType.FILE_CONTENT,
0.999937,
16,
(0.549632, 0.552879),
),
) )
def test_embed_in_context(self, l2_normalize, quantize, model_name, def test_embed_in_context(self, l2_normalize, quantize, model_name,
model_file_type, expected_similarity, expected_size, model_file_type, expected_similarity, expected_size,
@ -192,10 +248,14 @@ class TextEmbedderTest(parameterized.TestCase):
self._check_embedding_value(result1, expected_result1_value) self._check_embedding_value(result1, expected_result1_value)
self._check_cosine_similarity(result0, result1, expected_similarity) self._check_cosine_similarity(result0, result1, expected_similarity)
def test_embed_with_mobile_bert_and_different_themes(self): @parameterized.parameters(
# TODO: The similarity should likely be lower
(_BERT_MODEL_FILE, 0.980880),
)
def test_embed_with_different_themes(self, model_file, expected_similarity):
# Creates embedder. # Creates embedder.
model_path = test_utils.get_test_data_path( model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE) os.path.join(_TEST_DATA_DIR, model_file)
) )
base_options = _BaseOptions(model_asset_path=model_path) base_options = _BaseOptions(model_asset_path=model_path)
options = _TextEmbedderOptions(base_options=base_options) options = _TextEmbedderOptions(base_options=base_options)
@ -215,8 +275,9 @@ class TextEmbedderTest(parameterized.TestCase):
result0.embeddings[0], result1.embeddings[0] result0.embeddings[0], result1.embeddings[0]
) )
# TODO: The similarity should likely be lower self.assertAlmostEqual(
self.assertAlmostEqual(similarity, 0.980880, delta=_SIMILARITY_TOLERANCE) similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE
)
# Closes the embedder explicitly when the embedder is not used in # Closes the embedder explicitly when the embedder is not used in
# a context. # a context.