Merge branch 'master' into ios-text-cocoapods-force-load

This commit is contained in:
Prianka Liz Kariat 2023-04-18 23:03:56 +05:30
commit 24bd7a6b9f
3 changed files with 46 additions and 15 deletions

View File

@ -175,11 +175,7 @@ py_test(
data = [":testdata"],
tags = ["requires-net:external"],
deps = [
":dataset",
":hyperparameters",
":model_spec",
":object_detector",
":object_detector_options",
":object_detector_import",
"//mediapipe/tasks/python/test:test_utils",
],
)

View File

@ -19,11 +19,7 @@ from unittest import mock as unittest_mock
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.vision.object_detector import dataset
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
from mediapipe.model_maker.python.vision.object_detector import object_detector
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
from mediapipe.model_maker.python.vision import object_detector
from mediapipe.tasks.python.test import test_utils as task_test_utils
@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
super().setUp()
dataset_folder = task_test_utils.get_test_data_path('coco_data')
cache_dir = self.create_tempdir()
self.data = dataset.Dataset.from_coco_folder(
self.data = object_detector.Dataset.from_coco_folder(
dataset_folder, cache_dir=cache_dir
)
# Mock tempfile.gettempdir() to be unique for each test to avoid race
@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
self.addCleanup(mock_gettempdir.stop)
def test_object_detector(self):
hparams = hyperparameters.HParams(
hparams = object_detector.HParams(
epochs=1,
batch_size=2,
learning_rate=0.9,
shuffle=False,
export_dir=self.create_tempdir(),
)
options = object_detector_options.ObjectDetectorOptions(
supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams
options = object_detector.ObjectDetectorOptions(
supported_model=object_detector.SupportedModels.MOBILENET_V2,
hparams=hparams,
)
# Test `create``
model = object_detector.ObjectDetector.create(
@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
self.assertGreater(os.path.getsize(output_metadata_file), 0)
# Test `quantization_aware_training`
qat_hparams = hyperparameters.QATHParams(
qat_hparams = object_detector.QATHParams(
learning_rate=0.9,
batch_size=2,
epochs=1,

View File

@ -32,6 +32,7 @@ _TextEmbedderOptions = text_embedder.TextEmbedderOptions
_BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite'
_REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite'
_USE_MODEL_FILE = 'universal_sentence_encoder_qa_with_metadata.tflite'
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
# Tolerance for embedding vector coordinate values.
_EPSILON = 1e-4
@ -138,6 +139,24 @@ class TextEmbedderTest(parameterized.TestCase):
16,
(0.549632, 0.552879),
),
(
False,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_NAME,
0.851961,
100,
(1.422951, 1.404664),
),
(
True,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_CONTENT,
0.851961,
100,
(0.127049, 0.125416),
),
)
def test_embed(self, l2_normalize, quantize, model_name, model_file_type,
expected_similarity, expected_size, expected_first_values):
@ -213,6 +232,24 @@ class TextEmbedderTest(parameterized.TestCase):
16,
(0.549632, 0.552879),
),
(
False,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_NAME,
0.851961,
100,
(1.422951, 1.404664),
),
(
True,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_CONTENT,
0.851961,
100,
(0.127049, 0.125416),
),
)
def test_embed_in_context(self, l2_normalize, quantize, model_name,
model_file_type, expected_similarity, expected_size,
@ -251,6 +288,7 @@ class TextEmbedderTest(parameterized.TestCase):
@parameterized.parameters(
# TODO: The similarity should likely be lower
(_BERT_MODEL_FILE, 0.980880),
(_USE_MODEL_FILE, 0.780334),
)
def test_embed_with_different_themes(self, model_file, expected_similarity):
# Creates embedder.