Internal change - cleanup
PiperOrigin-RevId: 486721059
This commit is contained in:
parent
2371051e17
commit
6b54cae34c
|
@ -29,21 +29,16 @@ from mediapipe.model_maker.python.core.tasks import custom_model
|
||||||
class Classifier(custom_model.CustomModel):
|
class Classifier(custom_model.CustomModel):
|
||||||
"""An abstract base class that represents a TensorFlow classifier."""
|
"""An abstract base class that represents a TensorFlow classifier."""
|
||||||
|
|
||||||
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool,
|
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool):
|
||||||
full_train: bool):
|
|
||||||
"""Initilizes a classifier with its specifications.
|
"""Initilizes a classifier with its specifications.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_spec: Specification for the model.
|
model_spec: Specification for the model.
|
||||||
label_names: A list of label names for the classes.
|
label_names: A list of label names for the classes.
|
||||||
shuffle: Whether the dataset should be shuffled.
|
shuffle: Whether the dataset should be shuffled.
|
||||||
full_train: If true, train the model end-to-end including the backbone
|
|
||||||
and the classification layers on top. Otherwise, only train the top
|
|
||||||
classification layers.
|
|
||||||
"""
|
"""
|
||||||
super(Classifier, self).__init__(model_spec, shuffle)
|
super(Classifier, self).__init__(model_spec, shuffle)
|
||||||
self._label_names = label_names
|
self._label_names = label_names
|
||||||
self._full_train = full_train
|
|
||||||
self._num_classes = len(label_names)
|
self._num_classes = len(label_names)
|
||||||
|
|
||||||
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
||||||
|
|
|
@ -38,10 +38,7 @@ class ClassifierTest(tf.test.TestCase):
|
||||||
super(ClassifierTest, self).setUp()
|
super(ClassifierTest, self).setUp()
|
||||||
label_names = ['cat', 'dog']
|
label_names = ['cat', 'dog']
|
||||||
self.model = MockClassifier(
|
self.model = MockClassifier(
|
||||||
model_spec=None,
|
model_spec=None, label_names=label_names, shuffle=False)
|
||||||
label_names=label_names,
|
|
||||||
shuffle=False,
|
|
||||||
full_train=False)
|
|
||||||
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
|
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||||
|
|
||||||
def _check_nonempty_file(self, filepath):
|
def _check_nonempty_file(self, filepath):
|
||||||
|
|
|
@ -44,10 +44,7 @@ class ImageClassifier(classifier.Classifier):
|
||||||
hparams: The hyperparameters for training image classifier.
|
hparams: The hyperparameters for training image classifier.
|
||||||
"""
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_spec=model_spec,
|
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
|
||||||
label_names=label_names,
|
|
||||||
shuffle=hparams.shuffle,
|
|
||||||
full_train=hparams.do_fine_tuning)
|
|
||||||
self._hparams = hparams
|
self._hparams = hparams
|
||||||
self._preprocess = image_preprocessing.Preprocessor(
|
self._preprocess = image_preprocessing.Preprocessor(
|
||||||
input_shape=self._model_spec.input_image_shape,
|
input_shape=self._model_spec.input_image_shape,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user