From 0067a1b5c233895ced9579efdf3ed02c5e2fb338 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 5 Apr 2023 22:35:14 -0700 Subject: [PATCH] Internal changes PiperOrigin-RevId: 522248624 --- .../python/vision/object_detector/object_detector_test.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py b/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py index 3feb75f2e..df6b58a07 100644 --- a/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py +++ b/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py @@ -14,7 +14,6 @@ import os import tempfile -import unittest # pylint:disable=unused-import from unittest import mock as unittest_mock from absl.testing import parameterized @@ -28,7 +27,6 @@ from mediapipe.model_maker.python.vision.object_detector import object_detector_ from mediapipe.tasks.python.test import test_utils as task_test_utils -@unittest.skip('b/275624089') class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): @@ -51,7 +49,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): def test_object_detector(self): hparams = hyperparameters.HParams( - epochs=10, + epochs=1, batch_size=2, learning_rate=0.9, shuffle=False, @@ -75,7 +73,6 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): output_tflite_file = os.path.join( options.hparams.export_dir, 'model.tflite' ) - print('ASDF float', os.path.getsize(output_tflite_file)) self.assertTrue(os.path.exists(output_tflite_file)) self.assertGreater(os.path.getsize(output_tflite_file), 0) self.assertTrue(os.path.exists(output_metadata_file)) @@ -85,7 +82,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): qat_hparams = hyperparameters.QATHParams( learning_rate=0.9, batch_size=2, - epochs=5, + epochs=1, decay_steps=6, decay_rate=0.96, ) @@ -101,7 +98,6 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): output_tflite_file = os.path.join( options.hparams.export_dir, 'model_qat.tflite' ) - print('ASDF qat', os.path.getsize(output_tflite_file)) self.assertTrue(os.path.exists(output_tflite_file)) self.assertGreater(os.path.getsize(output_tflite_file), 0) self.assertLess(os.path.getsize(output_tflite_file), 3500000)