Model maker gesture recognizer test changes
PiperOrigin-RevId: 488702055
This commit is contained in:
parent
ebba119f15
commit
f14645cb06
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
from unittest import mock as unittest_mock
|
from unittest import mock as unittest_mock
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
|
@ -40,30 +41,35 @@ class GestureRecognizerTest(tf.test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self._model_options = gesture_recognizer.ModelOptions()
|
|
||||||
self._hparams = gesture_recognizer.HParams(epochs=2)
|
|
||||||
self._gesture_recognizer_options = (
|
|
||||||
gesture_recognizer.GestureRecognizerOptions(
|
|
||||||
model_options=self._model_options, hparams=self._hparams))
|
|
||||||
all_data = self._load_data()
|
all_data = self._load_data()
|
||||||
# Splits data, 90% data for training, 10% for testing
|
# Splits data, 90% data for training, 10% for testing
|
||||||
self._train_data, self._test_data = all_data.split(0.9)
|
self._train_data, self._test_data = all_data.split(0.9)
|
||||||
|
|
||||||
def test_gesture_recognizer_model(self):
|
def test_gesture_recognizer_model(self):
|
||||||
|
model_options = gesture_recognizer.ModelOptions()
|
||||||
|
hparams = gesture_recognizer.HParams(
|
||||||
|
export_dir=tempfile.mkdtemp(), epochs=2)
|
||||||
|
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
|
||||||
|
model_options=model_options, hparams=hparams)
|
||||||
model = gesture_recognizer.GestureRecognizer.create(
|
model = gesture_recognizer.GestureRecognizer.create(
|
||||||
train_data=self._train_data,
|
train_data=self._train_data,
|
||||||
validation_data=self._test_data,
|
validation_data=self._test_data,
|
||||||
options=self._gesture_recognizer_options)
|
options=gesture_recognizer_options)
|
||||||
|
|
||||||
self._test_accuracy(model)
|
self._test_accuracy(model)
|
||||||
|
|
||||||
def test_export_gesture_recognizer_model(self):
|
def test_export_gesture_recognizer_model(self):
|
||||||
|
model_options = gesture_recognizer.ModelOptions()
|
||||||
|
hparams = gesture_recognizer.HParams(
|
||||||
|
export_dir=tempfile.mkdtemp(), epochs=2)
|
||||||
|
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
|
||||||
|
model_options=model_options, hparams=hparams)
|
||||||
model = gesture_recognizer.GestureRecognizer.create(
|
model = gesture_recognizer.GestureRecognizer.create(
|
||||||
train_data=self._train_data,
|
train_data=self._train_data,
|
||||||
validation_data=self._test_data,
|
validation_data=self._test_data,
|
||||||
options=self._gesture_recognizer_options)
|
options=gesture_recognizer_options)
|
||||||
model.export_model()
|
model.export_model()
|
||||||
model_bundle_file = os.path.join(self._hparams.export_dir,
|
model_bundle_file = os.path.join(hparams.export_dir,
|
||||||
'gesture_recognizer.task')
|
'gesture_recognizer.task')
|
||||||
with zipfile.ZipFile(model_bundle_file) as zf:
|
with zipfile.ZipFile(model_bundle_file) as zf:
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -102,7 +108,7 @@ class GestureRecognizerTest(tf.test.TestCase):
|
||||||
'GestureRecognizerModelOptions',
|
'GestureRecognizerModelOptions',
|
||||||
autospec=True,
|
autospec=True,
|
||||||
return_value=gesture_recognizer.ModelOptions())
|
return_value=gesture_recognizer.ModelOptions())
|
||||||
def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
|
def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options(
|
||||||
self, mock_hparams, mock_model_options):
|
self, mock_hparams, mock_model_options):
|
||||||
options = gesture_recognizer.GestureRecognizerOptions()
|
options = gesture_recognizer.GestureRecognizerOptions()
|
||||||
gesture_recognizer.GestureRecognizer.create(
|
gesture_recognizer.GestureRecognizer.create(
|
||||||
|
@ -113,16 +119,21 @@ class GestureRecognizerTest(tf.test.TestCase):
|
||||||
mock_model_options.assert_called_once()
|
mock_model_options.assert_called_once()
|
||||||
|
|
||||||
def test_continual_training_by_loading_checkpoint(self):
|
def test_continual_training_by_loading_checkpoint(self):
|
||||||
|
model_options = gesture_recognizer.ModelOptions()
|
||||||
|
hparams = gesture_recognizer.HParams(
|
||||||
|
export_dir=tempfile.mkdtemp(), epochs=2)
|
||||||
|
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
|
||||||
|
model_options=model_options, hparams=hparams)
|
||||||
mock_stdout = io.StringIO()
|
mock_stdout = io.StringIO()
|
||||||
with mock.patch('sys.stdout', mock_stdout):
|
with mock.patch('sys.stdout', mock_stdout):
|
||||||
model = gesture_recognizer.GestureRecognizer.create(
|
model = gesture_recognizer.GestureRecognizer.create(
|
||||||
train_data=self._train_data,
|
train_data=self._train_data,
|
||||||
validation_data=self._test_data,
|
validation_data=self._test_data,
|
||||||
options=self._gesture_recognizer_options)
|
options=gesture_recognizer_options)
|
||||||
model = gesture_recognizer.GestureRecognizer.create(
|
model = gesture_recognizer.GestureRecognizer.create(
|
||||||
train_data=self._train_data,
|
train_data=self._train_data,
|
||||||
validation_data=self._test_data,
|
validation_data=self._test_data,
|
||||||
options=self._gesture_recognizer_options)
|
options=gesture_recognizer_options)
|
||||||
self._test_accuracy(model)
|
self._test_accuracy(model)
|
||||||
|
|
||||||
self.assertRegex(mock_stdout.getvalue(), 'Resuming from')
|
self.assertRegex(mock_stdout.getvalue(), 'Resuming from')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user