Merge branch 'master' into ios-text-cocoapods-force-load
This commit is contained in:
commit
24bd7a6b9f
|
@ -175,11 +175,7 @@ py_test(
|
||||||
data = [":testdata"],
|
data = [":testdata"],
|
||||||
tags = ["requires-net:external"],
|
tags = ["requires-net:external"],
|
||||||
deps = [
|
deps = [
|
||||||
":dataset",
|
":object_detector_import",
|
||||||
":hyperparameters",
|
|
||||||
":model_spec",
|
|
||||||
":object_detector",
|
|
||||||
":object_detector_options",
|
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,11 +19,7 @@ from unittest import mock as unittest_mock
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.vision.object_detector import dataset
|
from mediapipe.model_maker.python.vision import object_detector
|
||||||
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
|
|
||||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
|
||||||
from mediapipe.model_maker.python.vision.object_detector import object_detector
|
|
||||||
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
|
|
||||||
from mediapipe.tasks.python.test import test_utils as task_test_utils
|
from mediapipe.tasks.python.test import test_utils as task_test_utils
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
dataset_folder = task_test_utils.get_test_data_path('coco_data')
|
dataset_folder = task_test_utils.get_test_data_path('coco_data')
|
||||||
cache_dir = self.create_tempdir()
|
cache_dir = self.create_tempdir()
|
||||||
self.data = dataset.Dataset.from_coco_folder(
|
self.data = object_detector.Dataset.from_coco_folder(
|
||||||
dataset_folder, cache_dir=cache_dir
|
dataset_folder, cache_dir=cache_dir
|
||||||
)
|
)
|
||||||
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||||
|
@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.addCleanup(mock_gettempdir.stop)
|
self.addCleanup(mock_gettempdir.stop)
|
||||||
|
|
||||||
def test_object_detector(self):
|
def test_object_detector(self):
|
||||||
hparams = hyperparameters.HParams(
|
hparams = object_detector.HParams(
|
||||||
epochs=1,
|
epochs=1,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
learning_rate=0.9,
|
learning_rate=0.9,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
export_dir=self.create_tempdir(),
|
export_dir=self.create_tempdir(),
|
||||||
)
|
)
|
||||||
options = object_detector_options.ObjectDetectorOptions(
|
options = object_detector.ObjectDetectorOptions(
|
||||||
supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams
|
supported_model=object_detector.SupportedModels.MOBILENET_V2,
|
||||||
|
hparams=hparams,
|
||||||
)
|
)
|
||||||
# Test `create``
|
# Test `create``
|
||||||
model = object_detector.ObjectDetector.create(
|
model = object_detector.ObjectDetector.create(
|
||||||
|
@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||||
|
|
||||||
# Test `quantization_aware_training`
|
# Test `quantization_aware_training`
|
||||||
qat_hparams = hyperparameters.QATHParams(
|
qat_hparams = object_detector.QATHParams(
|
||||||
learning_rate=0.9,
|
learning_rate=0.9,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
epochs=1,
|
epochs=1,
|
||||||
|
|
|
@ -32,6 +32,7 @@ _TextEmbedderOptions = text_embedder.TextEmbedderOptions
|
||||||
|
|
||||||
_BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite'
|
_BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite'
|
||||||
_REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite'
|
_REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite'
|
||||||
|
_USE_MODEL_FILE = 'universal_sentence_encoder_qa_with_metadata.tflite'
|
||||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
|
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
|
||||||
# Tolerance for embedding vector coordinate values.
|
# Tolerance for embedding vector coordinate values.
|
||||||
_EPSILON = 1e-4
|
_EPSILON = 1e-4
|
||||||
|
@ -138,6 +139,24 @@ class TextEmbedderTest(parameterized.TestCase):
|
||||||
16,
|
16,
|
||||||
(0.549632, 0.552879),
|
(0.549632, 0.552879),
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
_USE_MODEL_FILE,
|
||||||
|
ModelFileType.FILE_NAME,
|
||||||
|
0.851961,
|
||||||
|
100,
|
||||||
|
(1.422951, 1.404664),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
_USE_MODEL_FILE,
|
||||||
|
ModelFileType.FILE_CONTENT,
|
||||||
|
0.851961,
|
||||||
|
100,
|
||||||
|
(0.127049, 0.125416),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
def test_embed(self, l2_normalize, quantize, model_name, model_file_type,
|
def test_embed(self, l2_normalize, quantize, model_name, model_file_type,
|
||||||
expected_similarity, expected_size, expected_first_values):
|
expected_similarity, expected_size, expected_first_values):
|
||||||
|
@ -213,6 +232,24 @@ class TextEmbedderTest(parameterized.TestCase):
|
||||||
16,
|
16,
|
||||||
(0.549632, 0.552879),
|
(0.549632, 0.552879),
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
_USE_MODEL_FILE,
|
||||||
|
ModelFileType.FILE_NAME,
|
||||||
|
0.851961,
|
||||||
|
100,
|
||||||
|
(1.422951, 1.404664),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
_USE_MODEL_FILE,
|
||||||
|
ModelFileType.FILE_CONTENT,
|
||||||
|
0.851961,
|
||||||
|
100,
|
||||||
|
(0.127049, 0.125416),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
def test_embed_in_context(self, l2_normalize, quantize, model_name,
|
def test_embed_in_context(self, l2_normalize, quantize, model_name,
|
||||||
model_file_type, expected_similarity, expected_size,
|
model_file_type, expected_similarity, expected_size,
|
||||||
|
@ -251,6 +288,7 @@ class TextEmbedderTest(parameterized.TestCase):
|
||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
# TODO: The similarity should likely be lower
|
# TODO: The similarity should likely be lower
|
||||||
(_BERT_MODEL_FILE, 0.980880),
|
(_BERT_MODEL_FILE, 0.980880),
|
||||||
|
(_USE_MODEL_FILE, 0.780334),
|
||||||
)
|
)
|
||||||
def test_embed_with_different_themes(self, model_file, expected_similarity):
|
def test_embed_with_different_themes(self, model_file, expected_similarity):
|
||||||
# Creates embedder.
|
# Creates embedder.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user