From f14645cb06376cd1a6818a6155118ad0667d2d84 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 15 Nov 2022 10:48:41 -0800 Subject: [PATCH] Model maker gesture recognizer test changes PiperOrigin-RevId: 488702055 --- .../gesture_recognizer_test.py | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index eb2b1d171..7e7a1ca30 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -14,6 +14,7 @@ import io import os +import tempfile from unittest import mock as unittest_mock import zipfile @@ -40,30 +41,35 @@ class GestureRecognizerTest(tf.test.TestCase): def setUp(self): super().setUp() - self._model_options = gesture_recognizer.ModelOptions() - self._hparams = gesture_recognizer.HParams(epochs=2) - self._gesture_recognizer_options = ( - gesture_recognizer.GestureRecognizerOptions( - model_options=self._model_options, hparams=self._hparams)) all_data = self._load_data() # Splits data, 90% data for training, 10% for testing self._train_data, self._test_data = all_data.split(0.9) def test_gesture_recognizer_model(self): + model_options = gesture_recognizer.ModelOptions() + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=model_options, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._test_data, - options=self._gesture_recognizer_options) + options=gesture_recognizer_options) self._test_accuracy(model) def test_export_gesture_recognizer_model(self): + model_options = gesture_recognizer.ModelOptions() + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=model_options, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._test_data, - options=self._gesture_recognizer_options) + options=gesture_recognizer_options) model.export_model() - model_bundle_file = os.path.join(self._hparams.export_dir, + model_bundle_file = os.path.join(hparams.export_dir, 'gesture_recognizer.task') with zipfile.ZipFile(model_bundle_file) as zf: self.assertEqual( @@ -102,7 +108,7 @@ class GestureRecognizerTest(tf.test.TestCase): 'GestureRecognizerModelOptions', autospec=True, return_value=gesture_recognizer.ModelOptions()) - def test_create_hparams_and_model_options_if_none_in_image_classifier_options( + def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options( self, mock_hparams, mock_model_options): options = gesture_recognizer.GestureRecognizerOptions() gesture_recognizer.GestureRecognizer.create( @@ -113,16 +119,21 @@ class GestureRecognizerTest(tf.test.TestCase): mock_model_options.assert_called_once() def test_continual_training_by_loading_checkpoint(self): + model_options = gesture_recognizer.ModelOptions() + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=model_options, hparams=hparams) mock_stdout = io.StringIO() with mock.patch('sys.stdout', mock_stdout): model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._test_data, - options=self._gesture_recognizer_options) + options=gesture_recognizer_options) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._test_data, - options=self._gesture_recognizer_options) + options=gesture_recognizer_options) self._test_accuracy(model) self.assertRegex(mock_stdout.getvalue(), 'Resuming from')