diff --git a/mediapipe/model_maker/python/vision/object_detector/BUILD b/mediapipe/model_maker/python/vision/object_detector/BUILD index f3d4407d8..b97d215da 100644 --- a/mediapipe/model_maker/python/vision/object_detector/BUILD +++ b/mediapipe/model_maker/python/vision/object_detector/BUILD @@ -175,11 +175,7 @@ py_test( data = [":testdata"], tags = ["requires-net:external"], deps = [ - ":dataset", - ":hyperparameters", - ":model_spec", - ":object_detector", - ":object_detector_options", + ":object_detector_import", "//mediapipe/tasks/python/test:test_utils", ], ) diff --git a/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py b/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py index df6b58a07..02f773e69 100644 --- a/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py +++ b/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py @@ -19,11 +19,7 @@ from unittest import mock as unittest_mock from absl.testing import parameterized import tensorflow as tf -from mediapipe.model_maker.python.vision.object_detector import dataset -from mediapipe.model_maker.python.vision.object_detector import hyperparameters -from mediapipe.model_maker.python.vision.object_detector import model_spec as ms -from mediapipe.model_maker.python.vision.object_detector import object_detector -from mediapipe.model_maker.python.vision.object_detector import object_detector_options +from mediapipe.model_maker.python.vision import object_detector from mediapipe.tasks.python.test import test_utils as task_test_utils @@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): super().setUp() dataset_folder = task_test_utils.get_test_data_path('coco_data') cache_dir = self.create_tempdir() - self.data = dataset.Dataset.from_coco_folder( + self.data = object_detector.Dataset.from_coco_folder( dataset_folder, cache_dir=cache_dir ) # Mock tempfile.gettempdir() to be unique for each test to avoid race @@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): self.addCleanup(mock_gettempdir.stop) def test_object_detector(self): - hparams = hyperparameters.HParams( + hparams = object_detector.HParams( epochs=1, batch_size=2, learning_rate=0.9, shuffle=False, export_dir=self.create_tempdir(), ) - options = object_detector_options.ObjectDetectorOptions( - supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams + options = object_detector.ObjectDetectorOptions( + supported_model=object_detector.SupportedModels.MOBILENET_V2, + hparams=hparams, ) # Test `create`` model = object_detector.ObjectDetector.create( @@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): self.assertGreater(os.path.getsize(output_metadata_file), 0) # Test `quantization_aware_training` - qat_hparams = hyperparameters.QATHParams( + qat_hparams = object_detector.QATHParams( learning_rate=0.9, batch_size=2, epochs=1, diff --git a/mediapipe/tasks/python/test/text/text_embedder_test.py b/mediapipe/tasks/python/test/text/text_embedder_test.py index 78e98a1b4..62d162f6e 100644 --- a/mediapipe/tasks/python/test/text/text_embedder_test.py +++ b/mediapipe/tasks/python/test/text/text_embedder_test.py @@ -32,6 +32,7 @@ _TextEmbedderOptions = text_embedder.TextEmbedderOptions _BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite' _REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite' +_USE_MODEL_FILE = 'universal_sentence_encoder_qa_with_metadata.tflite' _TEST_DATA_DIR = 'mediapipe/tasks/testdata/text' # Tolerance for embedding vector coordinate values. _EPSILON = 1e-4 @@ -138,6 +139,24 @@ class TextEmbedderTest(parameterized.TestCase): 16, (0.549632, 0.552879), ), + ( + False, + False, + _USE_MODEL_FILE, + ModelFileType.FILE_NAME, + 0.851961, + 100, + (1.422951, 1.404664), + ), + ( + True, + False, + _USE_MODEL_FILE, + ModelFileType.FILE_CONTENT, + 0.851961, + 100, + (0.127049, 0.125416), + ), ) def test_embed(self, l2_normalize, quantize, model_name, model_file_type, expected_similarity, expected_size, expected_first_values): @@ -213,6 +232,24 @@ class TextEmbedderTest(parameterized.TestCase): 16, (0.549632, 0.552879), ), + ( + False, + False, + _USE_MODEL_FILE, + ModelFileType.FILE_NAME, + 0.851961, + 100, + (1.422951, 1.404664), + ), + ( + True, + False, + _USE_MODEL_FILE, + ModelFileType.FILE_CONTENT, + 0.851961, + 100, + (0.127049, 0.125416), + ), ) def test_embed_in_context(self, l2_normalize, quantize, model_name, model_file_type, expected_similarity, expected_size, @@ -251,6 +288,7 @@ class TextEmbedderTest(parameterized.TestCase): @parameterized.parameters( # TODO: The similarity should likely be lower (_BERT_MODEL_FILE, 0.980880), + (_USE_MODEL_FILE, 0.780334), ) def test_embed_with_different_themes(self, model_file, expected_similarity): # Creates embedder.