Internal change - cleanup

PiperOrigin-RevId: 486721059
This commit is contained in:
MediaPipe Team 2022-11-07 11:40:43 -08:00 committed by Copybara-Service
parent 2371051e17
commit 6b54cae34c
3 changed files with 3 additions and 14 deletions

View File

@ -29,21 +29,16 @@ from mediapipe.model_maker.python.core.tasks import custom_model
class Classifier(custom_model.CustomModel): class Classifier(custom_model.CustomModel):
"""An abstract base class that represents a TensorFlow classifier.""" """An abstract base class that represents a TensorFlow classifier."""
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool, def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool):
full_train: bool):
"""Initilizes a classifier with its specifications. """Initilizes a classifier with its specifications.
Args: Args:
model_spec: Specification for the model. model_spec: Specification for the model.
label_names: A list of label names for the classes. label_names: A list of label names for the classes.
shuffle: Whether the dataset should be shuffled. shuffle: Whether the dataset should be shuffled.
full_train: If true, train the model end-to-end including the backbone
and the classification layers on top. Otherwise, only train the top
classification layers.
""" """
super(Classifier, self).__init__(model_spec, shuffle) super(Classifier, self).__init__(model_spec, shuffle)
self._label_names = label_names self._label_names = label_names
self._full_train = full_train
self._num_classes = len(label_names) self._num_classes = len(label_names)
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:

View File

@ -38,10 +38,7 @@ class ClassifierTest(tf.test.TestCase):
super(ClassifierTest, self).setUp() super(ClassifierTest, self).setUp()
label_names = ['cat', 'dog'] label_names = ['cat', 'dog']
self.model = MockClassifier( self.model = MockClassifier(
model_spec=None, model_spec=None, label_names=label_names, shuffle=False)
label_names=label_names,
shuffle=False,
full_train=False)
self.model.model = test_util.build_model(input_shape=[4], num_classes=2) self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
def _check_nonempty_file(self, filepath): def _check_nonempty_file(self, filepath):

View File

@ -44,10 +44,7 @@ class ImageClassifier(classifier.Classifier):
hparams: The hyperparameters for training image classifier. hparams: The hyperparameters for training image classifier.
""" """
super().__init__( super().__init__(
model_spec=model_spec, model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
label_names=label_names,
shuffle=hparams.shuffle,
full_train=hparams.do_fine_tuning)
self._hparams = hparams self._hparams = hparams
self._preprocess = image_preprocessing.Preprocessor( self._preprocess = image_preprocessing.Preprocessor(
input_shape=self._model_spec.input_image_shape, input_shape=self._model_spec.input_image_shape,