Internal Changes
PiperOrigin-RevId: 525180095
This commit is contained in:
parent
88a10de345
commit
3e0ed2ced0
|
@ -175,11 +175,7 @@ py_test(
|
|||
data = [":testdata"],
|
||||
tags = ["requires-net:external"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":hyperparameters",
|
||||
":model_spec",
|
||||
":object_detector",
|
||||
":object_detector_options",
|
||||
":object_detector_import",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -19,11 +19,7 @@ from unittest import mock as unittest_mock
|
|||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.object_detector import dataset
|
||||
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||
from mediapipe.model_maker.python.vision.object_detector import object_detector
|
||||
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
|
||||
from mediapipe.model_maker.python.vision import object_detector
|
||||
from mediapipe.tasks.python.test import test_utils as task_test_utils
|
||||
|
||||
|
||||
|
@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
super().setUp()
|
||||
dataset_folder = task_test_utils.get_test_data_path('coco_data')
|
||||
cache_dir = self.create_tempdir()
|
||||
self.data = dataset.Dataset.from_coco_folder(
|
||||
self.data = object_detector.Dataset.from_coco_folder(
|
||||
dataset_folder, cache_dir=cache_dir
|
||||
)
|
||||
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||
|
@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.addCleanup(mock_gettempdir.stop)
|
||||
|
||||
def test_object_detector(self):
|
||||
hparams = hyperparameters.HParams(
|
||||
hparams = object_detector.HParams(
|
||||
epochs=1,
|
||||
batch_size=2,
|
||||
learning_rate=0.9,
|
||||
shuffle=False,
|
||||
export_dir=self.create_tempdir(),
|
||||
)
|
||||
options = object_detector_options.ObjectDetectorOptions(
|
||||
supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams
|
||||
options = object_detector.ObjectDetectorOptions(
|
||||
supported_model=object_detector.SupportedModels.MOBILENET_V2,
|
||||
hparams=hparams,
|
||||
)
|
||||
# Test `create``
|
||||
model = object_detector.ObjectDetector.create(
|
||||
|
@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||
|
||||
# Test `quantization_aware_training`
|
||||
qat_hparams = hyperparameters.QATHParams(
|
||||
qat_hparams = object_detector.QATHParams(
|
||||
learning_rate=0.9,
|
||||
batch_size=2,
|
||||
epochs=1,
|
||||
|
|
Loading…
Reference in New Issue
Block a user