Use train_data to evaluate accuracy of unit test for gesture_recognizer due to limited dataset size.
PiperOrigin-RevId: 488808942
This commit is contained in:
parent
b308c0dd5e
commit
2f77bf44e3
|
@ -42,8 +42,8 @@ class GestureRecognizerTest(tf.test.TestCase):
|
|||
def setUp(self):
|
||||
super().setUp()
|
||||
all_data = self._load_data()
|
||||
# Splits data, 90% data for training, 10% for testing
|
||||
self._train_data, self._test_data = all_data.split(0.9)
|
||||
# Splits data, 90% data for training, 10% for validation
|
||||
self._train_data, self._validation_data = all_data.split(0.9)
|
||||
|
||||
def test_gesture_recognizer_model(self):
|
||||
model_options = gesture_recognizer.ModelOptions()
|
||||
|
@ -53,7 +53,7 @@ class GestureRecognizerTest(tf.test.TestCase):
|
|||
model_options=model_options, hparams=hparams)
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
validation_data=self._validation_data,
|
||||
options=gesture_recognizer_options)
|
||||
|
||||
self._test_accuracy(model)
|
||||
|
@ -66,7 +66,7 @@ class GestureRecognizerTest(tf.test.TestCase):
|
|||
model_options=model_options, hparams=hparams)
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
validation_data=self._validation_data,
|
||||
options=gesture_recognizer_options)
|
||||
model.export_model()
|
||||
model_bundle_file = os.path.join(hparams.export_dir,
|
||||
|
@ -94,8 +94,9 @@ class GestureRecognizerTest(tf.test.TestCase):
|
|||
size=[1, model.embedding_size])
|
||||
|
||||
def _test_accuracy(self, model, threshold=0.5):
|
||||
_, accuracy = model.evaluate(self._test_data)
|
||||
tf.compat.v1.logging.info(f'accuracy: {accuracy}')
|
||||
# Test on _train_data because of our limited dataset size
|
||||
_, accuracy = model.evaluate(self._train_data)
|
||||
tf.compat.v1.logging.info(f'train accuracy: {accuracy}')
|
||||
self.assertGreaterEqual(accuracy, threshold)
|
||||
|
||||
@unittest_mock.patch.object(
|
||||
|
@ -113,7 +114,7 @@ class GestureRecognizerTest(tf.test.TestCase):
|
|||
options = gesture_recognizer.GestureRecognizerOptions()
|
||||
gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
validation_data=self._validation_data,
|
||||
options=options)
|
||||
mock_hparams.assert_called_once()
|
||||
mock_model_options.assert_called_once()
|
||||
|
@ -128,11 +129,11 @@ class GestureRecognizerTest(tf.test.TestCase):
|
|||
with mock.patch('sys.stdout', mock_stdout):
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
validation_data=self._validation_data,
|
||||
options=gesture_recognizer_options)
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
validation_data=self._validation_data,
|
||||
options=gesture_recognizer_options)
|
||||
self._test_accuracy(model)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user