Internal change
PiperOrigin-RevId: 482697999
This commit is contained in:
parent
9e2a9bb4be
commit
086fc442fd
|
@ -23,13 +23,17 @@ from mediapipe.model_maker.python.core.data import dataset as ds
|
||||||
class ClassificationDataset(ds.Dataset):
|
class ClassificationDataset(ds.Dataset):
|
||||||
"""DataLoader for classification models."""
|
"""DataLoader for classification models."""
|
||||||
|
|
||||||
def __init__(self, dataset: tf.data.Dataset, size: int, index_to_label: Any):
|
def __init__(self, dataset: tf.data.Dataset, size: int, index_by_label: Any):
|
||||||
super().__init__(dataset, size)
|
super().__init__(dataset, size)
|
||||||
self.index_to_label = index_to_label
|
self._index_by_label = index_by_label
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_classes(self: ds._DatasetT) -> int:
|
def num_classes(self: ds._DatasetT) -> int:
|
||||||
return len(self.index_to_label)
|
return len(self._index_by_label)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def index_by_label(self: ds._DatasetT) -> Any:
|
||||||
|
return self._index_by_label
|
||||||
|
|
||||||
def split(self: ds._DatasetT,
|
def split(self: ds._DatasetT,
|
||||||
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
|
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
|
||||||
|
@ -44,4 +48,4 @@ class ClassificationDataset(ds.Dataset):
|
||||||
Returns:
|
Returns:
|
||||||
The splitted two sub datasets.
|
The splitted two sub datasets.
|
||||||
"""
|
"""
|
||||||
return self._split(fraction, self.index_to_label)
|
return self._split(fraction, self._index_by_label)
|
||||||
|
|
|
@ -12,45 +12,59 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, Tuple, TypeVar
|
||||||
|
|
||||||
# Dependency imports
|
# Dependency imports
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
|
|
||||||
|
_DatasetT = TypeVar(
|
||||||
|
'_DatasetT', bound='ClassificationDatasetTest.MagicClassificationDataset')
|
||||||
|
|
||||||
class ClassificationDataLoaderTest(tf.test.TestCase):
|
|
||||||
|
class ClassificationDatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
def test_split(self):
|
def test_split(self):
|
||||||
|
|
||||||
class MagicClassificationDataLoader(
|
class MagicClassificationDataset(
|
||||||
classification_dataset.ClassificationDataset):
|
classification_dataset.ClassificationDataset):
|
||||||
|
"""A mock classification dataset class for testing purpose.
|
||||||
|
|
||||||
def __init__(self, dataset, size, index_to_label, value):
|
Attributes:
|
||||||
super(MagicClassificationDataLoader,
|
value: A value variable stored by the mock dataset class for testing.
|
||||||
self).__init__(dataset, size, index_to_label)
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataset: tf.data.Dataset, size: int,
|
||||||
|
index_by_label: Any, value: Any):
|
||||||
|
super().__init__(
|
||||||
|
dataset=dataset, size=size, index_by_label=index_by_label)
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
def split(self, fraction):
|
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
|
||||||
return self._split(fraction, self.index_to_label, self.value)
|
return self._split(fraction, self.index_by_label, self.value)
|
||||||
|
|
||||||
# Some dummy inputs.
|
# Some dummy inputs.
|
||||||
magic_value = 42
|
magic_value = 42
|
||||||
num_classes = 2
|
num_classes = 2
|
||||||
index_to_label = (False, True)
|
index_by_label = (False, True)
|
||||||
|
|
||||||
# Create data loader from sample data.
|
# Create data loader from sample data.
|
||||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||||
data = MagicClassificationDataLoader(ds, len(ds), index_to_label,
|
data = MagicClassificationDataset(
|
||||||
magic_value)
|
dataset=ds,
|
||||||
|
size=len(ds),
|
||||||
|
index_by_label=index_by_label,
|
||||||
|
value=magic_value)
|
||||||
|
|
||||||
# Train/Test data split.
|
# Train/Test data split.
|
||||||
fraction = .25
|
fraction = .25
|
||||||
train_data, test_data = data.split(fraction)
|
train_data, test_data = data.split(fraction=fraction)
|
||||||
|
|
||||||
# `split` should return instances of child DataLoader.
|
# `split` should return instances of child DataLoader.
|
||||||
self.assertIsInstance(train_data, MagicClassificationDataLoader)
|
self.assertIsInstance(train_data, MagicClassificationDataset)
|
||||||
self.assertIsInstance(test_data, MagicClassificationDataLoader)
|
self.assertIsInstance(test_data, MagicClassificationDataset)
|
||||||
|
|
||||||
# Make sure number of entries are right.
|
# Make sure number of entries are right.
|
||||||
self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
|
self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
|
||||||
|
@ -59,7 +73,7 @@ class ClassificationDataLoaderTest(tf.test.TestCase):
|
||||||
|
|
||||||
# Make sure attributes propagated correctly.
|
# Make sure attributes propagated correctly.
|
||||||
self.assertEqual(train_data.num_classes, num_classes)
|
self.assertEqual(train_data.num_classes, num_classes)
|
||||||
self.assertEqual(test_data.index_to_label, index_to_label)
|
self.assertEqual(test_data.index_by_label, index_by_label)
|
||||||
self.assertEqual(train_data.value, magic_value)
|
self.assertEqual(train_data.value, magic_value)
|
||||||
self.assertEqual(test_data.value, magic_value)
|
self.assertEqual(test_data.value, magic_value)
|
||||||
|
|
||||||
|
|
|
@ -29,22 +29,22 @@ from mediapipe.model_maker.python.core.tasks import custom_model
|
||||||
class Classifier(custom_model.CustomModel):
|
class Classifier(custom_model.CustomModel):
|
||||||
"""An abstract base class that represents a TensorFlow classifier."""
|
"""An abstract base class that represents a TensorFlow classifier."""
|
||||||
|
|
||||||
def __init__(self, model_spec: Any, index_to_label: List[str], shuffle: bool,
|
def __init__(self, model_spec: Any, index_by_label: List[str], shuffle: bool,
|
||||||
full_train: bool):
|
full_train: bool):
|
||||||
"""Initilizes a classifier with its specifications.
|
"""Initilizes a classifier with its specifications.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_spec: Specification for the model.
|
model_spec: Specification for the model.
|
||||||
index_to_label: A list that map from index to label class name.
|
index_by_label: A list that map from index to label class name.
|
||||||
shuffle: Whether the dataset should be shuffled.
|
shuffle: Whether the dataset should be shuffled.
|
||||||
full_train: If true, train the model end-to-end including the backbone
|
full_train: If true, train the model end-to-end including the backbone
|
||||||
and the classification layers on top. Otherwise, only train the top
|
and the classification layers on top. Otherwise, only train the top
|
||||||
classification layers.
|
classification layers.
|
||||||
"""
|
"""
|
||||||
super(Classifier, self).__init__(model_spec, shuffle)
|
super(Classifier, self).__init__(model_spec, shuffle)
|
||||||
self._index_to_label = index_to_label
|
self._index_by_label = index_by_label
|
||||||
self._full_train = full_train
|
self._full_train = full_train
|
||||||
self._num_classes = len(index_to_label)
|
self._num_classes = len(index_by_label)
|
||||||
|
|
||||||
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
||||||
"""Evaluates the classifier with the provided evaluation dataset.
|
"""Evaluates the classifier with the provided evaluation dataset.
|
||||||
|
@ -74,4 +74,4 @@ class Classifier(custom_model.CustomModel):
|
||||||
label_filepath = os.path.join(export_dir, label_filename)
|
label_filepath = os.path.join(export_dir, label_filename)
|
||||||
tf.compat.v1.logging.info('Saving labels in %s', label_filepath)
|
tf.compat.v1.logging.info('Saving labels in %s', label_filepath)
|
||||||
with tf.io.gfile.GFile(label_filepath, 'w') as f:
|
with tf.io.gfile.GFile(label_filepath, 'w') as f:
|
||||||
f.write('\n'.join(self._index_to_label))
|
f.write('\n'.join(self._index_by_label))
|
||||||
|
|
|
@ -36,10 +36,10 @@ class ClassifierTest(tf.test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(ClassifierTest, self).setUp()
|
super(ClassifierTest, self).setUp()
|
||||||
index_to_label = ['cat', 'dog']
|
index_by_label = ['cat', 'dog']
|
||||||
self.model = MockClassifier(
|
self.model = MockClassifier(
|
||||||
model_spec=None,
|
model_spec=None,
|
||||||
index_to_label=index_to_label,
|
index_by_label=index_by_label,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
full_train=False)
|
full_train=False)
|
||||||
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
|
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||||
|
|
|
@ -84,10 +84,10 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
name for name in os.listdir(data_root)
|
name for name in os.listdir(data_root)
|
||||||
if os.path.isdir(os.path.join(data_root, name)))
|
if os.path.isdir(os.path.join(data_root, name)))
|
||||||
all_label_size = len(label_names)
|
all_label_size = len(label_names)
|
||||||
label_to_index = dict(
|
index_by_label = dict(
|
||||||
(name, index) for index, name in enumerate(label_names))
|
(name, index) for index, name in enumerate(label_names))
|
||||||
all_image_labels = [
|
all_image_labels = [
|
||||||
label_to_index[os.path.basename(os.path.dirname(path))]
|
index_by_label[os.path.basename(os.path.dirname(path))]
|
||||||
for path in all_image_paths
|
for path in all_image_paths
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
|
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
|
||||||
all_label_size, ', '.join(label_names))
|
all_label_size, ', '.join(label_names))
|
||||||
return Dataset(
|
return Dataset(
|
||||||
dataset=image_label_ds, size=all_image_size, index_to_label=label_names)
|
dataset=image_label_ds, size=all_image_size, index_by_label=label_names)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_tf_dataset(
|
def load_tf_dataset(
|
||||||
|
|
|
@ -49,27 +49,27 @@ class DatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
def test_split(self):
|
def test_split(self):
|
||||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||||
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
|
data = dataset.Dataset(dataset=ds, size=4, index_by_label=['pos', 'neg'])
|
||||||
train_data, test_data = data.split(0.5)
|
train_data, test_data = data.split(fraction=0.5)
|
||||||
|
|
||||||
self.assertLen(train_data, 2)
|
self.assertLen(train_data, 2)
|
||||||
for i, elem in enumerate(train_data._dataset):
|
for i, elem in enumerate(train_data._dataset):
|
||||||
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
|
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
|
||||||
self.assertEqual(train_data.num_classes, 2)
|
self.assertEqual(train_data.num_classes, 2)
|
||||||
self.assertEqual(train_data.index_to_label, ['pos', 'neg'])
|
self.assertEqual(train_data.index_by_label, ['pos', 'neg'])
|
||||||
|
|
||||||
self.assertLen(test_data, 2)
|
self.assertLen(test_data, 2)
|
||||||
for i, elem in enumerate(test_data._dataset):
|
for i, elem in enumerate(test_data._dataset):
|
||||||
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
|
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
|
||||||
self.assertEqual(test_data.num_classes, 2)
|
self.assertEqual(test_data.num_classes, 2)
|
||||||
self.assertEqual(test_data.index_to_label, ['pos', 'neg'])
|
self.assertEqual(test_data.index_by_label, ['pos', 'neg'])
|
||||||
|
|
||||||
def test_from_folder(self):
|
def test_from_folder(self):
|
||||||
data = dataset.Dataset.from_folder(self.image_path)
|
data = dataset.Dataset.from_folder(dirname=self.image_path)
|
||||||
|
|
||||||
self.assertLen(data, 2)
|
self.assertLen(data, 2)
|
||||||
self.assertEqual(data.num_classes, 2)
|
self.assertEqual(data.num_classes, 2)
|
||||||
self.assertEqual(data.index_to_label, ['daisy', 'tulips'])
|
self.assertEqual(data.index_by_label, ['daisy', 'tulips'])
|
||||||
for image, label in data.gen_tf_dataset():
|
for image, label in data.gen_tf_dataset():
|
||||||
self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
|
self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
|
||||||
if label.numpy() == 0:
|
if label.numpy() == 0:
|
||||||
|
@ -88,19 +88,19 @@ class DatasetTest(tf.test.TestCase):
|
||||||
self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
|
self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
|
||||||
self.assertLen(train_data, 1034)
|
self.assertLen(train_data, 1034)
|
||||||
self.assertEqual(train_data.num_classes, 3)
|
self.assertEqual(train_data.num_classes, 3)
|
||||||
self.assertEqual(train_data.index_to_label,
|
self.assertEqual(train_data.index_by_label,
|
||||||
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||||
|
|
||||||
self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
|
self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
|
||||||
self.assertLen(validation_data, 133)
|
self.assertLen(validation_data, 133)
|
||||||
self.assertEqual(validation_data.num_classes, 3)
|
self.assertEqual(validation_data.num_classes, 3)
|
||||||
self.assertEqual(validation_data.index_to_label,
|
self.assertEqual(validation_data.index_by_label,
|
||||||
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||||
|
|
||||||
self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
|
self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
|
||||||
self.assertLen(test_data, 128)
|
self.assertLen(test_data, 128)
|
||||||
self.assertEqual(test_data.num_classes, 3)
|
self.assertEqual(test_data.num_classes, 3)
|
||||||
self.assertEqual(test_data.index_to_label,
|
self.assertEqual(test_data.index_by_label,
|
||||||
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -31,18 +31,18 @@ from mediapipe.model_maker.python.vision.image_classifier import train_image_cla
|
||||||
class ImageClassifier(classifier.Classifier):
|
class ImageClassifier(classifier.Classifier):
|
||||||
"""ImageClassifier for building image classification model."""
|
"""ImageClassifier for building image classification model."""
|
||||||
|
|
||||||
def __init__(self, model_spec: ms.ModelSpec, index_to_label: List[Any],
|
def __init__(self, model_spec: ms.ModelSpec, index_by_label: List[Any],
|
||||||
hparams: hp.HParams):
|
hparams: hp.HParams):
|
||||||
"""Initializes ImageClassifier class.
|
"""Initializes ImageClassifier class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_spec: Specification for the model.
|
model_spec: Specification for the model.
|
||||||
index_to_label: A list that maps from index to label class name.
|
index_by_label: A list that maps from index to label class name.
|
||||||
hparams: The hyperparameters for training image classifier.
|
hparams: The hyperparameters for training image classifier.
|
||||||
"""
|
"""
|
||||||
super(ImageClassifier, self).__init__(
|
super().__init__(
|
||||||
model_spec=model_spec,
|
model_spec=model_spec,
|
||||||
index_to_label=index_to_label,
|
index_by_label=index_by_label,
|
||||||
shuffle=hparams.shuffle,
|
shuffle=hparams.shuffle,
|
||||||
full_train=hparams.do_fine_tuning)
|
full_train=hparams.do_fine_tuning)
|
||||||
self._hparams = hparams
|
self._hparams = hparams
|
||||||
|
@ -81,7 +81,7 @@ class ImageClassifier(classifier.Classifier):
|
||||||
spec = ms.SupportedModels.get(model_spec)
|
spec = ms.SupportedModels.get(model_spec)
|
||||||
image_classifier = cls(
|
image_classifier = cls(
|
||||||
model_spec=spec,
|
model_spec=spec,
|
||||||
index_to_label=train_data.index_to_label,
|
index_by_label=train_data.index_by_label,
|
||||||
hparams=hparams)
|
hparams=hparams)
|
||||||
|
|
||||||
image_classifier._create_model()
|
image_classifier._create_model()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user