Internal changes
PiperOrigin-RevId: 522248624
This commit is contained in:
parent
5a1a9269e6
commit
0067a1b5c2
|
@ -14,7 +14,6 @@
|
|||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest # pylint:disable=unused-import
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
@ -28,7 +27,6 @@ from mediapipe.model_maker.python.vision.object_detector import object_detector_
|
|||
from mediapipe.tasks.python.test import test_utils as task_test_utils
|
||||
|
||||
|
||||
@unittest.skip('b/275624089')
|
||||
class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -51,7 +49,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
def test_object_detector(self):
|
||||
hparams = hyperparameters.HParams(
|
||||
epochs=10,
|
||||
epochs=1,
|
||||
batch_size=2,
|
||||
learning_rate=0.9,
|
||||
shuffle=False,
|
||||
|
@ -75,7 +73,6 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
output_tflite_file = os.path.join(
|
||||
options.hparams.export_dir, 'model.tflite'
|
||||
)
|
||||
print('ASDF float', os.path.getsize(output_tflite_file))
|
||||
self.assertTrue(os.path.exists(output_tflite_file))
|
||||
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
||||
self.assertTrue(os.path.exists(output_metadata_file))
|
||||
|
@ -85,7 +82,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
qat_hparams = hyperparameters.QATHParams(
|
||||
learning_rate=0.9,
|
||||
batch_size=2,
|
||||
epochs=5,
|
||||
epochs=1,
|
||||
decay_steps=6,
|
||||
decay_rate=0.96,
|
||||
)
|
||||
|
@ -101,7 +98,6 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
output_tflite_file = os.path.join(
|
||||
options.hparams.export_dir, 'model_qat.tflite'
|
||||
)
|
||||
print('ASDF qat', os.path.getsize(output_tflite_file))
|
||||
self.assertTrue(os.path.exists(output_tflite_file))
|
||||
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
||||
self.assertLess(os.path.getsize(output_tflite_file), 3500000)
|
||||
|
|
Loading…
Reference in New Issue
Block a user