Add a new from_image API to create face stylizer dataset from a single image. Also deprecate the from_folder API since we only support one-shot use case now.
PiperOrigin-RevId: 558912896
This commit is contained in:
parent
ae9e945e0c
commit
bbf168ddda
|
@ -51,71 +51,41 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
"""Dataset library for face stylizer fine tuning."""
|
||||
|
||||
@classmethod
|
||||
def from_folder(
|
||||
cls, dirname: str
|
||||
def from_image(
|
||||
cls, filename: str
|
||||
) -> classification_dataset.ClassificationDataset:
|
||||
"""Loads images from the given directory.
|
||||
"""Creates a dataset from single image.
|
||||
|
||||
The style image dataset directory is expected to contain one subdirectory
|
||||
whose name represents the label of the style. There can be one or multiple
|
||||
images of the same style in that subdirectory. Supported input image formats
|
||||
include 'jpg', 'jpeg', 'png'.
|
||||
Supported input image formats include 'jpg', 'jpeg', 'png'.
|
||||
|
||||
Args:
|
||||
dirname: Name of the directory containing the image files.
|
||||
filename: Name of the image file.
|
||||
|
||||
Returns:
|
||||
Dataset containing images and labels and other related info.
|
||||
Raises:
|
||||
ValueError: if the input data directory is empty.
|
||||
Dataset containing image and label and other related info.
|
||||
"""
|
||||
data_root = os.path.abspath(dirname)
|
||||
file_path = os.path.abspath(filename)
|
||||
image_filename = os.path.basename(filename)
|
||||
image_name, ext_name = os.path.splitext(image_filename)
|
||||
|
||||
# Assumes the image data of the same label are in the same subdirectory,
|
||||
# gets image path and label names.
|
||||
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
|
||||
all_image_size = len(all_image_paths)
|
||||
if all_image_size == 0:
|
||||
raise ValueError('Invalid input data directory')
|
||||
if not any(
|
||||
fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths
|
||||
):
|
||||
raise ValueError('No images found under given directory')
|
||||
if not ext_name.endswith(('.jpg', '.jpeg', '.png')):
|
||||
raise ValueError('Unsupported image formats: %s' % ext_name)
|
||||
|
||||
image_data = _preprocess_face_dataset(all_image_paths)
|
||||
label_names = sorted(
|
||||
name
|
||||
for name in os.listdir(data_root)
|
||||
if os.path.isdir(os.path.join(data_root, name))
|
||||
)
|
||||
all_label_size = len(label_names)
|
||||
index_by_label = dict(
|
||||
(name, index) for index, name in enumerate(label_names)
|
||||
)
|
||||
# Get the style label from the subdirectory name.
|
||||
all_image_labels = [
|
||||
index_by_label[os.path.basename(os.path.dirname(path))]
|
||||
for path in all_image_paths
|
||||
]
|
||||
image_data = _preprocess_face_dataset([file_path])
|
||||
label_names = [image_name]
|
||||
|
||||
image_ds = tf.data.Dataset.from_tensor_slices(image_data)
|
||||
|
||||
# Load label
|
||||
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||
tf.cast(all_image_labels, tf.int64)
|
||||
)
|
||||
label_ds = tf.data.Dataset.from_tensor_slices(tf.cast([0], tf.int64))
|
||||
|
||||
# Create a dataset of (image, label) pairs
|
||||
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
|
||||
|
||||
logging.info(
|
||||
'Load images dataset with size: %d, num_label: %d, labels: %s.',
|
||||
all_image_size,
|
||||
all_label_size,
|
||||
', '.join(label_names),
|
||||
)
|
||||
logging.info('Create dataset for style: %s.', image_name)
|
||||
|
||||
return Dataset(
|
||||
dataset=image_label_ds,
|
||||
label_names=label_names,
|
||||
size=all_image_size,
|
||||
size=1,
|
||||
)
|
||||
|
|
|
@ -25,24 +25,17 @@ class DatasetTest(tf.test.TestCase):
|
|||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
def test_from_folder(self):
|
||||
test_data_dirname = 'input/style'
|
||||
input_data_dir = test_utils.get_test_data_path(test_data_dirname)
|
||||
data = dataset.Dataset.from_folder(dirname=input_data_dir)
|
||||
self.assertEqual(data.num_classes, 2)
|
||||
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
|
||||
self.assertLen(data, 2)
|
||||
def test_from_image(self):
|
||||
test_image_file = 'input/style/cartoon/cartoon.jpg'
|
||||
input_data_dir = test_utils.get_test_data_path(test_image_file)
|
||||
data = dataset.Dataset.from_image(filename=input_data_dir)
|
||||
self.assertEqual(data.num_classes, 1)
|
||||
self.assertEqual(data.label_names, ['cartoon'])
|
||||
self.assertLen(data, 1)
|
||||
|
||||
def test_from_folder_raise_value_error_for_invalid_path(self):
|
||||
with self.assertRaisesRegex(ValueError, 'Invalid input data directory'):
|
||||
dataset.Dataset.from_folder(dirname='invalid')
|
||||
|
||||
def test_from_folder_raise_value_error_for_valid_no_data_path(self):
|
||||
input_data_dir = test_utils.get_test_data_path('face_stylizer')
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'No images found under given directory'
|
||||
):
|
||||
dataset.Dataset.from_folder(dirname=input_data_dir)
|
||||
def test_from_image_raise_value_error_for_invalid_path(self):
|
||||
with self.assertRaisesRegex(ValueError, 'Unsupported image formats: .zip'):
|
||||
dataset.Dataset.from_image(filename='input/style/cartoon/cartoon.zip')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -24,11 +24,13 @@ from mediapipe.tasks.python.test import test_utils
|
|||
|
||||
class FaceStylizerTest(tf.test.TestCase):
|
||||
|
||||
def _load_data(self):
|
||||
"""Loads training dataset."""
|
||||
input_data_dir = test_utils.get_test_data_path('input/style')
|
||||
def _create_training_dataset(self):
|
||||
"""Creates training dataset."""
|
||||
input_style_image_file = test_utils.get_test_data_path(
|
||||
'input/style/cartoon/cartoon.jpg'
|
||||
)
|
||||
|
||||
data = face_stylizer.Dataset.from_folder(dirname=input_data_dir)
|
||||
data = face_stylizer.Dataset.from_image(filename=input_style_image_file)
|
||||
return data
|
||||
|
||||
def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer):
|
||||
|
@ -41,7 +43,7 @@ class FaceStylizerTest(tf.test.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._train_data = self._load_data()
|
||||
self._train_data = self._create_training_dataset()
|
||||
|
||||
def test_finetuning_face_stylizer_with_single_input_style_image(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
|
|
Loading…
Reference in New Issue
Block a user