Internal change - cleanup
PiperOrigin-RevId: 486721059
This commit is contained in:
		
							parent
							
								
									2371051e17
								
							
						
					
					
						commit
						6b54cae34c
					
				|  | @ -29,21 +29,16 @@ from mediapipe.model_maker.python.core.tasks import custom_model | ||||||
| class Classifier(custom_model.CustomModel): | class Classifier(custom_model.CustomModel): | ||||||
|   """An abstract base class that represents a TensorFlow classifier.""" |   """An abstract base class that represents a TensorFlow classifier.""" | ||||||
| 
 | 
 | ||||||
|   def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool, |   def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool): | ||||||
|                full_train: bool): |  | ||||||
|     """Initilizes a classifier with its specifications. |     """Initilizes a classifier with its specifications. | ||||||
| 
 | 
 | ||||||
|     Args: |     Args: | ||||||
|         model_spec: Specification for the model. |         model_spec: Specification for the model. | ||||||
|         label_names: A list of label names for the classes. |         label_names: A list of label names for the classes. | ||||||
|         shuffle: Whether the dataset should be shuffled. |         shuffle: Whether the dataset should be shuffled. | ||||||
|         full_train: If true, train the model end-to-end including the backbone |  | ||||||
|           and the classification layers on top. Otherwise, only train the top |  | ||||||
|           classification layers. |  | ||||||
|     """ |     """ | ||||||
|     super(Classifier, self).__init__(model_spec, shuffle) |     super(Classifier, self).__init__(model_spec, shuffle) | ||||||
|     self._label_names = label_names |     self._label_names = label_names | ||||||
|     self._full_train = full_train |  | ||||||
|     self._num_classes = len(label_names) |     self._num_classes = len(label_names) | ||||||
| 
 | 
 | ||||||
|   def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: |   def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: | ||||||
|  |  | ||||||
|  | @ -38,10 +38,7 @@ class ClassifierTest(tf.test.TestCase): | ||||||
|     super(ClassifierTest, self).setUp() |     super(ClassifierTest, self).setUp() | ||||||
|     label_names = ['cat', 'dog'] |     label_names = ['cat', 'dog'] | ||||||
|     self.model = MockClassifier( |     self.model = MockClassifier( | ||||||
|         model_spec=None, |         model_spec=None, label_names=label_names, shuffle=False) | ||||||
|         label_names=label_names, |  | ||||||
|         shuffle=False, |  | ||||||
|         full_train=False) |  | ||||||
|     self.model.model = test_util.build_model(input_shape=[4], num_classes=2) |     self.model.model = test_util.build_model(input_shape=[4], num_classes=2) | ||||||
| 
 | 
 | ||||||
|   def _check_nonempty_file(self, filepath): |   def _check_nonempty_file(self, filepath): | ||||||
|  |  | ||||||
|  | @ -44,10 +44,7 @@ class ImageClassifier(classifier.Classifier): | ||||||
|       hparams: The hyperparameters for training image classifier. |       hparams: The hyperparameters for training image classifier. | ||||||
|     """ |     """ | ||||||
|     super().__init__( |     super().__init__( | ||||||
|         model_spec=model_spec, |         model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle) | ||||||
|         label_names=label_names, |  | ||||||
|         shuffle=hparams.shuffle, |  | ||||||
|         full_train=hparams.do_fine_tuning) |  | ||||||
|     self._hparams = hparams |     self._hparams = hparams | ||||||
|     self._preprocess = image_preprocessing.Preprocessor( |     self._preprocess = image_preprocessing.Preprocessor( | ||||||
|         input_shape=self._model_spec.input_image_shape, |         input_shape=self._model_spec.input_image_shape, | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user