From 3e0ed2ced053faae2f078dd2d53052c2f578120f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 18 Apr 2023 10:10:58 -0700 Subject: [PATCH] Internal Changes PiperOrigin-RevId: 525180095 --- .../python/vision/object_detector/BUILD | 6 +----- .../object_detector/object_detector_test.py | 17 +++++++---------- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/mediapipe/model_maker/python/vision/object_detector/BUILD b/mediapipe/model_maker/python/vision/object_detector/BUILD index f3d4407d8..b97d215da 100644 --- a/mediapipe/model_maker/python/vision/object_detector/BUILD +++ b/mediapipe/model_maker/python/vision/object_detector/BUILD @@ -175,11 +175,7 @@ py_test( data = [":testdata"], tags = ["requires-net:external"], deps = [ - ":dataset", - ":hyperparameters", - ":model_spec", - ":object_detector", - ":object_detector_options", + ":object_detector_import", "//mediapipe/tasks/python/test:test_utils", ], ) diff --git a/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py b/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py index df6b58a07..02f773e69 100644 --- a/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py +++ b/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py @@ -19,11 +19,7 @@ from unittest import mock as unittest_mock from absl.testing import parameterized import tensorflow as tf -from mediapipe.model_maker.python.vision.object_detector import dataset -from mediapipe.model_maker.python.vision.object_detector import hyperparameters -from mediapipe.model_maker.python.vision.object_detector import model_spec as ms -from mediapipe.model_maker.python.vision.object_detector import object_detector -from mediapipe.model_maker.python.vision.object_detector import object_detector_options +from mediapipe.model_maker.python.vision import object_detector from mediapipe.tasks.python.test import test_utils as task_test_utils @@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): super().setUp() dataset_folder = task_test_utils.get_test_data_path('coco_data') cache_dir = self.create_tempdir() - self.data = dataset.Dataset.from_coco_folder( + self.data = object_detector.Dataset.from_coco_folder( dataset_folder, cache_dir=cache_dir ) # Mock tempfile.gettempdir() to be unique for each test to avoid race @@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): self.addCleanup(mock_gettempdir.stop) def test_object_detector(self): - hparams = hyperparameters.HParams( + hparams = object_detector.HParams( epochs=1, batch_size=2, learning_rate=0.9, shuffle=False, export_dir=self.create_tempdir(), ) - options = object_detector_options.ObjectDetectorOptions( - supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams + options = object_detector.ObjectDetectorOptions( + supported_model=object_detector.SupportedModels.MOBILENET_V2, + hparams=hparams, ) # Test `create`` model = object_detector.ObjectDetector.create( @@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): self.assertGreater(os.path.getsize(output_metadata_file), 0) # Test `quantization_aware_training` - qat_hparams = hyperparameters.QATHParams( + qat_hparams = object_detector.QATHParams( learning_rate=0.9, batch_size=2, epochs=1,