Merge branch 'master' into ios-text-cocoapods-force-load
This commit is contained in:
		
						commit
						24bd7a6b9f
					
				| 
						 | 
					@ -175,11 +175,7 @@ py_test(
 | 
				
			||||||
    data = [":testdata"],
 | 
					    data = [":testdata"],
 | 
				
			||||||
    tags = ["requires-net:external"],
 | 
					    tags = ["requires-net:external"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        ":dataset",
 | 
					        ":object_detector_import",
 | 
				
			||||||
        ":hyperparameters",
 | 
					 | 
				
			||||||
        ":model_spec",
 | 
					 | 
				
			||||||
        ":object_detector",
 | 
					 | 
				
			||||||
        ":object_detector_options",
 | 
					 | 
				
			||||||
        "//mediapipe/tasks/python/test:test_utils",
 | 
					        "//mediapipe/tasks/python/test:test_utils",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,11 +19,7 @@ from unittest import mock as unittest_mock
 | 
				
			||||||
from absl.testing import parameterized
 | 
					from absl.testing import parameterized
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mediapipe.model_maker.python.vision.object_detector import dataset
 | 
					from mediapipe.model_maker.python.vision import object_detector
 | 
				
			||||||
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
 | 
					 | 
				
			||||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
 | 
					 | 
				
			||||||
from mediapipe.model_maker.python.vision.object_detector import object_detector
 | 
					 | 
				
			||||||
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
 | 
					 | 
				
			||||||
from mediapipe.tasks.python.test import test_utils as task_test_utils
 | 
					from mediapipe.tasks.python.test import test_utils as task_test_utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
 | 
				
			||||||
    super().setUp()
 | 
					    super().setUp()
 | 
				
			||||||
    dataset_folder = task_test_utils.get_test_data_path('coco_data')
 | 
					    dataset_folder = task_test_utils.get_test_data_path('coco_data')
 | 
				
			||||||
    cache_dir = self.create_tempdir()
 | 
					    cache_dir = self.create_tempdir()
 | 
				
			||||||
    self.data = dataset.Dataset.from_coco_folder(
 | 
					    self.data = object_detector.Dataset.from_coco_folder(
 | 
				
			||||||
        dataset_folder, cache_dir=cache_dir
 | 
					        dataset_folder, cache_dir=cache_dir
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    # Mock tempfile.gettempdir() to be unique for each test to avoid race
 | 
					    # Mock tempfile.gettempdir() to be unique for each test to avoid race
 | 
				
			||||||
| 
						 | 
					@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
 | 
				
			||||||
    self.addCleanup(mock_gettempdir.stop)
 | 
					    self.addCleanup(mock_gettempdir.stop)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_object_detector(self):
 | 
					  def test_object_detector(self):
 | 
				
			||||||
    hparams = hyperparameters.HParams(
 | 
					    hparams = object_detector.HParams(
 | 
				
			||||||
        epochs=1,
 | 
					        epochs=1,
 | 
				
			||||||
        batch_size=2,
 | 
					        batch_size=2,
 | 
				
			||||||
        learning_rate=0.9,
 | 
					        learning_rate=0.9,
 | 
				
			||||||
        shuffle=False,
 | 
					        shuffle=False,
 | 
				
			||||||
        export_dir=self.create_tempdir(),
 | 
					        export_dir=self.create_tempdir(),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    options = object_detector_options.ObjectDetectorOptions(
 | 
					    options = object_detector.ObjectDetectorOptions(
 | 
				
			||||||
        supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams
 | 
					        supported_model=object_detector.SupportedModels.MOBILENET_V2,
 | 
				
			||||||
 | 
					        hparams=hparams,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    # Test `create``
 | 
					    # Test `create``
 | 
				
			||||||
    model = object_detector.ObjectDetector.create(
 | 
					    model = object_detector.ObjectDetector.create(
 | 
				
			||||||
| 
						 | 
					@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
 | 
				
			||||||
    self.assertGreater(os.path.getsize(output_metadata_file), 0)
 | 
					    self.assertGreater(os.path.getsize(output_metadata_file), 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Test `quantization_aware_training`
 | 
					    # Test `quantization_aware_training`
 | 
				
			||||||
    qat_hparams = hyperparameters.QATHParams(
 | 
					    qat_hparams = object_detector.QATHParams(
 | 
				
			||||||
        learning_rate=0.9,
 | 
					        learning_rate=0.9,
 | 
				
			||||||
        batch_size=2,
 | 
					        batch_size=2,
 | 
				
			||||||
        epochs=1,
 | 
					        epochs=1,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -32,6 +32,7 @@ _TextEmbedderOptions = text_embedder.TextEmbedderOptions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite'
 | 
					_BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite'
 | 
				
			||||||
_REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite'
 | 
					_REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite'
 | 
				
			||||||
 | 
					_USE_MODEL_FILE = 'universal_sentence_encoder_qa_with_metadata.tflite'
 | 
				
			||||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
 | 
					_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
 | 
				
			||||||
# Tolerance for embedding vector coordinate values.
 | 
					# Tolerance for embedding vector coordinate values.
 | 
				
			||||||
_EPSILON = 1e-4
 | 
					_EPSILON = 1e-4
 | 
				
			||||||
| 
						 | 
					@ -138,6 +139,24 @@ class TextEmbedderTest(parameterized.TestCase):
 | 
				
			||||||
          16,
 | 
					          16,
 | 
				
			||||||
          (0.549632, 0.552879),
 | 
					          (0.549632, 0.552879),
 | 
				
			||||||
      ),
 | 
					      ),
 | 
				
			||||||
 | 
					      (
 | 
				
			||||||
 | 
					          False,
 | 
				
			||||||
 | 
					          False,
 | 
				
			||||||
 | 
					          _USE_MODEL_FILE,
 | 
				
			||||||
 | 
					          ModelFileType.FILE_NAME,
 | 
				
			||||||
 | 
					          0.851961,
 | 
				
			||||||
 | 
					          100,
 | 
				
			||||||
 | 
					          (1.422951, 1.404664),
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      (
 | 
				
			||||||
 | 
					          True,
 | 
				
			||||||
 | 
					          False,
 | 
				
			||||||
 | 
					          _USE_MODEL_FILE,
 | 
				
			||||||
 | 
					          ModelFileType.FILE_CONTENT,
 | 
				
			||||||
 | 
					          0.851961,
 | 
				
			||||||
 | 
					          100,
 | 
				
			||||||
 | 
					          (0.127049, 0.125416),
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
  def test_embed(self, l2_normalize, quantize, model_name, model_file_type,
 | 
					  def test_embed(self, l2_normalize, quantize, model_name, model_file_type,
 | 
				
			||||||
                 expected_similarity, expected_size, expected_first_values):
 | 
					                 expected_similarity, expected_size, expected_first_values):
 | 
				
			||||||
| 
						 | 
					@ -213,6 +232,24 @@ class TextEmbedderTest(parameterized.TestCase):
 | 
				
			||||||
          16,
 | 
					          16,
 | 
				
			||||||
          (0.549632, 0.552879),
 | 
					          (0.549632, 0.552879),
 | 
				
			||||||
      ),
 | 
					      ),
 | 
				
			||||||
 | 
					      (
 | 
				
			||||||
 | 
					          False,
 | 
				
			||||||
 | 
					          False,
 | 
				
			||||||
 | 
					          _USE_MODEL_FILE,
 | 
				
			||||||
 | 
					          ModelFileType.FILE_NAME,
 | 
				
			||||||
 | 
					          0.851961,
 | 
				
			||||||
 | 
					          100,
 | 
				
			||||||
 | 
					          (1.422951, 1.404664),
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      (
 | 
				
			||||||
 | 
					          True,
 | 
				
			||||||
 | 
					          False,
 | 
				
			||||||
 | 
					          _USE_MODEL_FILE,
 | 
				
			||||||
 | 
					          ModelFileType.FILE_CONTENT,
 | 
				
			||||||
 | 
					          0.851961,
 | 
				
			||||||
 | 
					          100,
 | 
				
			||||||
 | 
					          (0.127049, 0.125416),
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
  def test_embed_in_context(self, l2_normalize, quantize, model_name,
 | 
					  def test_embed_in_context(self, l2_normalize, quantize, model_name,
 | 
				
			||||||
                            model_file_type, expected_similarity, expected_size,
 | 
					                            model_file_type, expected_similarity, expected_size,
 | 
				
			||||||
| 
						 | 
					@ -251,6 +288,7 @@ class TextEmbedderTest(parameterized.TestCase):
 | 
				
			||||||
  @parameterized.parameters(
 | 
					  @parameterized.parameters(
 | 
				
			||||||
      # TODO: The similarity should likely be lower
 | 
					      # TODO: The similarity should likely be lower
 | 
				
			||||||
      (_BERT_MODEL_FILE, 0.980880),
 | 
					      (_BERT_MODEL_FILE, 0.980880),
 | 
				
			||||||
 | 
					      (_USE_MODEL_FILE, 0.780334),
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
  def test_embed_with_different_themes(self, model_file, expected_similarity):
 | 
					  def test_embed_with_different_themes(self, model_file, expected_similarity):
 | 
				
			||||||
    # Creates embedder.
 | 
					    # Creates embedder.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user