Model Maker allow core dataset library to handle datasets with unknown sizes.

PiperOrigin-RevId: 547268411
This commit is contained in:
MediaPipe Team 2023-07-11 12:42:42 -07:00 committed by Copybara-Service
parent 4788fddde9
commit 56bc019819
11 changed files with 68 additions and 45 deletions

View File

@ -13,7 +13,7 @@
# limitations under the License.
"""Common classification dataset library."""
from typing import List, Tuple
from typing import List, Optional, Tuple
import tensorflow as tf
@ -23,8 +23,12 @@ from mediapipe.model_maker.python.core.data import dataset as ds
class ClassificationDataset(ds.Dataset):
"""Dataset Loader for classification models."""
def __init__(self, dataset: tf.data.Dataset, size: int,
label_names: List[str]):
def __init__(
self,
dataset: tf.data.Dataset,
label_names: List[str],
size: Optional[int] = None,
):
super().__init__(dataset, size)
self._label_names = label_names

View File

@ -36,9 +36,14 @@ class ClassificationDatasetTest(tf.test.TestCase):
value: A value variable stored by the mock dataset class for testing.
"""
def __init__(self, dataset: tf.data.Dataset, size: int,
label_names: List[str], value: Any):
super().__init__(dataset=dataset, size=size, label_names=label_names)
def __init__(
self,
dataset: tf.data.Dataset,
label_names: List[str],
value: Any,
size: int,
):
super().__init__(dataset=dataset, label_names=label_names, size=size)
self.value = value
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
@ -52,7 +57,8 @@ class ClassificationDatasetTest(tf.test.TestCase):
# Create data loader from sample data.
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = MagicClassificationDataset(
dataset=ds, size=len(ds), label_names=label_names, value=magic_value)
dataset=ds, label_names=label_names, value=magic_value, size=len(ds)
)
# Train/Test data split.
fraction = .25

View File

@ -56,15 +56,14 @@ class Dataset(object):
def size(self) -> Optional[int]:
"""Returns the size of the dataset.
Note that this function may return None becuase the exact size of the
dataset isn't a necessary parameter to create an instance of this class,
and tf.data.Dataset donesn't support a function to get the length directly
since it's lazy-loaded and may be infinite.
In most cases, however, when an instance of this class is created by helper
functions like 'from_folder', the size of the dataset will be preprocessed,
and this function can return an int representing the size of the dataset.
Same functionality as calling __len__. See the __len__ method definition for
more information.
Raises:
TypeError if self._size is not set and the cardinality of self._dataset
is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY.
"""
return self._size
return self.__len__()
def gen_tf_dataset(
self,
@ -116,8 +115,22 @@ class Dataset(object):
# here.
return dataset
def __len__(self):
"""Returns the number of element of the dataset."""
def __len__(self) -> int:
"""Returns the number of element of the dataset.
If size is not set, this method will fallback to using the __len__ method
of the tf.data.Dataset in self._dataset. Calling __len__ on a
tf.data.Dataset instance may throw a TypeError because the dataset may
be lazy-loaded with an unknown size or have infinite size.
In most cases, however, when an instance of this class is created by helper
functions like 'from_folder', the size of the dataset will be preprocessed,
and the _size instance variable will be already set.
Raises:
TypeError if self._size is not set and the cardinality of self._dataset
is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY.
"""
if self._size is not None:
return self._size
else:
@ -152,15 +165,25 @@ class Dataset(object):
Returns:
The splitted two sub datasets.
Raises:
ValueError: if the provided fraction is not between 0 and 1.
ValueError: if this dataset does not have a set size.
"""
assert (fraction > 0 and fraction < 1)
if not (fraction > 0 and fraction < 1):
raise ValueError(f'Fraction must be between 0 and 1. Got:{fraction}')
if not self._size:
raise ValueError(
'Dataset size unknown. Cannot split the dataset when '
'the size is unknown.'
)
dataset = self._dataset
train_size = int(self._size * fraction)
trainset = self.__class__(dataset.take(train_size), train_size, *args)
trainset = self.__class__(dataset.take(train_size), *args, size=train_size)
test_size = self._size - train_size
testset = self.__class__(dataset.skip(train_size), test_size, *args)
testset = self.__class__(dataset.skip(train_size), *args, size=test_size)
return trainset, testset

View File

@ -85,4 +85,5 @@ class Dataset(classification_dataset.ClassificationDataset):
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
return Dataset(
dataset=text_label_ds, size=len(texts), label_names=label_names)
dataset=text_label_ds, label_names=label_names, size=len(texts)
)

View File

@ -53,7 +53,7 @@ class DatasetTest(tf.test.TestCase):
def test_split(self):
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
data = dataset.Dataset(ds, ['pos', 'neg'], 4)
train_data, test_data = data.split(0.5)
expected_train_data = [b'good', b'bad']
expected_test_data = [b'neutral', b'odd']

View File

@ -115,5 +115,7 @@ class Dataset(classification_dataset.ClassificationDataset):
', '.join(label_names),
)
return Dataset(
dataset=image_label_ds, size=all_image_size, label_names=label_names
dataset=image_label_ds,
label_names=label_names,
size=all_image_size,
)

View File

@ -249,5 +249,6 @@ class Dataset(classification_dataset.ClassificationDataset):
len(valid_hand_data), len(label_names), ','.join(label_names)))
return Dataset(
dataset=hand_embedding_label_ds,
label_names=label_names,
size=len(valid_hand_data),
label_names=label_names)
)

View File

@ -15,28 +15,12 @@
import os
import random
from typing import List, Optional
import tensorflow as tf
import tensorflow_datasets as tfds
from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.vision.core import image_utils
def _create_data(
name: str, data: tf.data.Dataset, info: tfds.core.DatasetInfo,
label_names: List[str]
) -> Optional[classification_dataset.ClassificationDataset]:
"""Creates a Dataset object from tfds data."""
if name not in data:
return None
data = data[name]
data = data.map(lambda a: (a['image'], a['label']))
size = info.splits[name].num_examples
return Dataset(data, size, label_names)
class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for image classifier."""
@ -99,4 +83,5 @@ class Dataset(classification_dataset.ClassificationDataset):
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
all_label_size, ', '.join(label_names))
return Dataset(
dataset=image_label_ds, size=all_image_size, label_names=label_names)
dataset=image_label_ds, label_names=label_names, size=all_image_size
)

View File

@ -41,7 +41,7 @@ class DatasetTest(tf.test.TestCase):
def test_split(self):
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = dataset.Dataset(dataset=ds, size=4, label_names=['pos', 'neg'])
data = dataset.Dataset(dataset=ds, label_names=['pos', 'neg'], size=4)
train_data, test_data = data.split(fraction=0.5)
self.assertLen(train_data, 2)

View File

@ -52,8 +52,9 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
ds = tf.data.Dataset.from_generator(
self._gen, (tf.uint8, tf.int64), (tf.TensorShape(
[self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([])))
data = image_classifier.Dataset(ds, self.IMAGES_PER_CLASS * 3,
['cyan', 'magenta', 'yellow'])
data = image_classifier.Dataset(
ds, ['cyan', 'magenta', 'yellow'], self.IMAGES_PER_CLASS * 3
)
return data
def setUp(self):

View File

@ -176,5 +176,5 @@ class Dataset(classification_dataset.ClassificationDataset):
label_names = [label_map[k] for k in sorted(label_map.keys())]
return Dataset(
dataset=dataset, size=meta_data['size'], label_names=label_names
dataset=dataset, label_names=label_names, size=meta_data['size']
)