From 2f77bf44e3f3a53ff187bd9a39f9cbc413b4e413 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 15 Nov 2022 18:08:31 -0800 Subject: [PATCH] Use train_data to evaluate accuracy of unit test for gesture_recognizer due to limited dataset size. PiperOrigin-RevId: 488808942 --- .../gesture_recognizer_test.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 7e7a1ca30..9bac22133 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -42,8 +42,8 @@ class GestureRecognizerTest(tf.test.TestCase): def setUp(self): super().setUp() all_data = self._load_data() - # Splits data, 90% data for training, 10% for testing - self._train_data, self._test_data = all_data.split(0.9) + # Splits data, 90% data for training, 10% for validation + self._train_data, self._validation_data = all_data.split(0.9) def test_gesture_recognizer_model(self): model_options = gesture_recognizer.ModelOptions() @@ -53,7 +53,7 @@ class GestureRecognizerTest(tf.test.TestCase): model_options=model_options, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, + validation_data=self._validation_data, options=gesture_recognizer_options) self._test_accuracy(model) @@ -66,7 +66,7 @@ class GestureRecognizerTest(tf.test.TestCase): model_options=model_options, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, + validation_data=self._validation_data, options=gesture_recognizer_options) model.export_model() model_bundle_file = os.path.join(hparams.export_dir, @@ -94,8 +94,9 @@ class GestureRecognizerTest(tf.test.TestCase): size=[1, model.embedding_size]) def _test_accuracy(self, model, threshold=0.5): - _, accuracy = model.evaluate(self._test_data) - tf.compat.v1.logging.info(f'accuracy: {accuracy}') + # Test on _train_data because of our limited dataset size + _, accuracy = model.evaluate(self._train_data) + tf.compat.v1.logging.info(f'train accuracy: {accuracy}') self.assertGreaterEqual(accuracy, threshold) @unittest_mock.patch.object( @@ -113,7 +114,7 @@ class GestureRecognizerTest(tf.test.TestCase): options = gesture_recognizer.GestureRecognizerOptions() gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, + validation_data=self._validation_data, options=options) mock_hparams.assert_called_once() mock_model_options.assert_called_once() @@ -128,11 +129,11 @@ class GestureRecognizerTest(tf.test.TestCase): with mock.patch('sys.stdout', mock_stdout): model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, + validation_data=self._validation_data, options=gesture_recognizer_options) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, + validation_data=self._validation_data, options=gesture_recognizer_options) self._test_accuracy(model)