Rename index_by_label to label_names.

PiperOrigin-RevId: 484956259
This commit is contained in:
MediaPipe Team 2022-10-30 22:24:19 -07:00 committed by Copybara-Service
parent 7bcf322625
commit 459214e6a3
7 changed files with 36 additions and 41 deletions

View File

@ -13,7 +13,7 @@
# limitations under the License.
"""Common classification dataset library."""
from typing import Any, Tuple
from typing import List, Tuple
import tensorflow as tf
@ -21,19 +21,20 @@ from mediapipe.model_maker.python.core.data import dataset as ds
class ClassificationDataset(ds.Dataset):
"""DataLoader for classification models."""
"""Dataset Loader for classification models."""
def __init__(self, dataset: tf.data.Dataset, size: int, index_by_label: Any):
def __init__(self, dataset: tf.data.Dataset, size: int,
label_names: List[str]):
super().__init__(dataset, size)
self._index_by_label = index_by_label
self._label_names = label_names
@property
def num_classes(self: ds._DatasetT) -> int:
return len(self._index_by_label)
return len(self._label_names)
@property
def index_by_label(self: ds._DatasetT) -> Any:
return self._index_by_label
def label_names(self: ds._DatasetT) -> List[str]:
return self._label_names
def split(self: ds._DatasetT,
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
@ -48,4 +49,4 @@ class ClassificationDataset(ds.Dataset):
Returns:
The splitted two sub datasets.
"""
return self._split(fraction, self._index_by_label)
return self._split(fraction, self._label_names)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Tuple, TypeVar
from typing import Any, List, Tuple, TypeVar
# Dependency imports
@ -37,26 +37,22 @@ class ClassificationDatasetTest(tf.test.TestCase):
"""
def __init__(self, dataset: tf.data.Dataset, size: int,
index_by_label: Any, value: Any):
super().__init__(
dataset=dataset, size=size, index_by_label=index_by_label)
label_names: List[str], value: Any):
super().__init__(dataset=dataset, size=size, label_names=label_names)
self.value = value
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
return self._split(fraction, self.index_by_label, self.value)
return self._split(fraction, self.label_names, self.value)
# Some dummy inputs.
magic_value = 42
num_classes = 2
index_by_label = (False, True)
label_names = ['foo', 'bar']
# Create data loader from sample data.
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = MagicClassificationDataset(
dataset=ds,
size=len(ds),
index_by_label=index_by_label,
value=magic_value)
dataset=ds, size=len(ds), label_names=label_names, value=magic_value)
# Train/Test data split.
fraction = .25
@ -73,7 +69,7 @@ class ClassificationDatasetTest(tf.test.TestCase):
# Make sure attributes propagated correctly.
self.assertEqual(train_data.num_classes, num_classes)
self.assertEqual(test_data.index_by_label, index_by_label)
self.assertEqual(test_data.label_names, label_names)
self.assertEqual(train_data.value, magic_value)
self.assertEqual(test_data.value, magic_value)

View File

@ -29,22 +29,22 @@ from mediapipe.model_maker.python.core.tasks import custom_model
class Classifier(custom_model.CustomModel):
"""An abstract base class that represents a TensorFlow classifier."""
def __init__(self, model_spec: Any, index_by_label: List[str], shuffle: bool,
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool,
full_train: bool):
"""Initilizes a classifier with its specifications.
Args:
model_spec: Specification for the model.
index_by_label: A list that map from index to label class name.
label_names: A list of label names for the classes.
shuffle: Whether the dataset should be shuffled.
full_train: If true, train the model end-to-end including the backbone
and the classification layers on top. Otherwise, only train the top
classification layers.
"""
super(Classifier, self).__init__(model_spec, shuffle)
self._index_by_label = index_by_label
self._label_names = label_names
self._full_train = full_train
self._num_classes = len(index_by_label)
self._num_classes = len(label_names)
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
"""Evaluates the classifier with the provided evaluation dataset.
@ -74,4 +74,4 @@ class Classifier(custom_model.CustomModel):
label_filepath = os.path.join(export_dir, label_filename)
tf.compat.v1.logging.info('Saving labels in %s', label_filepath)
with tf.io.gfile.GFile(label_filepath, 'w') as f:
f.write('\n'.join(self._index_by_label))
f.write('\n'.join(self._label_names))

View File

@ -36,10 +36,10 @@ class ClassifierTest(tf.test.TestCase):
def setUp(self):
super(ClassifierTest, self).setUp()
index_by_label = ['cat', 'dog']
label_names = ['cat', 'dog']
self.model = MockClassifier(
model_spec=None,
index_by_label=index_by_label,
label_names=label_names,
shuffle=False,
full_train=False)
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)

View File

@ -106,4 +106,4 @@ class Dataset(classification_dataset.ClassificationDataset):
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
all_label_size, ', '.join(label_names))
return Dataset(
dataset=image_label_ds, size=all_image_size, index_by_label=label_names)
dataset=image_label_ds, size=all_image_size, label_names=label_names)

View File

@ -49,27 +49,27 @@ class DatasetTest(tf.test.TestCase):
def test_split(self):
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = dataset.Dataset(dataset=ds, size=4, index_by_label=['pos', 'neg'])
data = dataset.Dataset(dataset=ds, size=4, label_names=['pos', 'neg'])
train_data, test_data = data.split(fraction=0.5)
self.assertLen(train_data, 2)
for i, elem in enumerate(train_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
self.assertEqual(train_data.num_classes, 2)
self.assertEqual(train_data.index_by_label, ['pos', 'neg'])
self.assertEqual(train_data.label_names, ['pos', 'neg'])
self.assertLen(test_data, 2)
for i, elem in enumerate(test_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
self.assertEqual(test_data.num_classes, 2)
self.assertEqual(test_data.index_by_label, ['pos', 'neg'])
self.assertEqual(test_data.label_names, ['pos', 'neg'])
def test_from_folder(self):
data = dataset.Dataset.from_folder(dirname=self.image_path)
self.assertLen(data, 2)
self.assertEqual(data.num_classes, 2)
self.assertEqual(data.index_by_label, ['daisy', 'tulips'])
self.assertEqual(data.label_names, ['daisy', 'tulips'])
for image, label in data.gen_tf_dataset():
self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
if label.numpy() == 0:
@ -88,19 +88,19 @@ class DatasetTest(tf.test.TestCase):
self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(train_data, 1034)
self.assertEqual(train_data.num_classes, 3)
self.assertEqual(train_data.index_by_label,
self.assertEqual(train_data.label_names,
['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(validation_data, 133)
self.assertEqual(validation_data.num_classes, 3)
self.assertEqual(validation_data.index_by_label,
self.assertEqual(validation_data.label_names,
['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(test_data, 128)
self.assertEqual(test_data.num_classes, 3)
self.assertEqual(test_data.index_by_label,
self.assertEqual(test_data.label_names,
['angular_leaf_spot', 'bean_rust', 'healthy'])

View File

@ -13,7 +13,7 @@
# limitations under the License.
"""APIs to train image classifier model."""
from typing import Any, List, Optional
from typing import List, Optional
import tensorflow as tf
import tensorflow_hub as hub
@ -31,18 +31,18 @@ from mediapipe.model_maker.python.vision.image_classifier import train_image_cla
class ImageClassifier(classifier.Classifier):
"""ImageClassifier for building image classification model."""
def __init__(self, model_spec: ms.ModelSpec, index_by_label: List[Any],
def __init__(self, model_spec: ms.ModelSpec, label_names: List[str],
hparams: hp.HParams):
"""Initializes ImageClassifier class.
Args:
model_spec: Specification for the model.
index_by_label: A list that maps from index to label class name.
label_names: A list of label names for the classes.
hparams: The hyperparameters for training image classifier.
"""
super().__init__(
model_spec=model_spec,
index_by_label=index_by_label,
label_names=label_names,
shuffle=hparams.shuffle,
full_train=hparams.do_fine_tuning)
self._hparams = hparams
@ -80,9 +80,7 @@ class ImageClassifier(classifier.Classifier):
spec = ms.SupportedModels.get(model_spec)
image_classifier = cls(
model_spec=spec,
index_by_label=train_data.index_by_label,
hparams=hparams)
model_spec=spec, label_names=train_data.label_names, hparams=hparams)
image_classifier._create_model()