Model maker gesture recognizer test changes

PiperOrigin-RevId: 488702055
This commit is contained in:
MediaPipe Team 2022-11-15 10:48:41 -08:00 committed by Copybara-Service
parent ebba119f15
commit f14645cb06

View File

@ -14,6 +14,7 @@
import io import io
import os import os
import tempfile
from unittest import mock as unittest_mock from unittest import mock as unittest_mock
import zipfile import zipfile
@ -40,30 +41,35 @@ class GestureRecognizerTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self._model_options = gesture_recognizer.ModelOptions()
self._hparams = gesture_recognizer.HParams(epochs=2)
self._gesture_recognizer_options = (
gesture_recognizer.GestureRecognizerOptions(
model_options=self._model_options, hparams=self._hparams))
all_data = self._load_data() all_data = self._load_data()
# Splits data, 90% data for training, 10% for testing # Splits data, 90% data for training, 10% for testing
self._train_data, self._test_data = all_data.split(0.9) self._train_data, self._test_data = all_data.split(0.9)
def test_gesture_recognizer_model(self): def test_gesture_recognizer_model(self):
model_options = gesture_recognizer.ModelOptions()
hparams = gesture_recognizer.HParams(
export_dir=tempfile.mkdtemp(), epochs=2)
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
model_options=model_options, hparams=hparams)
model = gesture_recognizer.GestureRecognizer.create( model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data, train_data=self._train_data,
validation_data=self._test_data, validation_data=self._test_data,
options=self._gesture_recognizer_options) options=gesture_recognizer_options)
self._test_accuracy(model) self._test_accuracy(model)
def test_export_gesture_recognizer_model(self): def test_export_gesture_recognizer_model(self):
model_options = gesture_recognizer.ModelOptions()
hparams = gesture_recognizer.HParams(
export_dir=tempfile.mkdtemp(), epochs=2)
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
model_options=model_options, hparams=hparams)
model = gesture_recognizer.GestureRecognizer.create( model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data, train_data=self._train_data,
validation_data=self._test_data, validation_data=self._test_data,
options=self._gesture_recognizer_options) options=gesture_recognizer_options)
model.export_model() model.export_model()
model_bundle_file = os.path.join(self._hparams.export_dir, model_bundle_file = os.path.join(hparams.export_dir,
'gesture_recognizer.task') 'gesture_recognizer.task')
with zipfile.ZipFile(model_bundle_file) as zf: with zipfile.ZipFile(model_bundle_file) as zf:
self.assertEqual( self.assertEqual(
@ -102,7 +108,7 @@ class GestureRecognizerTest(tf.test.TestCase):
'GestureRecognizerModelOptions', 'GestureRecognizerModelOptions',
autospec=True, autospec=True,
return_value=gesture_recognizer.ModelOptions()) return_value=gesture_recognizer.ModelOptions())
def test_create_hparams_and_model_options_if_none_in_image_classifier_options( def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options(
self, mock_hparams, mock_model_options): self, mock_hparams, mock_model_options):
options = gesture_recognizer.GestureRecognizerOptions() options = gesture_recognizer.GestureRecognizerOptions()
gesture_recognizer.GestureRecognizer.create( gesture_recognizer.GestureRecognizer.create(
@ -113,16 +119,21 @@ class GestureRecognizerTest(tf.test.TestCase):
mock_model_options.assert_called_once() mock_model_options.assert_called_once()
def test_continual_training_by_loading_checkpoint(self): def test_continual_training_by_loading_checkpoint(self):
model_options = gesture_recognizer.ModelOptions()
hparams = gesture_recognizer.HParams(
export_dir=tempfile.mkdtemp(), epochs=2)
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
model_options=model_options, hparams=hparams)
mock_stdout = io.StringIO() mock_stdout = io.StringIO()
with mock.patch('sys.stdout', mock_stdout): with mock.patch('sys.stdout', mock_stdout):
model = gesture_recognizer.GestureRecognizer.create( model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data, train_data=self._train_data,
validation_data=self._test_data, validation_data=self._test_data,
options=self._gesture_recognizer_options) options=gesture_recognizer_options)
model = gesture_recognizer.GestureRecognizer.create( model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data, train_data=self._train_data,
validation_data=self._test_data, validation_data=self._test_data,
options=self._gesture_recognizer_options) options=gesture_recognizer_options)
self._test_accuracy(model) self._test_accuracy(model)
self.assertRegex(mock_stdout.getvalue(), 'Resuming from') self.assertRegex(mock_stdout.getvalue(), 'Resuming from')