Merge branch 'master' into ios-text-cocoapods-force-load

This commit is contained in:
Prianka Liz Kariat 2023-04-18 23:03:56 +05:30
commit 24bd7a6b9f
3 changed files with 46 additions and 15 deletions

View File

@ -175,11 +175,7 @@ py_test(
data = [":testdata"], data = [":testdata"],
tags = ["requires-net:external"], tags = ["requires-net:external"],
deps = [ deps = [
":dataset", ":object_detector_import",
":hyperparameters",
":model_spec",
":object_detector",
":object_detector_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],
) )

View File

@ -19,11 +19,7 @@ from unittest import mock as unittest_mock
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.vision.object_detector import dataset from mediapipe.model_maker.python.vision import object_detector
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
from mediapipe.model_maker.python.vision.object_detector import object_detector
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
from mediapipe.tasks.python.test import test_utils as task_test_utils from mediapipe.tasks.python.test import test_utils as task_test_utils
@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
super().setUp() super().setUp()
dataset_folder = task_test_utils.get_test_data_path('coco_data') dataset_folder = task_test_utils.get_test_data_path('coco_data')
cache_dir = self.create_tempdir() cache_dir = self.create_tempdir()
self.data = dataset.Dataset.from_coco_folder( self.data = object_detector.Dataset.from_coco_folder(
dataset_folder, cache_dir=cache_dir dataset_folder, cache_dir=cache_dir
) )
# Mock tempfile.gettempdir() to be unique for each test to avoid race # Mock tempfile.gettempdir() to be unique for each test to avoid race
@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
self.addCleanup(mock_gettempdir.stop) self.addCleanup(mock_gettempdir.stop)
def test_object_detector(self): def test_object_detector(self):
hparams = hyperparameters.HParams( hparams = object_detector.HParams(
epochs=1, epochs=1,
batch_size=2, batch_size=2,
learning_rate=0.9, learning_rate=0.9,
shuffle=False, shuffle=False,
export_dir=self.create_tempdir(), export_dir=self.create_tempdir(),
) )
options = object_detector_options.ObjectDetectorOptions( options = object_detector.ObjectDetectorOptions(
supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams supported_model=object_detector.SupportedModels.MOBILENET_V2,
hparams=hparams,
) )
# Test `create`` # Test `create``
model = object_detector.ObjectDetector.create( model = object_detector.ObjectDetector.create(
@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
self.assertGreater(os.path.getsize(output_metadata_file), 0) self.assertGreater(os.path.getsize(output_metadata_file), 0)
# Test `quantization_aware_training` # Test `quantization_aware_training`
qat_hparams = hyperparameters.QATHParams( qat_hparams = object_detector.QATHParams(
learning_rate=0.9, learning_rate=0.9,
batch_size=2, batch_size=2,
epochs=1, epochs=1,

View File

@ -32,6 +32,7 @@ _TextEmbedderOptions = text_embedder.TextEmbedderOptions
_BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite' _BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite'
_REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite' _REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite'
_USE_MODEL_FILE = 'universal_sentence_encoder_qa_with_metadata.tflite'
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text' _TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
# Tolerance for embedding vector coordinate values. # Tolerance for embedding vector coordinate values.
_EPSILON = 1e-4 _EPSILON = 1e-4
@ -138,6 +139,24 @@ class TextEmbedderTest(parameterized.TestCase):
16, 16,
(0.549632, 0.552879), (0.549632, 0.552879),
), ),
(
False,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_NAME,
0.851961,
100,
(1.422951, 1.404664),
),
(
True,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_CONTENT,
0.851961,
100,
(0.127049, 0.125416),
),
) )
def test_embed(self, l2_normalize, quantize, model_name, model_file_type, def test_embed(self, l2_normalize, quantize, model_name, model_file_type,
expected_similarity, expected_size, expected_first_values): expected_similarity, expected_size, expected_first_values):
@ -213,6 +232,24 @@ class TextEmbedderTest(parameterized.TestCase):
16, 16,
(0.549632, 0.552879), (0.549632, 0.552879),
), ),
(
False,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_NAME,
0.851961,
100,
(1.422951, 1.404664),
),
(
True,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_CONTENT,
0.851961,
100,
(0.127049, 0.125416),
),
) )
def test_embed_in_context(self, l2_normalize, quantize, model_name, def test_embed_in_context(self, l2_normalize, quantize, model_name,
model_file_type, expected_similarity, expected_size, model_file_type, expected_similarity, expected_size,
@ -251,6 +288,7 @@ class TextEmbedderTest(parameterized.TestCase):
@parameterized.parameters( @parameterized.parameters(
# TODO: The similarity should likely be lower # TODO: The similarity should likely be lower
(_BERT_MODEL_FILE, 0.980880), (_BERT_MODEL_FILE, 0.980880),
(_USE_MODEL_FILE, 0.780334),
) )
def test_embed_with_different_themes(self, model_file, expected_similarity): def test_embed_with_different_themes(self, model_file, expected_similarity):
# Creates embedder. # Creates embedder.