Internal change - cleanup
PiperOrigin-RevId: 486721059
This commit is contained in:
parent
2371051e17
commit
6b54cae34c
|
@ -29,21 +29,16 @@ from mediapipe.model_maker.python.core.tasks import custom_model
|
|||
class Classifier(custom_model.CustomModel):
|
||||
"""An abstract base class that represents a TensorFlow classifier."""
|
||||
|
||||
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool,
|
||||
full_train: bool):
|
||||
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool):
|
||||
"""Initilizes a classifier with its specifications.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
label_names: A list of label names for the classes.
|
||||
shuffle: Whether the dataset should be shuffled.
|
||||
full_train: If true, train the model end-to-end including the backbone
|
||||
and the classification layers on top. Otherwise, only train the top
|
||||
classification layers.
|
||||
"""
|
||||
super(Classifier, self).__init__(model_spec, shuffle)
|
||||
self._label_names = label_names
|
||||
self._full_train = full_train
|
||||
self._num_classes = len(label_names)
|
||||
|
||||
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
||||
|
|
|
@ -38,10 +38,7 @@ class ClassifierTest(tf.test.TestCase):
|
|||
super(ClassifierTest, self).setUp()
|
||||
label_names = ['cat', 'dog']
|
||||
self.model = MockClassifier(
|
||||
model_spec=None,
|
||||
label_names=label_names,
|
||||
shuffle=False,
|
||||
full_train=False)
|
||||
model_spec=None, label_names=label_names, shuffle=False)
|
||||
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||
|
||||
def _check_nonempty_file(self, filepath):
|
||||
|
|
|
@ -44,10 +44,7 @@ class ImageClassifier(classifier.Classifier):
|
|||
hparams: The hyperparameters for training image classifier.
|
||||
"""
|
||||
super().__init__(
|
||||
model_spec=model_spec,
|
||||
label_names=label_names,
|
||||
shuffle=hparams.shuffle,
|
||||
full_train=hparams.do_fine_tuning)
|
||||
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
|
||||
self._hparams = hparams
|
||||
self._preprocess = image_preprocessing.Preprocessor(
|
||||
input_shape=self._model_spec.input_image_shape,
|
||||
|
|
Loading…
Reference in New Issue
Block a user