Rename index_by_label to label_names.

PiperOrigin-RevId: 484956259
This commit is contained in:
MediaPipe Team 2022-10-30 22:24:19 -07:00 committed by Copybara-Service
parent 7bcf322625
commit 459214e6a3
7 changed files with 36 additions and 41 deletions

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Common classification dataset library.""" """Common classification dataset library."""
from typing import Any, Tuple from typing import List, Tuple
import tensorflow as tf import tensorflow as tf
@ -21,19 +21,20 @@ from mediapipe.model_maker.python.core.data import dataset as ds
class ClassificationDataset(ds.Dataset): class ClassificationDataset(ds.Dataset):
"""DataLoader for classification models.""" """Dataset Loader for classification models."""
def __init__(self, dataset: tf.data.Dataset, size: int, index_by_label: Any): def __init__(self, dataset: tf.data.Dataset, size: int,
label_names: List[str]):
super().__init__(dataset, size) super().__init__(dataset, size)
self._index_by_label = index_by_label self._label_names = label_names
@property @property
def num_classes(self: ds._DatasetT) -> int: def num_classes(self: ds._DatasetT) -> int:
return len(self._index_by_label) return len(self._label_names)
@property @property
def index_by_label(self: ds._DatasetT) -> Any: def label_names(self: ds._DatasetT) -> List[str]:
return self._index_by_label return self._label_names
def split(self: ds._DatasetT, def split(self: ds._DatasetT,
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]: fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
@ -48,4 +49,4 @@ class ClassificationDataset(ds.Dataset):
Returns: Returns:
The splitted two sub datasets. The splitted two sub datasets.
""" """
return self._split(fraction, self._index_by_label) return self._split(fraction, self._label_names)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Tuple, TypeVar from typing import Any, List, Tuple, TypeVar
# Dependency imports # Dependency imports
@ -37,26 +37,22 @@ class ClassificationDatasetTest(tf.test.TestCase):
""" """
def __init__(self, dataset: tf.data.Dataset, size: int, def __init__(self, dataset: tf.data.Dataset, size: int,
index_by_label: Any, value: Any): label_names: List[str], value: Any):
super().__init__( super().__init__(dataset=dataset, size=size, label_names=label_names)
dataset=dataset, size=size, index_by_label=index_by_label)
self.value = value self.value = value
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]: def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
return self._split(fraction, self.index_by_label, self.value) return self._split(fraction, self.label_names, self.value)
# Some dummy inputs. # Some dummy inputs.
magic_value = 42 magic_value = 42
num_classes = 2 num_classes = 2
index_by_label = (False, True) label_names = ['foo', 'bar']
# Create data loader from sample data. # Create data loader from sample data.
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = MagicClassificationDataset( data = MagicClassificationDataset(
dataset=ds, dataset=ds, size=len(ds), label_names=label_names, value=magic_value)
size=len(ds),
index_by_label=index_by_label,
value=magic_value)
# Train/Test data split. # Train/Test data split.
fraction = .25 fraction = .25
@ -73,7 +69,7 @@ class ClassificationDatasetTest(tf.test.TestCase):
# Make sure attributes propagated correctly. # Make sure attributes propagated correctly.
self.assertEqual(train_data.num_classes, num_classes) self.assertEqual(train_data.num_classes, num_classes)
self.assertEqual(test_data.index_by_label, index_by_label) self.assertEqual(test_data.label_names, label_names)
self.assertEqual(train_data.value, magic_value) self.assertEqual(train_data.value, magic_value)
self.assertEqual(test_data.value, magic_value) self.assertEqual(test_data.value, magic_value)

View File

@ -29,22 +29,22 @@ from mediapipe.model_maker.python.core.tasks import custom_model
class Classifier(custom_model.CustomModel): class Classifier(custom_model.CustomModel):
"""An abstract base class that represents a TensorFlow classifier.""" """An abstract base class that represents a TensorFlow classifier."""
def __init__(self, model_spec: Any, index_by_label: List[str], shuffle: bool, def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool,
full_train: bool): full_train: bool):
"""Initilizes a classifier with its specifications. """Initilizes a classifier with its specifications.
Args: Args:
model_spec: Specification for the model. model_spec: Specification for the model.
index_by_label: A list that map from index to label class name. label_names: A list of label names for the classes.
shuffle: Whether the dataset should be shuffled. shuffle: Whether the dataset should be shuffled.
full_train: If true, train the model end-to-end including the backbone full_train: If true, train the model end-to-end including the backbone
and the classification layers on top. Otherwise, only train the top and the classification layers on top. Otherwise, only train the top
classification layers. classification layers.
""" """
super(Classifier, self).__init__(model_spec, shuffle) super(Classifier, self).__init__(model_spec, shuffle)
self._index_by_label = index_by_label self._label_names = label_names
self._full_train = full_train self._full_train = full_train
self._num_classes = len(index_by_label) self._num_classes = len(label_names)
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
"""Evaluates the classifier with the provided evaluation dataset. """Evaluates the classifier with the provided evaluation dataset.
@ -74,4 +74,4 @@ class Classifier(custom_model.CustomModel):
label_filepath = os.path.join(export_dir, label_filename) label_filepath = os.path.join(export_dir, label_filename)
tf.compat.v1.logging.info('Saving labels in %s', label_filepath) tf.compat.v1.logging.info('Saving labels in %s', label_filepath)
with tf.io.gfile.GFile(label_filepath, 'w') as f: with tf.io.gfile.GFile(label_filepath, 'w') as f:
f.write('\n'.join(self._index_by_label)) f.write('\n'.join(self._label_names))

View File

@ -36,10 +36,10 @@ class ClassifierTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(ClassifierTest, self).setUp() super(ClassifierTest, self).setUp()
index_by_label = ['cat', 'dog'] label_names = ['cat', 'dog']
self.model = MockClassifier( self.model = MockClassifier(
model_spec=None, model_spec=None,
index_by_label=index_by_label, label_names=label_names,
shuffle=False, shuffle=False,
full_train=False) full_train=False)
self.model.model = test_util.build_model(input_shape=[4], num_classes=2) self.model.model = test_util.build_model(input_shape=[4], num_classes=2)

View File

@ -106,4 +106,4 @@ class Dataset(classification_dataset.ClassificationDataset):
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size, 'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
all_label_size, ', '.join(label_names)) all_label_size, ', '.join(label_names))
return Dataset( return Dataset(
dataset=image_label_ds, size=all_image_size, index_by_label=label_names) dataset=image_label_ds, size=all_image_size, label_names=label_names)

View File

@ -49,27 +49,27 @@ class DatasetTest(tf.test.TestCase):
def test_split(self): def test_split(self):
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = dataset.Dataset(dataset=ds, size=4, index_by_label=['pos', 'neg']) data = dataset.Dataset(dataset=ds, size=4, label_names=['pos', 'neg'])
train_data, test_data = data.split(fraction=0.5) train_data, test_data = data.split(fraction=0.5)
self.assertLen(train_data, 2) self.assertLen(train_data, 2)
for i, elem in enumerate(train_data._dataset): for i, elem in enumerate(train_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 1])).all()) self.assertTrue((elem.numpy() == np.array([i, 1])).all())
self.assertEqual(train_data.num_classes, 2) self.assertEqual(train_data.num_classes, 2)
self.assertEqual(train_data.index_by_label, ['pos', 'neg']) self.assertEqual(train_data.label_names, ['pos', 'neg'])
self.assertLen(test_data, 2) self.assertLen(test_data, 2)
for i, elem in enumerate(test_data._dataset): for i, elem in enumerate(test_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 0])).all()) self.assertTrue((elem.numpy() == np.array([i, 0])).all())
self.assertEqual(test_data.num_classes, 2) self.assertEqual(test_data.num_classes, 2)
self.assertEqual(test_data.index_by_label, ['pos', 'neg']) self.assertEqual(test_data.label_names, ['pos', 'neg'])
def test_from_folder(self): def test_from_folder(self):
data = dataset.Dataset.from_folder(dirname=self.image_path) data = dataset.Dataset.from_folder(dirname=self.image_path)
self.assertLen(data, 2) self.assertLen(data, 2)
self.assertEqual(data.num_classes, 2) self.assertEqual(data.num_classes, 2)
self.assertEqual(data.index_by_label, ['daisy', 'tulips']) self.assertEqual(data.label_names, ['daisy', 'tulips'])
for image, label in data.gen_tf_dataset(): for image, label in data.gen_tf_dataset():
self.assertTrue(label.numpy() == 1 or label.numpy() == 0) self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
if label.numpy() == 0: if label.numpy() == 0:
@ -88,19 +88,19 @@ class DatasetTest(tf.test.TestCase):
self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset) self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(train_data, 1034) self.assertLen(train_data, 1034)
self.assertEqual(train_data.num_classes, 3) self.assertEqual(train_data.num_classes, 3)
self.assertEqual(train_data.index_by_label, self.assertEqual(train_data.label_names,
['angular_leaf_spot', 'bean_rust', 'healthy']) ['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset) self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(validation_data, 133) self.assertLen(validation_data, 133)
self.assertEqual(validation_data.num_classes, 3) self.assertEqual(validation_data.num_classes, 3)
self.assertEqual(validation_data.index_by_label, self.assertEqual(validation_data.label_names,
['angular_leaf_spot', 'bean_rust', 'healthy']) ['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset) self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(test_data, 128) self.assertLen(test_data, 128)
self.assertEqual(test_data.num_classes, 3) self.assertEqual(test_data.num_classes, 3)
self.assertEqual(test_data.index_by_label, self.assertEqual(test_data.label_names,
['angular_leaf_spot', 'bean_rust', 'healthy']) ['angular_leaf_spot', 'bean_rust', 'healthy'])

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""APIs to train image classifier model.""" """APIs to train image classifier model."""
from typing import Any, List, Optional from typing import List, Optional
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
@ -31,18 +31,18 @@ from mediapipe.model_maker.python.vision.image_classifier import train_image_cla
class ImageClassifier(classifier.Classifier): class ImageClassifier(classifier.Classifier):
"""ImageClassifier for building image classification model.""" """ImageClassifier for building image classification model."""
def __init__(self, model_spec: ms.ModelSpec, index_by_label: List[Any], def __init__(self, model_spec: ms.ModelSpec, label_names: List[str],
hparams: hp.HParams): hparams: hp.HParams):
"""Initializes ImageClassifier class. """Initializes ImageClassifier class.
Args: Args:
model_spec: Specification for the model. model_spec: Specification for the model.
index_by_label: A list that maps from index to label class name. label_names: A list of label names for the classes.
hparams: The hyperparameters for training image classifier. hparams: The hyperparameters for training image classifier.
""" """
super().__init__( super().__init__(
model_spec=model_spec, model_spec=model_spec,
index_by_label=index_by_label, label_names=label_names,
shuffle=hparams.shuffle, shuffle=hparams.shuffle,
full_train=hparams.do_fine_tuning) full_train=hparams.do_fine_tuning)
self._hparams = hparams self._hparams = hparams
@ -80,9 +80,7 @@ class ImageClassifier(classifier.Classifier):
spec = ms.SupportedModels.get(model_spec) spec = ms.SupportedModels.get(model_spec)
image_classifier = cls( image_classifier = cls(
model_spec=spec, model_spec=spec, label_names=train_data.label_names, hparams=hparams)
index_by_label=train_data.index_by_label,
hparams=hparams)
image_classifier._create_model() image_classifier._create_model()