Internal change

PiperOrigin-RevId: 482697999
This commit is contained in:
MediaPipe Team 2022-10-20 23:12:58 -07:00 committed by Copybara-Service
parent 9e2a9bb4be
commit 086fc442fd
7 changed files with 60 additions and 42 deletions

View File

@ -23,13 +23,17 @@ from mediapipe.model_maker.python.core.data import dataset as ds
class ClassificationDataset(ds.Dataset): class ClassificationDataset(ds.Dataset):
"""DataLoader for classification models.""" """DataLoader for classification models."""
def __init__(self, dataset: tf.data.Dataset, size: int, index_to_label: Any): def __init__(self, dataset: tf.data.Dataset, size: int, index_by_label: Any):
super().__init__(dataset, size) super().__init__(dataset, size)
self.index_to_label = index_to_label self._index_by_label = index_by_label
@property @property
def num_classes(self: ds._DatasetT) -> int: def num_classes(self: ds._DatasetT) -> int:
return len(self.index_to_label) return len(self._index_by_label)
@property
def index_by_label(self: ds._DatasetT) -> Any:
return self._index_by_label
def split(self: ds._DatasetT, def split(self: ds._DatasetT,
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]: fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
@ -44,4 +48,4 @@ class ClassificationDataset(ds.Dataset):
Returns: Returns:
The splitted two sub datasets. The splitted two sub datasets.
""" """
return self._split(fraction, self.index_to_label) return self._split(fraction, self._index_by_label)

View File

@ -12,45 +12,59 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Tuple, TypeVar
# Dependency imports # Dependency imports
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset from mediapipe.model_maker.python.core.data import classification_dataset
_DatasetT = TypeVar(
'_DatasetT', bound='ClassificationDatasetTest.MagicClassificationDataset')
class ClassificationDataLoaderTest(tf.test.TestCase):
class ClassificationDatasetTest(tf.test.TestCase):
def test_split(self): def test_split(self):
class MagicClassificationDataLoader( class MagicClassificationDataset(
classification_dataset.ClassificationDataset): classification_dataset.ClassificationDataset):
"""A mock classification dataset class for testing purpose.
def __init__(self, dataset, size, index_to_label, value): Attributes:
super(MagicClassificationDataLoader, value: A value variable stored by the mock dataset class for testing.
self).__init__(dataset, size, index_to_label) """
def __init__(self, dataset: tf.data.Dataset, size: int,
index_by_label: Any, value: Any):
super().__init__(
dataset=dataset, size=size, index_by_label=index_by_label)
self.value = value self.value = value
def split(self, fraction): def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
return self._split(fraction, self.index_to_label, self.value) return self._split(fraction, self.index_by_label, self.value)
# Some dummy inputs. # Some dummy inputs.
magic_value = 42 magic_value = 42
num_classes = 2 num_classes = 2
index_to_label = (False, True) index_by_label = (False, True)
# Create data loader from sample data. # Create data loader from sample data.
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = MagicClassificationDataLoader(ds, len(ds), index_to_label, data = MagicClassificationDataset(
magic_value) dataset=ds,
size=len(ds),
index_by_label=index_by_label,
value=magic_value)
# Train/Test data split. # Train/Test data split.
fraction = .25 fraction = .25
train_data, test_data = data.split(fraction) train_data, test_data = data.split(fraction=fraction)
# `split` should return instances of child DataLoader. # `split` should return instances of child DataLoader.
self.assertIsInstance(train_data, MagicClassificationDataLoader) self.assertIsInstance(train_data, MagicClassificationDataset)
self.assertIsInstance(test_data, MagicClassificationDataLoader) self.assertIsInstance(test_data, MagicClassificationDataset)
# Make sure number of entries are right. # Make sure number of entries are right.
self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data)) self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
@ -59,7 +73,7 @@ class ClassificationDataLoaderTest(tf.test.TestCase):
# Make sure attributes propagated correctly. # Make sure attributes propagated correctly.
self.assertEqual(train_data.num_classes, num_classes) self.assertEqual(train_data.num_classes, num_classes)
self.assertEqual(test_data.index_to_label, index_to_label) self.assertEqual(test_data.index_by_label, index_by_label)
self.assertEqual(train_data.value, magic_value) self.assertEqual(train_data.value, magic_value)
self.assertEqual(test_data.value, magic_value) self.assertEqual(test_data.value, magic_value)

View File

@ -29,22 +29,22 @@ from mediapipe.model_maker.python.core.tasks import custom_model
class Classifier(custom_model.CustomModel): class Classifier(custom_model.CustomModel):
"""An abstract base class that represents a TensorFlow classifier.""" """An abstract base class that represents a TensorFlow classifier."""
def __init__(self, model_spec: Any, index_to_label: List[str], shuffle: bool, def __init__(self, model_spec: Any, index_by_label: List[str], shuffle: bool,
full_train: bool): full_train: bool):
"""Initilizes a classifier with its specifications. """Initilizes a classifier with its specifications.
Args: Args:
model_spec: Specification for the model. model_spec: Specification for the model.
index_to_label: A list that map from index to label class name. index_by_label: A list that map from index to label class name.
shuffle: Whether the dataset should be shuffled. shuffle: Whether the dataset should be shuffled.
full_train: If true, train the model end-to-end including the backbone full_train: If true, train the model end-to-end including the backbone
and the classification layers on top. Otherwise, only train the top and the classification layers on top. Otherwise, only train the top
classification layers. classification layers.
""" """
super(Classifier, self).__init__(model_spec, shuffle) super(Classifier, self).__init__(model_spec, shuffle)
self._index_to_label = index_to_label self._index_by_label = index_by_label
self._full_train = full_train self._full_train = full_train
self._num_classes = len(index_to_label) self._num_classes = len(index_by_label)
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
"""Evaluates the classifier with the provided evaluation dataset. """Evaluates the classifier with the provided evaluation dataset.
@ -74,4 +74,4 @@ class Classifier(custom_model.CustomModel):
label_filepath = os.path.join(export_dir, label_filename) label_filepath = os.path.join(export_dir, label_filename)
tf.compat.v1.logging.info('Saving labels in %s', label_filepath) tf.compat.v1.logging.info('Saving labels in %s', label_filepath)
with tf.io.gfile.GFile(label_filepath, 'w') as f: with tf.io.gfile.GFile(label_filepath, 'w') as f:
f.write('\n'.join(self._index_to_label)) f.write('\n'.join(self._index_by_label))

View File

@ -36,10 +36,10 @@ class ClassifierTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(ClassifierTest, self).setUp() super(ClassifierTest, self).setUp()
index_to_label = ['cat', 'dog'] index_by_label = ['cat', 'dog']
self.model = MockClassifier( self.model = MockClassifier(
model_spec=None, model_spec=None,
index_to_label=index_to_label, index_by_label=index_by_label,
shuffle=False, shuffle=False,
full_train=False) full_train=False)
self.model.model = test_util.build_model(input_shape=[4], num_classes=2) self.model.model = test_util.build_model(input_shape=[4], num_classes=2)

View File

@ -84,10 +84,10 @@ class Dataset(classification_dataset.ClassificationDataset):
name for name in os.listdir(data_root) name for name in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, name))) if os.path.isdir(os.path.join(data_root, name)))
all_label_size = len(label_names) all_label_size = len(label_names)
label_to_index = dict( index_by_label = dict(
(name, index) for index, name in enumerate(label_names)) (name, index) for index, name in enumerate(label_names))
all_image_labels = [ all_image_labels = [
label_to_index[os.path.basename(os.path.dirname(path))] index_by_label[os.path.basename(os.path.dirname(path))]
for path in all_image_paths for path in all_image_paths
] ]
@ -106,7 +106,7 @@ class Dataset(classification_dataset.ClassificationDataset):
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size, 'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
all_label_size, ', '.join(label_names)) all_label_size, ', '.join(label_names))
return Dataset( return Dataset(
dataset=image_label_ds, size=all_image_size, index_to_label=label_names) dataset=image_label_ds, size=all_image_size, index_by_label=label_names)
@classmethod @classmethod
def load_tf_dataset( def load_tf_dataset(

View File

@ -49,27 +49,27 @@ class DatasetTest(tf.test.TestCase):
def test_split(self): def test_split(self):
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = dataset.Dataset(ds, 4, ['pos', 'neg']) data = dataset.Dataset(dataset=ds, size=4, index_by_label=['pos', 'neg'])
train_data, test_data = data.split(0.5) train_data, test_data = data.split(fraction=0.5)
self.assertLen(train_data, 2) self.assertLen(train_data, 2)
for i, elem in enumerate(train_data._dataset): for i, elem in enumerate(train_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 1])).all()) self.assertTrue((elem.numpy() == np.array([i, 1])).all())
self.assertEqual(train_data.num_classes, 2) self.assertEqual(train_data.num_classes, 2)
self.assertEqual(train_data.index_to_label, ['pos', 'neg']) self.assertEqual(train_data.index_by_label, ['pos', 'neg'])
self.assertLen(test_data, 2) self.assertLen(test_data, 2)
for i, elem in enumerate(test_data._dataset): for i, elem in enumerate(test_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 0])).all()) self.assertTrue((elem.numpy() == np.array([i, 0])).all())
self.assertEqual(test_data.num_classes, 2) self.assertEqual(test_data.num_classes, 2)
self.assertEqual(test_data.index_to_label, ['pos', 'neg']) self.assertEqual(test_data.index_by_label, ['pos', 'neg'])
def test_from_folder(self): def test_from_folder(self):
data = dataset.Dataset.from_folder(self.image_path) data = dataset.Dataset.from_folder(dirname=self.image_path)
self.assertLen(data, 2) self.assertLen(data, 2)
self.assertEqual(data.num_classes, 2) self.assertEqual(data.num_classes, 2)
self.assertEqual(data.index_to_label, ['daisy', 'tulips']) self.assertEqual(data.index_by_label, ['daisy', 'tulips'])
for image, label in data.gen_tf_dataset(): for image, label in data.gen_tf_dataset():
self.assertTrue(label.numpy() == 1 or label.numpy() == 0) self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
if label.numpy() == 0: if label.numpy() == 0:
@ -88,19 +88,19 @@ class DatasetTest(tf.test.TestCase):
self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset) self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(train_data, 1034) self.assertLen(train_data, 1034)
self.assertEqual(train_data.num_classes, 3) self.assertEqual(train_data.num_classes, 3)
self.assertEqual(train_data.index_to_label, self.assertEqual(train_data.index_by_label,
['angular_leaf_spot', 'bean_rust', 'healthy']) ['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset) self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(validation_data, 133) self.assertLen(validation_data, 133)
self.assertEqual(validation_data.num_classes, 3) self.assertEqual(validation_data.num_classes, 3)
self.assertEqual(validation_data.index_to_label, self.assertEqual(validation_data.index_by_label,
['angular_leaf_spot', 'bean_rust', 'healthy']) ['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset) self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(test_data, 128) self.assertLen(test_data, 128)
self.assertEqual(test_data.num_classes, 3) self.assertEqual(test_data.num_classes, 3)
self.assertEqual(test_data.index_to_label, self.assertEqual(test_data.index_by_label,
['angular_leaf_spot', 'bean_rust', 'healthy']) ['angular_leaf_spot', 'bean_rust', 'healthy'])

View File

@ -31,18 +31,18 @@ from mediapipe.model_maker.python.vision.image_classifier import train_image_cla
class ImageClassifier(classifier.Classifier): class ImageClassifier(classifier.Classifier):
"""ImageClassifier for building image classification model.""" """ImageClassifier for building image classification model."""
def __init__(self, model_spec: ms.ModelSpec, index_to_label: List[Any], def __init__(self, model_spec: ms.ModelSpec, index_by_label: List[Any],
hparams: hp.HParams): hparams: hp.HParams):
"""Initializes ImageClassifier class. """Initializes ImageClassifier class.
Args: Args:
model_spec: Specification for the model. model_spec: Specification for the model.
index_to_label: A list that maps from index to label class name. index_by_label: A list that maps from index to label class name.
hparams: The hyperparameters for training image classifier. hparams: The hyperparameters for training image classifier.
""" """
super(ImageClassifier, self).__init__( super().__init__(
model_spec=model_spec, model_spec=model_spec,
index_to_label=index_to_label, index_by_label=index_by_label,
shuffle=hparams.shuffle, shuffle=hparams.shuffle,
full_train=hparams.do_fine_tuning) full_train=hparams.do_fine_tuning)
self._hparams = hparams self._hparams = hparams
@ -81,7 +81,7 @@ class ImageClassifier(classifier.Classifier):
spec = ms.SupportedModels.get(model_spec) spec = ms.SupportedModels.get(model_spec)
image_classifier = cls( image_classifier = cls(
model_spec=spec, model_spec=spec,
index_to_label=train_data.index_to_label, index_by_label=train_data.index_by_label,
hparams=hparams) hparams=hparams)
image_classifier._create_model() image_classifier._create_model()