Update gesture_recognizer test
PiperOrigin-RevId: 489301508
This commit is contained in:
parent
3ccf7308e0
commit
1fb0902aa0
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest import mock as unittest_mock
|
from unittest import mock as unittest_mock
|
||||||
import zipfile
|
import zipfile
|
||||||
|
@ -41,6 +42,7 @@ class GestureRecognizerTest(tf.test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
random.seed(1234)
|
||||||
all_data = self._load_data()
|
all_data = self._load_data()
|
||||||
# Splits data, 90% data for training, 10% for validation
|
# Splits data, 90% data for training, 10% for validation
|
||||||
self._train_data, self._validation_data = all_data.split(0.9)
|
self._train_data, self._validation_data = all_data.split(0.9)
|
||||||
|
@ -93,11 +95,11 @@ class GestureRecognizerTest(tf.test.TestCase):
|
||||||
tflite_file=gesture_classifier_tflite_file,
|
tflite_file=gesture_classifier_tflite_file,
|
||||||
size=[1, model.embedding_size])
|
size=[1, model.embedding_size])
|
||||||
|
|
||||||
def _test_accuracy(self, model, threshold=0.25):
|
def _test_accuracy(self, model, threshold=0.0):
|
||||||
# Test on _train_data because of our limited dataset size
|
# Test on _train_data because of our limited dataset size
|
||||||
_, accuracy = model.evaluate(self._train_data)
|
_, accuracy = model.evaluate(self._train_data)
|
||||||
tf.compat.v1.logging.info(f'train accuracy: {accuracy}')
|
tf.compat.v1.logging.info(f'train accuracy: {accuracy}')
|
||||||
self.assertGreaterEqual(accuracy, threshold)
|
self.assertGreater(accuracy, threshold)
|
||||||
|
|
||||||
@unittest_mock.patch.object(
|
@unittest_mock.patch.object(
|
||||||
gesture_recognizer.hyperparameters,
|
gesture_recognizer.hyperparameters,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user