Model Maker allow core dataset library to handle datasets with unknown sizes.
PiperOrigin-RevId: 547268411
This commit is contained in:
parent
4788fddde9
commit
56bc019819
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Common classification dataset library."""
|
"""Common classification dataset library."""
|
||||||
|
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -23,8 +23,12 @@ from mediapipe.model_maker.python.core.data import dataset as ds
|
||||||
class ClassificationDataset(ds.Dataset):
|
class ClassificationDataset(ds.Dataset):
|
||||||
"""Dataset Loader for classification models."""
|
"""Dataset Loader for classification models."""
|
||||||
|
|
||||||
def __init__(self, dataset: tf.data.Dataset, size: int,
|
def __init__(
|
||||||
label_names: List[str]):
|
self,
|
||||||
|
dataset: tf.data.Dataset,
|
||||||
|
label_names: List[str],
|
||||||
|
size: Optional[int] = None,
|
||||||
|
):
|
||||||
super().__init__(dataset, size)
|
super().__init__(dataset, size)
|
||||||
self._label_names = label_names
|
self._label_names = label_names
|
||||||
|
|
||||||
|
|
|
@ -36,9 +36,14 @@ class ClassificationDatasetTest(tf.test.TestCase):
|
||||||
value: A value variable stored by the mock dataset class for testing.
|
value: A value variable stored by the mock dataset class for testing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset: tf.data.Dataset, size: int,
|
def __init__(
|
||||||
label_names: List[str], value: Any):
|
self,
|
||||||
super().__init__(dataset=dataset, size=size, label_names=label_names)
|
dataset: tf.data.Dataset,
|
||||||
|
label_names: List[str],
|
||||||
|
value: Any,
|
||||||
|
size: int,
|
||||||
|
):
|
||||||
|
super().__init__(dataset=dataset, label_names=label_names, size=size)
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
|
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
|
||||||
|
@ -52,7 +57,8 @@ class ClassificationDatasetTest(tf.test.TestCase):
|
||||||
# Create data loader from sample data.
|
# Create data loader from sample data.
|
||||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||||
data = MagicClassificationDataset(
|
data = MagicClassificationDataset(
|
||||||
dataset=ds, size=len(ds), label_names=label_names, value=magic_value)
|
dataset=ds, label_names=label_names, value=magic_value, size=len(ds)
|
||||||
|
)
|
||||||
|
|
||||||
# Train/Test data split.
|
# Train/Test data split.
|
||||||
fraction = .25
|
fraction = .25
|
||||||
|
|
|
@ -56,15 +56,14 @@ class Dataset(object):
|
||||||
def size(self) -> Optional[int]:
|
def size(self) -> Optional[int]:
|
||||||
"""Returns the size of the dataset.
|
"""Returns the size of the dataset.
|
||||||
|
|
||||||
Note that this function may return None becuase the exact size of the
|
Same functionality as calling __len__. See the __len__ method definition for
|
||||||
dataset isn't a necessary parameter to create an instance of this class,
|
more information.
|
||||||
and tf.data.Dataset donesn't support a function to get the length directly
|
|
||||||
since it's lazy-loaded and may be infinite.
|
Raises:
|
||||||
In most cases, however, when an instance of this class is created by helper
|
TypeError if self._size is not set and the cardinality of self._dataset
|
||||||
functions like 'from_folder', the size of the dataset will be preprocessed,
|
is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY.
|
||||||
and this function can return an int representing the size of the dataset.
|
|
||||||
"""
|
"""
|
||||||
return self._size
|
return self.__len__()
|
||||||
|
|
||||||
def gen_tf_dataset(
|
def gen_tf_dataset(
|
||||||
self,
|
self,
|
||||||
|
@ -116,8 +115,22 @@ class Dataset(object):
|
||||||
# here.
|
# here.
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
"""Returns the number of element of the dataset."""
|
"""Returns the number of element of the dataset.
|
||||||
|
|
||||||
|
If size is not set, this method will fallback to using the __len__ method
|
||||||
|
of the tf.data.Dataset in self._dataset. Calling __len__ on a
|
||||||
|
tf.data.Dataset instance may throw a TypeError because the dataset may
|
||||||
|
be lazy-loaded with an unknown size or have infinite size.
|
||||||
|
|
||||||
|
In most cases, however, when an instance of this class is created by helper
|
||||||
|
functions like 'from_folder', the size of the dataset will be preprocessed,
|
||||||
|
and the _size instance variable will be already set.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError if self._size is not set and the cardinality of self._dataset
|
||||||
|
is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY.
|
||||||
|
"""
|
||||||
if self._size is not None:
|
if self._size is not None:
|
||||||
return self._size
|
return self._size
|
||||||
else:
|
else:
|
||||||
|
@ -152,15 +165,25 @@ class Dataset(object):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The splitted two sub datasets.
|
The splitted two sub datasets.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the provided fraction is not between 0 and 1.
|
||||||
|
ValueError: if this dataset does not have a set size.
|
||||||
"""
|
"""
|
||||||
assert (fraction > 0 and fraction < 1)
|
if not (fraction > 0 and fraction < 1):
|
||||||
|
raise ValueError(f'Fraction must be between 0 and 1. Got:{fraction}')
|
||||||
|
if not self._size:
|
||||||
|
raise ValueError(
|
||||||
|
'Dataset size unknown. Cannot split the dataset when '
|
||||||
|
'the size is unknown.'
|
||||||
|
)
|
||||||
|
|
||||||
dataset = self._dataset
|
dataset = self._dataset
|
||||||
|
|
||||||
train_size = int(self._size * fraction)
|
train_size = int(self._size * fraction)
|
||||||
trainset = self.__class__(dataset.take(train_size), train_size, *args)
|
trainset = self.__class__(dataset.take(train_size), *args, size=train_size)
|
||||||
|
|
||||||
test_size = self._size - train_size
|
test_size = self._size - train_size
|
||||||
testset = self.__class__(dataset.skip(train_size), test_size, *args)
|
testset = self.__class__(dataset.skip(train_size), *args, size=test_size)
|
||||||
|
|
||||||
return trainset, testset
|
return trainset, testset
|
||||||
|
|
|
@ -85,4 +85,5 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
|
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
|
||||||
|
|
||||||
return Dataset(
|
return Dataset(
|
||||||
dataset=text_label_ds, size=len(texts), label_names=label_names)
|
dataset=text_label_ds, label_names=label_names, size=len(texts)
|
||||||
|
)
|
||||||
|
|
|
@ -53,7 +53,7 @@ class DatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
def test_split(self):
|
def test_split(self):
|
||||||
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
|
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
|
||||||
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
|
data = dataset.Dataset(ds, ['pos', 'neg'], 4)
|
||||||
train_data, test_data = data.split(0.5)
|
train_data, test_data = data.split(0.5)
|
||||||
expected_train_data = [b'good', b'bad']
|
expected_train_data = [b'good', b'bad']
|
||||||
expected_test_data = [b'neutral', b'odd']
|
expected_test_data = [b'neutral', b'odd']
|
||||||
|
|
|
@ -115,5 +115,7 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
', '.join(label_names),
|
', '.join(label_names),
|
||||||
)
|
)
|
||||||
return Dataset(
|
return Dataset(
|
||||||
dataset=image_label_ds, size=all_image_size, label_names=label_names
|
dataset=image_label_ds,
|
||||||
|
label_names=label_names,
|
||||||
|
size=all_image_size,
|
||||||
)
|
)
|
||||||
|
|
|
@ -249,5 +249,6 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
len(valid_hand_data), len(label_names), ','.join(label_names)))
|
len(valid_hand_data), len(label_names), ','.join(label_names)))
|
||||||
return Dataset(
|
return Dataset(
|
||||||
dataset=hand_embedding_label_ds,
|
dataset=hand_embedding_label_ds,
|
||||||
|
label_names=label_names,
|
||||||
size=len(valid_hand_data),
|
size=len(valid_hand_data),
|
||||||
label_names=label_names)
|
)
|
||||||
|
|
|
@ -15,28 +15,12 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow_datasets as tfds
|
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
from mediapipe.model_maker.python.vision.core import image_utils
|
from mediapipe.model_maker.python.vision.core import image_utils
|
||||||
|
|
||||||
|
|
||||||
def _create_data(
|
|
||||||
name: str, data: tf.data.Dataset, info: tfds.core.DatasetInfo,
|
|
||||||
label_names: List[str]
|
|
||||||
) -> Optional[classification_dataset.ClassificationDataset]:
|
|
||||||
"""Creates a Dataset object from tfds data."""
|
|
||||||
if name not in data:
|
|
||||||
return None
|
|
||||||
data = data[name]
|
|
||||||
data = data.map(lambda a: (a['image'], a['label']))
|
|
||||||
size = info.splits[name].num_examples
|
|
||||||
return Dataset(data, size, label_names)
|
|
||||||
|
|
||||||
|
|
||||||
class Dataset(classification_dataset.ClassificationDataset):
|
class Dataset(classification_dataset.ClassificationDataset):
|
||||||
"""Dataset library for image classifier."""
|
"""Dataset library for image classifier."""
|
||||||
|
|
||||||
|
@ -99,4 +83,5 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
|
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
|
||||||
all_label_size, ', '.join(label_names))
|
all_label_size, ', '.join(label_names))
|
||||||
return Dataset(
|
return Dataset(
|
||||||
dataset=image_label_ds, size=all_image_size, label_names=label_names)
|
dataset=image_label_ds, label_names=label_names, size=all_image_size
|
||||||
|
)
|
||||||
|
|
|
@ -41,7 +41,7 @@ class DatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
def test_split(self):
|
def test_split(self):
|
||||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||||
data = dataset.Dataset(dataset=ds, size=4, label_names=['pos', 'neg'])
|
data = dataset.Dataset(dataset=ds, label_names=['pos', 'neg'], size=4)
|
||||||
train_data, test_data = data.split(fraction=0.5)
|
train_data, test_data = data.split(fraction=0.5)
|
||||||
|
|
||||||
self.assertLen(train_data, 2)
|
self.assertLen(train_data, 2)
|
||||||
|
|
|
@ -52,8 +52,9 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
ds = tf.data.Dataset.from_generator(
|
ds = tf.data.Dataset.from_generator(
|
||||||
self._gen, (tf.uint8, tf.int64), (tf.TensorShape(
|
self._gen, (tf.uint8, tf.int64), (tf.TensorShape(
|
||||||
[self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([])))
|
[self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([])))
|
||||||
data = image_classifier.Dataset(ds, self.IMAGES_PER_CLASS * 3,
|
data = image_classifier.Dataset(
|
||||||
['cyan', 'magenta', 'yellow'])
|
ds, ['cyan', 'magenta', 'yellow'], self.IMAGES_PER_CLASS * 3
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|
|
@ -176,5 +176,5 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
label_names = [label_map[k] for k in sorted(label_map.keys())]
|
label_names = [label_map[k] for k in sorted(label_map.keys())]
|
||||||
|
|
||||||
return Dataset(
|
return Dataset(
|
||||||
dataset=dataset, size=meta_data['size'], label_names=label_names
|
dataset=dataset, label_names=label_names, size=meta_data['size']
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user