Use train_data to evaluate accuracy of unit test for gesture_recognizer due to limited dataset size.
PiperOrigin-RevId: 488808942
This commit is contained in:
parent
b308c0dd5e
commit
2f77bf44e3
|
@ -42,8 +42,8 @@ class GestureRecognizerTest(tf.test.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
all_data = self._load_data()
|
all_data = self._load_data()
|
||||||
# Splits data, 90% data for training, 10% for testing
|
# Splits data, 90% data for training, 10% for validation
|
||||||
self._train_data, self._test_data = all_data.split(0.9)
|
self._train_data, self._validation_data = all_data.split(0.9)
|
||||||
|
|
||||||
def test_gesture_recognizer_model(self):
|
def test_gesture_recognizer_model(self):
|
||||||
model_options = gesture_recognizer.ModelOptions()
|
model_options = gesture_recognizer.ModelOptions()
|
||||||
|
@ -53,7 +53,7 @@ class GestureRecognizerTest(tf.test.TestCase):
|
||||||
model_options=model_options, hparams=hparams)
|
model_options=model_options, hparams=hparams)
|
||||||
model = gesture_recognizer.GestureRecognizer.create(
|
model = gesture_recognizer.GestureRecognizer.create(
|
||||||
train_data=self._train_data,
|
train_data=self._train_data,
|
||||||
validation_data=self._test_data,
|
validation_data=self._validation_data,
|
||||||
options=gesture_recognizer_options)
|
options=gesture_recognizer_options)
|
||||||
|
|
||||||
self._test_accuracy(model)
|
self._test_accuracy(model)
|
||||||
|
@ -66,7 +66,7 @@ class GestureRecognizerTest(tf.test.TestCase):
|
||||||
model_options=model_options, hparams=hparams)
|
model_options=model_options, hparams=hparams)
|
||||||
model = gesture_recognizer.GestureRecognizer.create(
|
model = gesture_recognizer.GestureRecognizer.create(
|
||||||
train_data=self._train_data,
|
train_data=self._train_data,
|
||||||
validation_data=self._test_data,
|
validation_data=self._validation_data,
|
||||||
options=gesture_recognizer_options)
|
options=gesture_recognizer_options)
|
||||||
model.export_model()
|
model.export_model()
|
||||||
model_bundle_file = os.path.join(hparams.export_dir,
|
model_bundle_file = os.path.join(hparams.export_dir,
|
||||||
|
@ -94,8 +94,9 @@ class GestureRecognizerTest(tf.test.TestCase):
|
||||||
size=[1, model.embedding_size])
|
size=[1, model.embedding_size])
|
||||||
|
|
||||||
def _test_accuracy(self, model, threshold=0.5):
|
def _test_accuracy(self, model, threshold=0.5):
|
||||||
_, accuracy = model.evaluate(self._test_data)
|
# Test on _train_data because of our limited dataset size
|
||||||
tf.compat.v1.logging.info(f'accuracy: {accuracy}')
|
_, accuracy = model.evaluate(self._train_data)
|
||||||
|
tf.compat.v1.logging.info(f'train accuracy: {accuracy}')
|
||||||
self.assertGreaterEqual(accuracy, threshold)
|
self.assertGreaterEqual(accuracy, threshold)
|
||||||
|
|
||||||
@unittest_mock.patch.object(
|
@unittest_mock.patch.object(
|
||||||
|
@ -113,7 +114,7 @@ class GestureRecognizerTest(tf.test.TestCase):
|
||||||
options = gesture_recognizer.GestureRecognizerOptions()
|
options = gesture_recognizer.GestureRecognizerOptions()
|
||||||
gesture_recognizer.GestureRecognizer.create(
|
gesture_recognizer.GestureRecognizer.create(
|
||||||
train_data=self._train_data,
|
train_data=self._train_data,
|
||||||
validation_data=self._test_data,
|
validation_data=self._validation_data,
|
||||||
options=options)
|
options=options)
|
||||||
mock_hparams.assert_called_once()
|
mock_hparams.assert_called_once()
|
||||||
mock_model_options.assert_called_once()
|
mock_model_options.assert_called_once()
|
||||||
|
@ -128,11 +129,11 @@ class GestureRecognizerTest(tf.test.TestCase):
|
||||||
with mock.patch('sys.stdout', mock_stdout):
|
with mock.patch('sys.stdout', mock_stdout):
|
||||||
model = gesture_recognizer.GestureRecognizer.create(
|
model = gesture_recognizer.GestureRecognizer.create(
|
||||||
train_data=self._train_data,
|
train_data=self._train_data,
|
||||||
validation_data=self._test_data,
|
validation_data=self._validation_data,
|
||||||
options=gesture_recognizer_options)
|
options=gesture_recognizer_options)
|
||||||
model = gesture_recognizer.GestureRecognizer.create(
|
model = gesture_recognizer.GestureRecognizer.create(
|
||||||
train_data=self._train_data,
|
train_data=self._train_data,
|
||||||
validation_data=self._test_data,
|
validation_data=self._validation_data,
|
||||||
options=gesture_recognizer_options)
|
options=gesture_recognizer_options)
|
||||||
self._test_accuracy(model)
|
self._test_accuracy(model)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user