Use train_data to evaluate accuracy of unit test for gesture_recognizer due to limited dataset size.
PiperOrigin-RevId: 488808942
This commit is contained in:
		
							parent
							
								
									b308c0dd5e
								
							
						
					
					
						commit
						2f77bf44e3
					
				| 
						 | 
					@ -42,8 +42,8 @@ class GestureRecognizerTest(tf.test.TestCase):
 | 
				
			||||||
  def setUp(self):
 | 
					  def setUp(self):
 | 
				
			||||||
    super().setUp()
 | 
					    super().setUp()
 | 
				
			||||||
    all_data = self._load_data()
 | 
					    all_data = self._load_data()
 | 
				
			||||||
    # Splits data, 90% data for training, 10% for testing
 | 
					    # Splits data, 90% data for training, 10% for validation
 | 
				
			||||||
    self._train_data, self._test_data = all_data.split(0.9)
 | 
					    self._train_data, self._validation_data = all_data.split(0.9)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_gesture_recognizer_model(self):
 | 
					  def test_gesture_recognizer_model(self):
 | 
				
			||||||
    model_options = gesture_recognizer.ModelOptions()
 | 
					    model_options = gesture_recognizer.ModelOptions()
 | 
				
			||||||
| 
						 | 
					@ -53,7 +53,7 @@ class GestureRecognizerTest(tf.test.TestCase):
 | 
				
			||||||
        model_options=model_options, hparams=hparams)
 | 
					        model_options=model_options, hparams=hparams)
 | 
				
			||||||
    model = gesture_recognizer.GestureRecognizer.create(
 | 
					    model = gesture_recognizer.GestureRecognizer.create(
 | 
				
			||||||
        train_data=self._train_data,
 | 
					        train_data=self._train_data,
 | 
				
			||||||
        validation_data=self._test_data,
 | 
					        validation_data=self._validation_data,
 | 
				
			||||||
        options=gesture_recognizer_options)
 | 
					        options=gesture_recognizer_options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    self._test_accuracy(model)
 | 
					    self._test_accuracy(model)
 | 
				
			||||||
| 
						 | 
					@ -66,7 +66,7 @@ class GestureRecognizerTest(tf.test.TestCase):
 | 
				
			||||||
        model_options=model_options, hparams=hparams)
 | 
					        model_options=model_options, hparams=hparams)
 | 
				
			||||||
    model = gesture_recognizer.GestureRecognizer.create(
 | 
					    model = gesture_recognizer.GestureRecognizer.create(
 | 
				
			||||||
        train_data=self._train_data,
 | 
					        train_data=self._train_data,
 | 
				
			||||||
        validation_data=self._test_data,
 | 
					        validation_data=self._validation_data,
 | 
				
			||||||
        options=gesture_recognizer_options)
 | 
					        options=gesture_recognizer_options)
 | 
				
			||||||
    model.export_model()
 | 
					    model.export_model()
 | 
				
			||||||
    model_bundle_file = os.path.join(hparams.export_dir,
 | 
					    model_bundle_file = os.path.join(hparams.export_dir,
 | 
				
			||||||
| 
						 | 
					@ -94,8 +94,9 @@ class GestureRecognizerTest(tf.test.TestCase):
 | 
				
			||||||
        size=[1, model.embedding_size])
 | 
					        size=[1, model.embedding_size])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def _test_accuracy(self, model, threshold=0.5):
 | 
					  def _test_accuracy(self, model, threshold=0.5):
 | 
				
			||||||
    _, accuracy = model.evaluate(self._test_data)
 | 
					    # Test on _train_data because of our limited dataset size
 | 
				
			||||||
    tf.compat.v1.logging.info(f'accuracy: {accuracy}')
 | 
					    _, accuracy = model.evaluate(self._train_data)
 | 
				
			||||||
 | 
					    tf.compat.v1.logging.info(f'train accuracy: {accuracy}')
 | 
				
			||||||
    self.assertGreaterEqual(accuracy, threshold)
 | 
					    self.assertGreaterEqual(accuracy, threshold)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @unittest_mock.patch.object(
 | 
					  @unittest_mock.patch.object(
 | 
				
			||||||
| 
						 | 
					@ -113,7 +114,7 @@ class GestureRecognizerTest(tf.test.TestCase):
 | 
				
			||||||
    options = gesture_recognizer.GestureRecognizerOptions()
 | 
					    options = gesture_recognizer.GestureRecognizerOptions()
 | 
				
			||||||
    gesture_recognizer.GestureRecognizer.create(
 | 
					    gesture_recognizer.GestureRecognizer.create(
 | 
				
			||||||
        train_data=self._train_data,
 | 
					        train_data=self._train_data,
 | 
				
			||||||
        validation_data=self._test_data,
 | 
					        validation_data=self._validation_data,
 | 
				
			||||||
        options=options)
 | 
					        options=options)
 | 
				
			||||||
    mock_hparams.assert_called_once()
 | 
					    mock_hparams.assert_called_once()
 | 
				
			||||||
    mock_model_options.assert_called_once()
 | 
					    mock_model_options.assert_called_once()
 | 
				
			||||||
| 
						 | 
					@ -128,11 +129,11 @@ class GestureRecognizerTest(tf.test.TestCase):
 | 
				
			||||||
    with mock.patch('sys.stdout', mock_stdout):
 | 
					    with mock.patch('sys.stdout', mock_stdout):
 | 
				
			||||||
      model = gesture_recognizer.GestureRecognizer.create(
 | 
					      model = gesture_recognizer.GestureRecognizer.create(
 | 
				
			||||||
          train_data=self._train_data,
 | 
					          train_data=self._train_data,
 | 
				
			||||||
          validation_data=self._test_data,
 | 
					          validation_data=self._validation_data,
 | 
				
			||||||
          options=gesture_recognizer_options)
 | 
					          options=gesture_recognizer_options)
 | 
				
			||||||
      model = gesture_recognizer.GestureRecognizer.create(
 | 
					      model = gesture_recognizer.GestureRecognizer.create(
 | 
				
			||||||
          train_data=self._train_data,
 | 
					          train_data=self._train_data,
 | 
				
			||||||
          validation_data=self._test_data,
 | 
					          validation_data=self._validation_data,
 | 
				
			||||||
          options=gesture_recognizer_options)
 | 
					          options=gesture_recognizer_options)
 | 
				
			||||||
      self._test_accuracy(model)
 | 
					      self._test_accuracy(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user