Internal Changes

PiperOrigin-RevId: 525180095
This commit is contained in:
MediaPipe Team 2023-04-18 10:10:58 -07:00 committed by Copybara-Service
parent 88a10de345
commit 3e0ed2ced0
2 changed files with 8 additions and 15 deletions

View File

@ -175,11 +175,7 @@ py_test(
data = [":testdata"],
tags = ["requires-net:external"],
deps = [
":dataset",
":hyperparameters",
":model_spec",
":object_detector",
":object_detector_options",
":object_detector_import",
"//mediapipe/tasks/python/test:test_utils",
],
)

View File

@ -19,11 +19,7 @@ from unittest import mock as unittest_mock
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.vision.object_detector import dataset
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
from mediapipe.model_maker.python.vision.object_detector import object_detector
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
from mediapipe.model_maker.python.vision import object_detector
from mediapipe.tasks.python.test import test_utils as task_test_utils
@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
super().setUp()
dataset_folder = task_test_utils.get_test_data_path('coco_data')
cache_dir = self.create_tempdir()
self.data = dataset.Dataset.from_coco_folder(
self.data = object_detector.Dataset.from_coco_folder(
dataset_folder, cache_dir=cache_dir
)
# Mock tempfile.gettempdir() to be unique for each test to avoid race
@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
self.addCleanup(mock_gettempdir.stop)
def test_object_detector(self):
hparams = hyperparameters.HParams(
hparams = object_detector.HParams(
epochs=1,
batch_size=2,
learning_rate=0.9,
shuffle=False,
export_dir=self.create_tempdir(),
)
options = object_detector_options.ObjectDetectorOptions(
supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams
options = object_detector.ObjectDetectorOptions(
supported_model=object_detector.SupportedModels.MOBILENET_V2,
hparams=hparams,
)
# Test `create``
model = object_detector.ObjectDetector.create(
@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
self.assertGreater(os.path.getsize(output_metadata_file), 0)
# Test `quantization_aware_training`
qat_hparams = hyperparameters.QATHParams(
qat_hparams = object_detector.QATHParams(
learning_rate=0.9,
batch_size=2,
epochs=1,