diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index f83d4059a..5d0fbd066 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -29,21 +29,16 @@ from mediapipe.model_maker.python.core.tasks import custom_model class Classifier(custom_model.CustomModel): """An abstract base class that represents a TensorFlow classifier.""" - def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool, - full_train: bool): + def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool): """Initilizes a classifier with its specifications. Args: model_spec: Specification for the model. label_names: A list of label names for the classes. shuffle: Whether the dataset should be shuffled. - full_train: If true, train the model end-to-end including the backbone - and the classification layers on top. Otherwise, only train the top - classification layers. """ super(Classifier, self).__init__(model_spec, shuffle) self._label_names = label_names - self._full_train = full_train self._num_classes = len(label_names) def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: diff --git a/mediapipe/model_maker/python/core/tasks/classifier_test.py b/mediapipe/model_maker/python/core/tasks/classifier_test.py index 52a3b97db..6bf3b7a2e 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier_test.py +++ b/mediapipe/model_maker/python/core/tasks/classifier_test.py @@ -38,10 +38,7 @@ class ClassifierTest(tf.test.TestCase): super(ClassifierTest, self).setUp() label_names = ['cat', 'dog'] self.model = MockClassifier( - model_spec=None, - label_names=label_names, - shuffle=False, - full_train=False) + model_spec=None, label_names=label_names, shuffle=False) self.model.model = test_util.build_model(input_shape=[4], num_classes=2) def _check_nonempty_file(self, filepath): diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py index 61e7c7152..569138df7 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -44,10 +44,7 @@ class ImageClassifier(classifier.Classifier): hparams: The hyperparameters for training image classifier. """ super().__init__( - model_spec=model_spec, - label_names=label_names, - shuffle=hparams.shuffle, - full_train=hparams.do_fine_tuning) + model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle) self._hparams = hparams self._preprocess = image_preprocessing.Preprocessor( input_shape=self._model_spec.input_image_shape,