Add a new from_image API to create face stylizer dataset from a single image. Also deprecate the from_folder API since we only support one-shot use case now.

PiperOrigin-RevId: 558912896
This commit is contained in:
MediaPipe Team 2023-08-21 15:12:23 -07:00 committed by Copybara-Service
parent ae9e945e0c
commit bbf168ddda
3 changed files with 34 additions and 69 deletions

View File

@ -51,71 +51,41 @@ class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for face stylizer fine tuning.""" """Dataset library for face stylizer fine tuning."""
@classmethod @classmethod
def from_folder( def from_image(
cls, dirname: str cls, filename: str
) -> classification_dataset.ClassificationDataset: ) -> classification_dataset.ClassificationDataset:
"""Loads images from the given directory. """Creates a dataset from single image.
The style image dataset directory is expected to contain one subdirectory Supported input image formats include 'jpg', 'jpeg', 'png'.
whose name represents the label of the style. There can be one or multiple
images of the same style in that subdirectory. Supported input image formats
include 'jpg', 'jpeg', 'png'.
Args: Args:
dirname: Name of the directory containing the image files. filename: Name of the image file.
Returns: Returns:
Dataset containing images and labels and other related info. Dataset containing image and label and other related info.
Raises:
ValueError: if the input data directory is empty.
""" """
data_root = os.path.abspath(dirname) file_path = os.path.abspath(filename)
image_filename = os.path.basename(filename)
image_name, ext_name = os.path.splitext(image_filename)
# Assumes the image data of the same label are in the same subdirectory, if not ext_name.endswith(('.jpg', '.jpeg', '.png')):
# gets image path and label names. raise ValueError('Unsupported image formats: %s' % ext_name)
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
all_image_size = len(all_image_paths)
if all_image_size == 0:
raise ValueError('Invalid input data directory')
if not any(
fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths
):
raise ValueError('No images found under given directory')
image_data = _preprocess_face_dataset(all_image_paths) image_data = _preprocess_face_dataset([file_path])
label_names = sorted( label_names = [image_name]
name
for name in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, name))
)
all_label_size = len(label_names)
index_by_label = dict(
(name, index) for index, name in enumerate(label_names)
)
# Get the style label from the subdirectory name.
all_image_labels = [
index_by_label[os.path.basename(os.path.dirname(path))]
for path in all_image_paths
]
image_ds = tf.data.Dataset.from_tensor_slices(image_data) image_ds = tf.data.Dataset.from_tensor_slices(image_data)
# Load label # Load label
label_ds = tf.data.Dataset.from_tensor_slices( label_ds = tf.data.Dataset.from_tensor_slices(tf.cast([0], tf.int64))
tf.cast(all_image_labels, tf.int64)
)
# Create a dataset of (image, label) pairs # Create a dataset of (image, label) pairs
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds)) image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
logging.info( logging.info('Create dataset for style: %s.', image_name)
'Load images dataset with size: %d, num_label: %d, labels: %s.',
all_image_size,
all_label_size,
', '.join(label_names),
)
return Dataset( return Dataset(
dataset=image_label_ds, dataset=image_label_ds,
label_names=label_names, label_names=label_names,
size=all_image_size, size=1,
) )

View File

@ -25,24 +25,17 @@ class DatasetTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
def test_from_folder(self): def test_from_image(self):
test_data_dirname = 'input/style' test_image_file = 'input/style/cartoon/cartoon.jpg'
input_data_dir = test_utils.get_test_data_path(test_data_dirname) input_data_dir = test_utils.get_test_data_path(test_image_file)
data = dataset.Dataset.from_folder(dirname=input_data_dir) data = dataset.Dataset.from_image(filename=input_data_dir)
self.assertEqual(data.num_classes, 2) self.assertEqual(data.num_classes, 1)
self.assertEqual(data.label_names, ['cartoon', 'sketch']) self.assertEqual(data.label_names, ['cartoon'])
self.assertLen(data, 2) self.assertLen(data, 1)
def test_from_folder_raise_value_error_for_invalid_path(self): def test_from_image_raise_value_error_for_invalid_path(self):
with self.assertRaisesRegex(ValueError, 'Invalid input data directory'): with self.assertRaisesRegex(ValueError, 'Unsupported image formats: .zip'):
dataset.Dataset.from_folder(dirname='invalid') dataset.Dataset.from_image(filename='input/style/cartoon/cartoon.zip')
def test_from_folder_raise_value_error_for_valid_no_data_path(self):
input_data_dir = test_utils.get_test_data_path('face_stylizer')
with self.assertRaisesRegex(
ValueError, 'No images found under given directory'
):
dataset.Dataset.from_folder(dirname=input_data_dir)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -24,11 +24,13 @@ from mediapipe.tasks.python.test import test_utils
class FaceStylizerTest(tf.test.TestCase): class FaceStylizerTest(tf.test.TestCase):
def _load_data(self): def _create_training_dataset(self):
"""Loads training dataset.""" """Creates training dataset."""
input_data_dir = test_utils.get_test_data_path('input/style') input_style_image_file = test_utils.get_test_data_path(
'input/style/cartoon/cartoon.jpg'
)
data = face_stylizer.Dataset.from_folder(dirname=input_data_dir) data = face_stylizer.Dataset.from_image(filename=input_style_image_file)
return data return data
def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer): def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer):
@ -41,7 +43,7 @@ class FaceStylizerTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self._train_data = self._load_data() self._train_data = self._create_training_dataset()
def test_finetuning_face_stylizer_with_single_input_style_image(self): def test_finetuning_face_stylizer_with_single_input_style_image(self):
with self.test_session(use_gpu=True): with self.test_session(use_gpu=True):