Rename index_by_label
to label_names
.
PiperOrigin-RevId: 484956259
This commit is contained in:
parent
7bcf322625
commit
459214e6a3
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
"""Common classification dataset library."""
|
||||
|
||||
from typing import Any, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -21,19 +21,20 @@ from mediapipe.model_maker.python.core.data import dataset as ds
|
|||
|
||||
|
||||
class ClassificationDataset(ds.Dataset):
|
||||
"""DataLoader for classification models."""
|
||||
"""Dataset Loader for classification models."""
|
||||
|
||||
def __init__(self, dataset: tf.data.Dataset, size: int, index_by_label: Any):
|
||||
def __init__(self, dataset: tf.data.Dataset, size: int,
|
||||
label_names: List[str]):
|
||||
super().__init__(dataset, size)
|
||||
self._index_by_label = index_by_label
|
||||
self._label_names = label_names
|
||||
|
||||
@property
|
||||
def num_classes(self: ds._DatasetT) -> int:
|
||||
return len(self._index_by_label)
|
||||
return len(self._label_names)
|
||||
|
||||
@property
|
||||
def index_by_label(self: ds._DatasetT) -> Any:
|
||||
return self._index_by_label
|
||||
def label_names(self: ds._DatasetT) -> List[str]:
|
||||
return self._label_names
|
||||
|
||||
def split(self: ds._DatasetT,
|
||||
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
|
||||
|
@ -48,4 +49,4 @@ class ClassificationDataset(ds.Dataset):
|
|||
Returns:
|
||||
The splitted two sub datasets.
|
||||
"""
|
||||
return self._split(fraction, self._index_by_label)
|
||||
return self._split(fraction, self._label_names)
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Tuple, TypeVar
|
||||
from typing import Any, List, Tuple, TypeVar
|
||||
|
||||
# Dependency imports
|
||||
|
||||
|
@ -37,26 +37,22 @@ class ClassificationDatasetTest(tf.test.TestCase):
|
|||
"""
|
||||
|
||||
def __init__(self, dataset: tf.data.Dataset, size: int,
|
||||
index_by_label: Any, value: Any):
|
||||
super().__init__(
|
||||
dataset=dataset, size=size, index_by_label=index_by_label)
|
||||
label_names: List[str], value: Any):
|
||||
super().__init__(dataset=dataset, size=size, label_names=label_names)
|
||||
self.value = value
|
||||
|
||||
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
|
||||
return self._split(fraction, self.index_by_label, self.value)
|
||||
return self._split(fraction, self.label_names, self.value)
|
||||
|
||||
# Some dummy inputs.
|
||||
magic_value = 42
|
||||
num_classes = 2
|
||||
index_by_label = (False, True)
|
||||
label_names = ['foo', 'bar']
|
||||
|
||||
# Create data loader from sample data.
|
||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||
data = MagicClassificationDataset(
|
||||
dataset=ds,
|
||||
size=len(ds),
|
||||
index_by_label=index_by_label,
|
||||
value=magic_value)
|
||||
dataset=ds, size=len(ds), label_names=label_names, value=magic_value)
|
||||
|
||||
# Train/Test data split.
|
||||
fraction = .25
|
||||
|
@ -73,7 +69,7 @@ class ClassificationDatasetTest(tf.test.TestCase):
|
|||
|
||||
# Make sure attributes propagated correctly.
|
||||
self.assertEqual(train_data.num_classes, num_classes)
|
||||
self.assertEqual(test_data.index_by_label, index_by_label)
|
||||
self.assertEqual(test_data.label_names, label_names)
|
||||
self.assertEqual(train_data.value, magic_value)
|
||||
self.assertEqual(test_data.value, magic_value)
|
||||
|
||||
|
|
|
@ -29,22 +29,22 @@ from mediapipe.model_maker.python.core.tasks import custom_model
|
|||
class Classifier(custom_model.CustomModel):
|
||||
"""An abstract base class that represents a TensorFlow classifier."""
|
||||
|
||||
def __init__(self, model_spec: Any, index_by_label: List[str], shuffle: bool,
|
||||
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool,
|
||||
full_train: bool):
|
||||
"""Initilizes a classifier with its specifications.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
index_by_label: A list that map from index to label class name.
|
||||
label_names: A list of label names for the classes.
|
||||
shuffle: Whether the dataset should be shuffled.
|
||||
full_train: If true, train the model end-to-end including the backbone
|
||||
and the classification layers on top. Otherwise, only train the top
|
||||
classification layers.
|
||||
"""
|
||||
super(Classifier, self).__init__(model_spec, shuffle)
|
||||
self._index_by_label = index_by_label
|
||||
self._label_names = label_names
|
||||
self._full_train = full_train
|
||||
self._num_classes = len(index_by_label)
|
||||
self._num_classes = len(label_names)
|
||||
|
||||
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
||||
"""Evaluates the classifier with the provided evaluation dataset.
|
||||
|
@ -74,4 +74,4 @@ class Classifier(custom_model.CustomModel):
|
|||
label_filepath = os.path.join(export_dir, label_filename)
|
||||
tf.compat.v1.logging.info('Saving labels in %s', label_filepath)
|
||||
with tf.io.gfile.GFile(label_filepath, 'w') as f:
|
||||
f.write('\n'.join(self._index_by_label))
|
||||
f.write('\n'.join(self._label_names))
|
||||
|
|
|
@ -36,10 +36,10 @@ class ClassifierTest(tf.test.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
super(ClassifierTest, self).setUp()
|
||||
index_by_label = ['cat', 'dog']
|
||||
label_names = ['cat', 'dog']
|
||||
self.model = MockClassifier(
|
||||
model_spec=None,
|
||||
index_by_label=index_by_label,
|
||||
label_names=label_names,
|
||||
shuffle=False,
|
||||
full_train=False)
|
||||
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||
|
|
|
@ -106,4 +106,4 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
|
||||
all_label_size, ', '.join(label_names))
|
||||
return Dataset(
|
||||
dataset=image_label_ds, size=all_image_size, index_by_label=label_names)
|
||||
dataset=image_label_ds, size=all_image_size, label_names=label_names)
|
||||
|
|
|
@ -49,27 +49,27 @@ class DatasetTest(tf.test.TestCase):
|
|||
|
||||
def test_split(self):
|
||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||
data = dataset.Dataset(dataset=ds, size=4, index_by_label=['pos', 'neg'])
|
||||
data = dataset.Dataset(dataset=ds, size=4, label_names=['pos', 'neg'])
|
||||
train_data, test_data = data.split(fraction=0.5)
|
||||
|
||||
self.assertLen(train_data, 2)
|
||||
for i, elem in enumerate(train_data._dataset):
|
||||
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
|
||||
self.assertEqual(train_data.num_classes, 2)
|
||||
self.assertEqual(train_data.index_by_label, ['pos', 'neg'])
|
||||
self.assertEqual(train_data.label_names, ['pos', 'neg'])
|
||||
|
||||
self.assertLen(test_data, 2)
|
||||
for i, elem in enumerate(test_data._dataset):
|
||||
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
|
||||
self.assertEqual(test_data.num_classes, 2)
|
||||
self.assertEqual(test_data.index_by_label, ['pos', 'neg'])
|
||||
self.assertEqual(test_data.label_names, ['pos', 'neg'])
|
||||
|
||||
def test_from_folder(self):
|
||||
data = dataset.Dataset.from_folder(dirname=self.image_path)
|
||||
|
||||
self.assertLen(data, 2)
|
||||
self.assertEqual(data.num_classes, 2)
|
||||
self.assertEqual(data.index_by_label, ['daisy', 'tulips'])
|
||||
self.assertEqual(data.label_names, ['daisy', 'tulips'])
|
||||
for image, label in data.gen_tf_dataset():
|
||||
self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
|
||||
if label.numpy() == 0:
|
||||
|
@ -88,19 +88,19 @@ class DatasetTest(tf.test.TestCase):
|
|||
self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
|
||||
self.assertLen(train_data, 1034)
|
||||
self.assertEqual(train_data.num_classes, 3)
|
||||
self.assertEqual(train_data.index_by_label,
|
||||
self.assertEqual(train_data.label_names,
|
||||
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||
|
||||
self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
|
||||
self.assertLen(validation_data, 133)
|
||||
self.assertEqual(validation_data.num_classes, 3)
|
||||
self.assertEqual(validation_data.index_by_label,
|
||||
self.assertEqual(validation_data.label_names,
|
||||
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||
|
||||
self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
|
||||
self.assertLen(test_data, 128)
|
||||
self.assertEqual(test_data.num_classes, 3)
|
||||
self.assertEqual(test_data.index_by_label,
|
||||
self.assertEqual(test_data.label_names,
|
||||
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
"""APIs to train image classifier model."""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub as hub
|
||||
|
@ -31,18 +31,18 @@ from mediapipe.model_maker.python.vision.image_classifier import train_image_cla
|
|||
class ImageClassifier(classifier.Classifier):
|
||||
"""ImageClassifier for building image classification model."""
|
||||
|
||||
def __init__(self, model_spec: ms.ModelSpec, index_by_label: List[Any],
|
||||
def __init__(self, model_spec: ms.ModelSpec, label_names: List[str],
|
||||
hparams: hp.HParams):
|
||||
"""Initializes ImageClassifier class.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
index_by_label: A list that maps from index to label class name.
|
||||
label_names: A list of label names for the classes.
|
||||
hparams: The hyperparameters for training image classifier.
|
||||
"""
|
||||
super().__init__(
|
||||
model_spec=model_spec,
|
||||
index_by_label=index_by_label,
|
||||
label_names=label_names,
|
||||
shuffle=hparams.shuffle,
|
||||
full_train=hparams.do_fine_tuning)
|
||||
self._hparams = hparams
|
||||
|
@ -80,9 +80,7 @@ class ImageClassifier(classifier.Classifier):
|
|||
|
||||
spec = ms.SupportedModels.get(model_spec)
|
||||
image_classifier = cls(
|
||||
model_spec=spec,
|
||||
index_by_label=train_data.index_by_label,
|
||||
hparams=hparams)
|
||||
model_spec=spec, label_names=train_data.label_names, hparams=hparams)
|
||||
|
||||
image_classifier._create_model()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user