Add a new from_image API to create face stylizer dataset from a single image. Also deprecate the from_folder API since we only support one-shot use case now.
PiperOrigin-RevId: 558912896
This commit is contained in:
parent
ae9e945e0c
commit
bbf168ddda
|
@ -51,71 +51,41 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
"""Dataset library for face stylizer fine tuning."""
|
"""Dataset library for face stylizer fine tuning."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_folder(
|
def from_image(
|
||||||
cls, dirname: str
|
cls, filename: str
|
||||||
) -> classification_dataset.ClassificationDataset:
|
) -> classification_dataset.ClassificationDataset:
|
||||||
"""Loads images from the given directory.
|
"""Creates a dataset from single image.
|
||||||
|
|
||||||
The style image dataset directory is expected to contain one subdirectory
|
Supported input image formats include 'jpg', 'jpeg', 'png'.
|
||||||
whose name represents the label of the style. There can be one or multiple
|
|
||||||
images of the same style in that subdirectory. Supported input image formats
|
|
||||||
include 'jpg', 'jpeg', 'png'.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dirname: Name of the directory containing the image files.
|
filename: Name of the image file.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset containing images and labels and other related info.
|
Dataset containing image and label and other related info.
|
||||||
Raises:
|
|
||||||
ValueError: if the input data directory is empty.
|
|
||||||
"""
|
"""
|
||||||
data_root = os.path.abspath(dirname)
|
file_path = os.path.abspath(filename)
|
||||||
|
image_filename = os.path.basename(filename)
|
||||||
|
image_name, ext_name = os.path.splitext(image_filename)
|
||||||
|
|
||||||
# Assumes the image data of the same label are in the same subdirectory,
|
if not ext_name.endswith(('.jpg', '.jpeg', '.png')):
|
||||||
# gets image path and label names.
|
raise ValueError('Unsupported image formats: %s' % ext_name)
|
||||||
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
|
|
||||||
all_image_size = len(all_image_paths)
|
|
||||||
if all_image_size == 0:
|
|
||||||
raise ValueError('Invalid input data directory')
|
|
||||||
if not any(
|
|
||||||
fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths
|
|
||||||
):
|
|
||||||
raise ValueError('No images found under given directory')
|
|
||||||
|
|
||||||
image_data = _preprocess_face_dataset(all_image_paths)
|
image_data = _preprocess_face_dataset([file_path])
|
||||||
label_names = sorted(
|
label_names = [image_name]
|
||||||
name
|
|
||||||
for name in os.listdir(data_root)
|
|
||||||
if os.path.isdir(os.path.join(data_root, name))
|
|
||||||
)
|
|
||||||
all_label_size = len(label_names)
|
|
||||||
index_by_label = dict(
|
|
||||||
(name, index) for index, name in enumerate(label_names)
|
|
||||||
)
|
|
||||||
# Get the style label from the subdirectory name.
|
|
||||||
all_image_labels = [
|
|
||||||
index_by_label[os.path.basename(os.path.dirname(path))]
|
|
||||||
for path in all_image_paths
|
|
||||||
]
|
|
||||||
|
|
||||||
image_ds = tf.data.Dataset.from_tensor_slices(image_data)
|
image_ds = tf.data.Dataset.from_tensor_slices(image_data)
|
||||||
|
|
||||||
# Load label
|
# Load label
|
||||||
label_ds = tf.data.Dataset.from_tensor_slices(
|
label_ds = tf.data.Dataset.from_tensor_slices(tf.cast([0], tf.int64))
|
||||||
tf.cast(all_image_labels, tf.int64)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a dataset of (image, label) pairs
|
# Create a dataset of (image, label) pairs
|
||||||
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
|
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
|
||||||
|
|
||||||
logging.info(
|
logging.info('Create dataset for style: %s.', image_name)
|
||||||
'Load images dataset with size: %d, num_label: %d, labels: %s.',
|
|
||||||
all_image_size,
|
|
||||||
all_label_size,
|
|
||||||
', '.join(label_names),
|
|
||||||
)
|
|
||||||
return Dataset(
|
return Dataset(
|
||||||
dataset=image_label_ds,
|
dataset=image_label_ds,
|
||||||
label_names=label_names,
|
label_names=label_names,
|
||||||
size=all_image_size,
|
size=1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -25,24 +25,17 @@ class DatasetTest(tf.test.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
def test_from_folder(self):
|
def test_from_image(self):
|
||||||
test_data_dirname = 'input/style'
|
test_image_file = 'input/style/cartoon/cartoon.jpg'
|
||||||
input_data_dir = test_utils.get_test_data_path(test_data_dirname)
|
input_data_dir = test_utils.get_test_data_path(test_image_file)
|
||||||
data = dataset.Dataset.from_folder(dirname=input_data_dir)
|
data = dataset.Dataset.from_image(filename=input_data_dir)
|
||||||
self.assertEqual(data.num_classes, 2)
|
self.assertEqual(data.num_classes, 1)
|
||||||
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
|
self.assertEqual(data.label_names, ['cartoon'])
|
||||||
self.assertLen(data, 2)
|
self.assertLen(data, 1)
|
||||||
|
|
||||||
def test_from_folder_raise_value_error_for_invalid_path(self):
|
def test_from_image_raise_value_error_for_invalid_path(self):
|
||||||
with self.assertRaisesRegex(ValueError, 'Invalid input data directory'):
|
with self.assertRaisesRegex(ValueError, 'Unsupported image formats: .zip'):
|
||||||
dataset.Dataset.from_folder(dirname='invalid')
|
dataset.Dataset.from_image(filename='input/style/cartoon/cartoon.zip')
|
||||||
|
|
||||||
def test_from_folder_raise_value_error_for_valid_no_data_path(self):
|
|
||||||
input_data_dir = test_utils.get_test_data_path('face_stylizer')
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
ValueError, 'No images found under given directory'
|
|
||||||
):
|
|
||||||
dataset.Dataset.from_folder(dirname=input_data_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -24,11 +24,13 @@ from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
class FaceStylizerTest(tf.test.TestCase):
|
class FaceStylizerTest(tf.test.TestCase):
|
||||||
|
|
||||||
def _load_data(self):
|
def _create_training_dataset(self):
|
||||||
"""Loads training dataset."""
|
"""Creates training dataset."""
|
||||||
input_data_dir = test_utils.get_test_data_path('input/style')
|
input_style_image_file = test_utils.get_test_data_path(
|
||||||
|
'input/style/cartoon/cartoon.jpg'
|
||||||
|
)
|
||||||
|
|
||||||
data = face_stylizer.Dataset.from_folder(dirname=input_data_dir)
|
data = face_stylizer.Dataset.from_image(filename=input_style_image_file)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer):
|
def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer):
|
||||||
|
@ -41,7 +43,7 @@ class FaceStylizerTest(tf.test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self._train_data = self._load_data()
|
self._train_data = self._create_training_dataset()
|
||||||
|
|
||||||
def test_finetuning_face_stylizer_with_single_input_style_image(self):
|
def test_finetuning_face_stylizer_with_single_input_style_image(self):
|
||||||
with self.test_session(use_gpu=True):
|
with self.test_session(use_gpu=True):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user