Add a new from_image API to create face stylizer dataset from a single image. Also deprecate the from_folder API since we only support one-shot use case now.

PiperOrigin-RevId: 558912896
This commit is contained in:
MediaPipe Team 2023-08-21 15:12:23 -07:00 committed by Copybara-Service
parent ae9e945e0c
commit bbf168ddda
3 changed files with 34 additions and 69 deletions

View File

@ -51,71 +51,41 @@ class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for face stylizer fine tuning."""
@classmethod
def from_folder(
cls, dirname: str
def from_image(
cls, filename: str
) -> classification_dataset.ClassificationDataset:
"""Loads images from the given directory.
"""Creates a dataset from single image.
The style image dataset directory is expected to contain one subdirectory
whose name represents the label of the style. There can be one or multiple
images of the same style in that subdirectory. Supported input image formats
include 'jpg', 'jpeg', 'png'.
Supported input image formats include 'jpg', 'jpeg', 'png'.
Args:
dirname: Name of the directory containing the image files.
filename: Name of the image file.
Returns:
Dataset containing images and labels and other related info.
Raises:
ValueError: if the input data directory is empty.
Dataset containing image and label and other related info.
"""
data_root = os.path.abspath(dirname)
file_path = os.path.abspath(filename)
image_filename = os.path.basename(filename)
image_name, ext_name = os.path.splitext(image_filename)
# Assumes the image data of the same label are in the same subdirectory,
# gets image path and label names.
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
all_image_size = len(all_image_paths)
if all_image_size == 0:
raise ValueError('Invalid input data directory')
if not any(
fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths
):
raise ValueError('No images found under given directory')
if not ext_name.endswith(('.jpg', '.jpeg', '.png')):
raise ValueError('Unsupported image formats: %s' % ext_name)
image_data = _preprocess_face_dataset(all_image_paths)
label_names = sorted(
name
for name in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, name))
)
all_label_size = len(label_names)
index_by_label = dict(
(name, index) for index, name in enumerate(label_names)
)
# Get the style label from the subdirectory name.
all_image_labels = [
index_by_label[os.path.basename(os.path.dirname(path))]
for path in all_image_paths
]
image_data = _preprocess_face_dataset([file_path])
label_names = [image_name]
image_ds = tf.data.Dataset.from_tensor_slices(image_data)
# Load label
label_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(all_image_labels, tf.int64)
)
label_ds = tf.data.Dataset.from_tensor_slices(tf.cast([0], tf.int64))
# Create a dataset of (image, label) pairs
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
logging.info(
'Load images dataset with size: %d, num_label: %d, labels: %s.',
all_image_size,
all_label_size,
', '.join(label_names),
)
logging.info('Create dataset for style: %s.', image_name)
return Dataset(
dataset=image_label_ds,
label_names=label_names,
size=all_image_size,
size=1,
)

View File

@ -25,24 +25,17 @@ class DatasetTest(tf.test.TestCase):
def setUp(self):
super().setUp()
def test_from_folder(self):
test_data_dirname = 'input/style'
input_data_dir = test_utils.get_test_data_path(test_data_dirname)
data = dataset.Dataset.from_folder(dirname=input_data_dir)
self.assertEqual(data.num_classes, 2)
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
self.assertLen(data, 2)
def test_from_image(self):
test_image_file = 'input/style/cartoon/cartoon.jpg'
input_data_dir = test_utils.get_test_data_path(test_image_file)
data = dataset.Dataset.from_image(filename=input_data_dir)
self.assertEqual(data.num_classes, 1)
self.assertEqual(data.label_names, ['cartoon'])
self.assertLen(data, 1)
def test_from_folder_raise_value_error_for_invalid_path(self):
with self.assertRaisesRegex(ValueError, 'Invalid input data directory'):
dataset.Dataset.from_folder(dirname='invalid')
def test_from_folder_raise_value_error_for_valid_no_data_path(self):
input_data_dir = test_utils.get_test_data_path('face_stylizer')
with self.assertRaisesRegex(
ValueError, 'No images found under given directory'
):
dataset.Dataset.from_folder(dirname=input_data_dir)
def test_from_image_raise_value_error_for_invalid_path(self):
with self.assertRaisesRegex(ValueError, 'Unsupported image formats: .zip'):
dataset.Dataset.from_image(filename='input/style/cartoon/cartoon.zip')
if __name__ == '__main__':

View File

@ -24,11 +24,13 @@ from mediapipe.tasks.python.test import test_utils
class FaceStylizerTest(tf.test.TestCase):
def _load_data(self):
"""Loads training dataset."""
input_data_dir = test_utils.get_test_data_path('input/style')
def _create_training_dataset(self):
"""Creates training dataset."""
input_style_image_file = test_utils.get_test_data_path(
'input/style/cartoon/cartoon.jpg'
)
data = face_stylizer.Dataset.from_folder(dirname=input_data_dir)
data = face_stylizer.Dataset.from_image(filename=input_style_image_file)
return data
def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer):
@ -41,7 +43,7 @@ class FaceStylizerTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self._train_data = self._load_data()
self._train_data = self._create_training_dataset()
def test_finetuning_face_stylizer_with_single_input_style_image(self):
with self.test_session(use_gpu=True):