Model Maker allow core dataset library to handle datasets with unknown sizes.

PiperOrigin-RevId: 547268411
This commit is contained in:
MediaPipe Team 2023-07-11 12:42:42 -07:00 committed by Copybara-Service
parent 4788fddde9
commit 56bc019819
11 changed files with 68 additions and 45 deletions

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Common classification dataset library.""" """Common classification dataset library."""
from typing import List, Tuple from typing import List, Optional, Tuple
import tensorflow as tf import tensorflow as tf
@ -23,8 +23,12 @@ from mediapipe.model_maker.python.core.data import dataset as ds
class ClassificationDataset(ds.Dataset): class ClassificationDataset(ds.Dataset):
"""Dataset Loader for classification models.""" """Dataset Loader for classification models."""
def __init__(self, dataset: tf.data.Dataset, size: int, def __init__(
label_names: List[str]): self,
dataset: tf.data.Dataset,
label_names: List[str],
size: Optional[int] = None,
):
super().__init__(dataset, size) super().__init__(dataset, size)
self._label_names = label_names self._label_names = label_names

View File

@ -36,9 +36,14 @@ class ClassificationDatasetTest(tf.test.TestCase):
value: A value variable stored by the mock dataset class for testing. value: A value variable stored by the mock dataset class for testing.
""" """
def __init__(self, dataset: tf.data.Dataset, size: int, def __init__(
label_names: List[str], value: Any): self,
super().__init__(dataset=dataset, size=size, label_names=label_names) dataset: tf.data.Dataset,
label_names: List[str],
value: Any,
size: int,
):
super().__init__(dataset=dataset, label_names=label_names, size=size)
self.value = value self.value = value
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]: def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
@ -52,7 +57,8 @@ class ClassificationDatasetTest(tf.test.TestCase):
# Create data loader from sample data. # Create data loader from sample data.
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = MagicClassificationDataset( data = MagicClassificationDataset(
dataset=ds, size=len(ds), label_names=label_names, value=magic_value) dataset=ds, label_names=label_names, value=magic_value, size=len(ds)
)
# Train/Test data split. # Train/Test data split.
fraction = .25 fraction = .25

View File

@ -56,15 +56,14 @@ class Dataset(object):
def size(self) -> Optional[int]: def size(self) -> Optional[int]:
"""Returns the size of the dataset. """Returns the size of the dataset.
Note that this function may return None becuase the exact size of the Same functionality as calling __len__. See the __len__ method definition for
dataset isn't a necessary parameter to create an instance of this class, more information.
and tf.data.Dataset donesn't support a function to get the length directly
since it's lazy-loaded and may be infinite. Raises:
In most cases, however, when an instance of this class is created by helper TypeError if self._size is not set and the cardinality of self._dataset
functions like 'from_folder', the size of the dataset will be preprocessed, is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY.
and this function can return an int representing the size of the dataset.
""" """
return self._size return self.__len__()
def gen_tf_dataset( def gen_tf_dataset(
self, self,
@ -116,8 +115,22 @@ class Dataset(object):
# here. # here.
return dataset return dataset
def __len__(self): def __len__(self) -> int:
"""Returns the number of element of the dataset.""" """Returns the number of element of the dataset.
If size is not set, this method will fallback to using the __len__ method
of the tf.data.Dataset in self._dataset. Calling __len__ on a
tf.data.Dataset instance may throw a TypeError because the dataset may
be lazy-loaded with an unknown size or have infinite size.
In most cases, however, when an instance of this class is created by helper
functions like 'from_folder', the size of the dataset will be preprocessed,
and the _size instance variable will be already set.
Raises:
TypeError if self._size is not set and the cardinality of self._dataset
is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY.
"""
if self._size is not None: if self._size is not None:
return self._size return self._size
else: else:
@ -152,15 +165,25 @@ class Dataset(object):
Returns: Returns:
The splitted two sub datasets. The splitted two sub datasets.
Raises:
ValueError: if the provided fraction is not between 0 and 1.
ValueError: if this dataset does not have a set size.
""" """
assert (fraction > 0 and fraction < 1) if not (fraction > 0 and fraction < 1):
raise ValueError(f'Fraction must be between 0 and 1. Got:{fraction}')
if not self._size:
raise ValueError(
'Dataset size unknown. Cannot split the dataset when '
'the size is unknown.'
)
dataset = self._dataset dataset = self._dataset
train_size = int(self._size * fraction) train_size = int(self._size * fraction)
trainset = self.__class__(dataset.take(train_size), train_size, *args) trainset = self.__class__(dataset.take(train_size), *args, size=train_size)
test_size = self._size - train_size test_size = self._size - train_size
testset = self.__class__(dataset.skip(train_size), test_size, *args) testset = self.__class__(dataset.skip(train_size), *args, size=test_size)
return trainset, testset return trainset, testset

View File

@ -85,4 +85,5 @@ class Dataset(classification_dataset.ClassificationDataset):
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds)) text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
return Dataset( return Dataset(
dataset=text_label_ds, size=len(texts), label_names=label_names) dataset=text_label_ds, label_names=label_names, size=len(texts)
)

View File

@ -53,7 +53,7 @@ class DatasetTest(tf.test.TestCase):
def test_split(self): def test_split(self):
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd']) ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
data = dataset.Dataset(ds, 4, ['pos', 'neg']) data = dataset.Dataset(ds, ['pos', 'neg'], 4)
train_data, test_data = data.split(0.5) train_data, test_data = data.split(0.5)
expected_train_data = [b'good', b'bad'] expected_train_data = [b'good', b'bad']
expected_test_data = [b'neutral', b'odd'] expected_test_data = [b'neutral', b'odd']

View File

@ -115,5 +115,7 @@ class Dataset(classification_dataset.ClassificationDataset):
', '.join(label_names), ', '.join(label_names),
) )
return Dataset( return Dataset(
dataset=image_label_ds, size=all_image_size, label_names=label_names dataset=image_label_ds,
label_names=label_names,
size=all_image_size,
) )

View File

@ -249,5 +249,6 @@ class Dataset(classification_dataset.ClassificationDataset):
len(valid_hand_data), len(label_names), ','.join(label_names))) len(valid_hand_data), len(label_names), ','.join(label_names)))
return Dataset( return Dataset(
dataset=hand_embedding_label_ds, dataset=hand_embedding_label_ds,
label_names=label_names,
size=len(valid_hand_data), size=len(valid_hand_data),
label_names=label_names) )

View File

@ -15,28 +15,12 @@
import os import os
import random import random
from typing import List, Optional
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds
from mediapipe.model_maker.python.core.data import classification_dataset from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.vision.core import image_utils from mediapipe.model_maker.python.vision.core import image_utils
def _create_data(
name: str, data: tf.data.Dataset, info: tfds.core.DatasetInfo,
label_names: List[str]
) -> Optional[classification_dataset.ClassificationDataset]:
"""Creates a Dataset object from tfds data."""
if name not in data:
return None
data = data[name]
data = data.map(lambda a: (a['image'], a['label']))
size = info.splits[name].num_examples
return Dataset(data, size, label_names)
class Dataset(classification_dataset.ClassificationDataset): class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for image classifier.""" """Dataset library for image classifier."""
@ -99,4 +83,5 @@ class Dataset(classification_dataset.ClassificationDataset):
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size, 'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
all_label_size, ', '.join(label_names)) all_label_size, ', '.join(label_names))
return Dataset( return Dataset(
dataset=image_label_ds, size=all_image_size, label_names=label_names) dataset=image_label_ds, label_names=label_names, size=all_image_size
)

View File

@ -41,7 +41,7 @@ class DatasetTest(tf.test.TestCase):
def test_split(self): def test_split(self):
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = dataset.Dataset(dataset=ds, size=4, label_names=['pos', 'neg']) data = dataset.Dataset(dataset=ds, label_names=['pos', 'neg'], size=4)
train_data, test_data = data.split(fraction=0.5) train_data, test_data = data.split(fraction=0.5)
self.assertLen(train_data, 2) self.assertLen(train_data, 2)

View File

@ -52,8 +52,9 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
ds = tf.data.Dataset.from_generator( ds = tf.data.Dataset.from_generator(
self._gen, (tf.uint8, tf.int64), (tf.TensorShape( self._gen, (tf.uint8, tf.int64), (tf.TensorShape(
[self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([]))) [self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([])))
data = image_classifier.Dataset(ds, self.IMAGES_PER_CLASS * 3, data = image_classifier.Dataset(
['cyan', 'magenta', 'yellow']) ds, ['cyan', 'magenta', 'yellow'], self.IMAGES_PER_CLASS * 3
)
return data return data
def setUp(self): def setUp(self):

View File

@ -176,5 +176,5 @@ class Dataset(classification_dataset.ClassificationDataset):
label_names = [label_map[k] for k in sorted(label_map.keys())] label_names = [label_map[k] for k in sorted(label_map.keys())]
return Dataset( return Dataset(
dataset=dataset, size=meta_data['size'], label_names=label_names dataset=dataset, label_names=label_names, size=meta_data['size']
) )