Internal Changes

PiperOrigin-RevId: 525180095
This commit is contained in:
MediaPipe Team 2023-04-18 10:10:58 -07:00 committed by Copybara-Service
parent 88a10de345
commit 3e0ed2ced0
2 changed files with 8 additions and 15 deletions

View File

@ -175,11 +175,7 @@ py_test(
data = [":testdata"], data = [":testdata"],
tags = ["requires-net:external"], tags = ["requires-net:external"],
deps = [ deps = [
":dataset", ":object_detector_import",
":hyperparameters",
":model_spec",
":object_detector",
":object_detector_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],
) )

View File

@ -19,11 +19,7 @@ from unittest import mock as unittest_mock
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.vision.object_detector import dataset from mediapipe.model_maker.python.vision import object_detector
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
from mediapipe.model_maker.python.vision.object_detector import object_detector
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
from mediapipe.tasks.python.test import test_utils as task_test_utils from mediapipe.tasks.python.test import test_utils as task_test_utils
@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
super().setUp() super().setUp()
dataset_folder = task_test_utils.get_test_data_path('coco_data') dataset_folder = task_test_utils.get_test_data_path('coco_data')
cache_dir = self.create_tempdir() cache_dir = self.create_tempdir()
self.data = dataset.Dataset.from_coco_folder( self.data = object_detector.Dataset.from_coco_folder(
dataset_folder, cache_dir=cache_dir dataset_folder, cache_dir=cache_dir
) )
# Mock tempfile.gettempdir() to be unique for each test to avoid race # Mock tempfile.gettempdir() to be unique for each test to avoid race
@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
self.addCleanup(mock_gettempdir.stop) self.addCleanup(mock_gettempdir.stop)
def test_object_detector(self): def test_object_detector(self):
hparams = hyperparameters.HParams( hparams = object_detector.HParams(
epochs=1, epochs=1,
batch_size=2, batch_size=2,
learning_rate=0.9, learning_rate=0.9,
shuffle=False, shuffle=False,
export_dir=self.create_tempdir(), export_dir=self.create_tempdir(),
) )
options = object_detector_options.ObjectDetectorOptions( options = object_detector.ObjectDetectorOptions(
supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams supported_model=object_detector.SupportedModels.MOBILENET_V2,
hparams=hparams,
) )
# Test `create`` # Test `create``
model = object_detector.ObjectDetector.create( model = object_detector.ObjectDetector.create(
@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
self.assertGreater(os.path.getsize(output_metadata_file), 0) self.assertGreater(os.path.getsize(output_metadata_file), 0)
# Test `quantization_aware_training` # Test `quantization_aware_training`
qat_hparams = hyperparameters.QATHParams( qat_hparams = object_detector.QATHParams(
learning_rate=0.9, learning_rate=0.9,
batch_size=2, batch_size=2,
epochs=1, epochs=1,