Use train_data to evaluate accuracy of unit test for gesture_recognizer due to limited dataset size.

PiperOrigin-RevId: 488808942
This commit is contained in:
MediaPipe Team 2022-11-15 18:08:31 -08:00 committed by Copybara-Service
parent b308c0dd5e
commit 2f77bf44e3

View File

@ -42,8 +42,8 @@ class GestureRecognizerTest(tf.test.TestCase):
def setUp(self):
super().setUp()
all_data = self._load_data()
# Splits data, 90% data for training, 10% for testing
self._train_data, self._test_data = all_data.split(0.9)
# Splits data, 90% data for training, 10% for validation
self._train_data, self._validation_data = all_data.split(0.9)
def test_gesture_recognizer_model(self):
model_options = gesture_recognizer.ModelOptions()
@ -53,7 +53,7 @@ class GestureRecognizerTest(tf.test.TestCase):
model_options=model_options, hparams=hparams)
model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._test_data,
validation_data=self._validation_data,
options=gesture_recognizer_options)
self._test_accuracy(model)
@ -66,7 +66,7 @@ class GestureRecognizerTest(tf.test.TestCase):
model_options=model_options, hparams=hparams)
model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._test_data,
validation_data=self._validation_data,
options=gesture_recognizer_options)
model.export_model()
model_bundle_file = os.path.join(hparams.export_dir,
@ -94,8 +94,9 @@ class GestureRecognizerTest(tf.test.TestCase):
size=[1, model.embedding_size])
def _test_accuracy(self, model, threshold=0.5):
_, accuracy = model.evaluate(self._test_data)
tf.compat.v1.logging.info(f'accuracy: {accuracy}')
# Test on _train_data because of our limited dataset size
_, accuracy = model.evaluate(self._train_data)
tf.compat.v1.logging.info(f'train accuracy: {accuracy}')
self.assertGreaterEqual(accuracy, threshold)
@unittest_mock.patch.object(
@ -113,7 +114,7 @@ class GestureRecognizerTest(tf.test.TestCase):
options = gesture_recognizer.GestureRecognizerOptions()
gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._test_data,
validation_data=self._validation_data,
options=options)
mock_hparams.assert_called_once()
mock_model_options.assert_called_once()
@ -128,11 +129,11 @@ class GestureRecognizerTest(tf.test.TestCase):
with mock.patch('sys.stdout', mock_stdout):
model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._test_data,
validation_data=self._validation_data,
options=gesture_recognizer_options)
model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._test_data,
validation_data=self._validation_data,
options=gesture_recognizer_options)
self._test_accuracy(model)