Internal changes
PiperOrigin-RevId: 522248624
This commit is contained in:
parent
5a1a9269e6
commit
0067a1b5c2
|
@ -14,7 +14,6 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest # pylint:disable=unused-import
|
|
||||||
from unittest import mock as unittest_mock
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
@ -28,7 +27,6 @@ from mediapipe.model_maker.python.vision.object_detector import object_detector_
|
||||||
from mediapipe.tasks.python.test import test_utils as task_test_utils
|
from mediapipe.tasks.python.test import test_utils as task_test_utils
|
||||||
|
|
||||||
|
|
||||||
@unittest.skip('b/275624089')
|
|
||||||
class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -51,7 +49,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def test_object_detector(self):
|
def test_object_detector(self):
|
||||||
hparams = hyperparameters.HParams(
|
hparams = hyperparameters.HParams(
|
||||||
epochs=10,
|
epochs=1,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
learning_rate=0.9,
|
learning_rate=0.9,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
@ -75,7 +73,6 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
output_tflite_file = os.path.join(
|
output_tflite_file = os.path.join(
|
||||||
options.hparams.export_dir, 'model.tflite'
|
options.hparams.export_dir, 'model.tflite'
|
||||||
)
|
)
|
||||||
print('ASDF float', os.path.getsize(output_tflite_file))
|
|
||||||
self.assertTrue(os.path.exists(output_tflite_file))
|
self.assertTrue(os.path.exists(output_tflite_file))
|
||||||
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
||||||
self.assertTrue(os.path.exists(output_metadata_file))
|
self.assertTrue(os.path.exists(output_metadata_file))
|
||||||
|
@ -85,7 +82,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
qat_hparams = hyperparameters.QATHParams(
|
qat_hparams = hyperparameters.QATHParams(
|
||||||
learning_rate=0.9,
|
learning_rate=0.9,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
epochs=5,
|
epochs=1,
|
||||||
decay_steps=6,
|
decay_steps=6,
|
||||||
decay_rate=0.96,
|
decay_rate=0.96,
|
||||||
)
|
)
|
||||||
|
@ -101,7 +98,6 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
output_tflite_file = os.path.join(
|
output_tflite_file = os.path.join(
|
||||||
options.hparams.export_dir, 'model_qat.tflite'
|
options.hparams.export_dir, 'model_qat.tflite'
|
||||||
)
|
)
|
||||||
print('ASDF qat', os.path.getsize(output_tflite_file))
|
|
||||||
self.assertTrue(os.path.exists(output_tflite_file))
|
self.assertTrue(os.path.exists(output_tflite_file))
|
||||||
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
||||||
self.assertLess(os.path.getsize(output_tflite_file), 3500000)
|
self.assertLess(os.path.getsize(output_tflite_file), 3500000)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user