Model maker gesture recognizer test changes
PiperOrigin-RevId: 488702055
This commit is contained in:
parent
ebba119f15
commit
f14645cb06
|
@ -14,6 +14,7 @@
|
|||
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
from unittest import mock as unittest_mock
|
||||
import zipfile
|
||||
|
||||
|
@ -40,30 +41,35 @@ class GestureRecognizerTest(tf.test.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._model_options = gesture_recognizer.ModelOptions()
|
||||
self._hparams = gesture_recognizer.HParams(epochs=2)
|
||||
self._gesture_recognizer_options = (
|
||||
gesture_recognizer.GestureRecognizerOptions(
|
||||
model_options=self._model_options, hparams=self._hparams))
|
||||
all_data = self._load_data()
|
||||
# Splits data, 90% data for training, 10% for testing
|
||||
self._train_data, self._test_data = all_data.split(0.9)
|
||||
|
||||
def test_gesture_recognizer_model(self):
|
||||
model_options = gesture_recognizer.ModelOptions()
|
||||
hparams = gesture_recognizer.HParams(
|
||||
export_dir=tempfile.mkdtemp(), epochs=2)
|
||||
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
|
||||
model_options=model_options, hparams=hparams)
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=self._gesture_recognizer_options)
|
||||
options=gesture_recognizer_options)
|
||||
|
||||
self._test_accuracy(model)
|
||||
|
||||
def test_export_gesture_recognizer_model(self):
|
||||
model_options = gesture_recognizer.ModelOptions()
|
||||
hparams = gesture_recognizer.HParams(
|
||||
export_dir=tempfile.mkdtemp(), epochs=2)
|
||||
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
|
||||
model_options=model_options, hparams=hparams)
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=self._gesture_recognizer_options)
|
||||
options=gesture_recognizer_options)
|
||||
model.export_model()
|
||||
model_bundle_file = os.path.join(self._hparams.export_dir,
|
||||
model_bundle_file = os.path.join(hparams.export_dir,
|
||||
'gesture_recognizer.task')
|
||||
with zipfile.ZipFile(model_bundle_file) as zf:
|
||||
self.assertEqual(
|
||||
|
@ -102,7 +108,7 @@ class GestureRecognizerTest(tf.test.TestCase):
|
|||
'GestureRecognizerModelOptions',
|
||||
autospec=True,
|
||||
return_value=gesture_recognizer.ModelOptions())
|
||||
def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
|
||||
def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options(
|
||||
self, mock_hparams, mock_model_options):
|
||||
options = gesture_recognizer.GestureRecognizerOptions()
|
||||
gesture_recognizer.GestureRecognizer.create(
|
||||
|
@ -113,16 +119,21 @@ class GestureRecognizerTest(tf.test.TestCase):
|
|||
mock_model_options.assert_called_once()
|
||||
|
||||
def test_continual_training_by_loading_checkpoint(self):
|
||||
model_options = gesture_recognizer.ModelOptions()
|
||||
hparams = gesture_recognizer.HParams(
|
||||
export_dir=tempfile.mkdtemp(), epochs=2)
|
||||
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
|
||||
model_options=model_options, hparams=hparams)
|
||||
mock_stdout = io.StringIO()
|
||||
with mock.patch('sys.stdout', mock_stdout):
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=self._gesture_recognizer_options)
|
||||
options=gesture_recognizer_options)
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=self._gesture_recognizer_options)
|
||||
options=gesture_recognizer_options)
|
||||
self._test_accuracy(model)
|
||||
|
||||
self.assertRegex(mock_stdout.getvalue(), 'Resuming from')
|
||||
|
|
Loading…
Reference in New Issue
Block a user