diff --git a/mediapipe/model_maker/python/vision/face_stylizer/dataset.py b/mediapipe/model_maker/python/vision/face_stylizer/dataset.py index 85802f908..fd86df960 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/dataset.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/dataset.py @@ -51,71 +51,41 @@ class Dataset(classification_dataset.ClassificationDataset): """Dataset library for face stylizer fine tuning.""" @classmethod - def from_folder( - cls, dirname: str + def from_image( + cls, filename: str ) -> classification_dataset.ClassificationDataset: - """Loads images from the given directory. + """Creates a dataset from single image. - The style image dataset directory is expected to contain one subdirectory - whose name represents the label of the style. There can be one or multiple - images of the same style in that subdirectory. Supported input image formats - include 'jpg', 'jpeg', 'png'. + Supported input image formats include 'jpg', 'jpeg', 'png'. Args: - dirname: Name of the directory containing the image files. + filename: Name of the image file. Returns: - Dataset containing images and labels and other related info. - Raises: - ValueError: if the input data directory is empty. + Dataset containing image and label and other related info. """ - data_root = os.path.abspath(dirname) + file_path = os.path.abspath(filename) + image_filename = os.path.basename(filename) + image_name, ext_name = os.path.splitext(image_filename) - # Assumes the image data of the same label are in the same subdirectory, - # gets image path and label names. - all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*')) - all_image_size = len(all_image_paths) - if all_image_size == 0: - raise ValueError('Invalid input data directory') - if not any( - fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths - ): - raise ValueError('No images found under given directory') + if not ext_name.endswith(('.jpg', '.jpeg', '.png')): + raise ValueError('Unsupported image formats: %s' % ext_name) - image_data = _preprocess_face_dataset(all_image_paths) - label_names = sorted( - name - for name in os.listdir(data_root) - if os.path.isdir(os.path.join(data_root, name)) - ) - all_label_size = len(label_names) - index_by_label = dict( - (name, index) for index, name in enumerate(label_names) - ) - # Get the style label from the subdirectory name. - all_image_labels = [ - index_by_label[os.path.basename(os.path.dirname(path))] - for path in all_image_paths - ] + image_data = _preprocess_face_dataset([file_path]) + label_names = [image_name] image_ds = tf.data.Dataset.from_tensor_slices(image_data) # Load label - label_ds = tf.data.Dataset.from_tensor_slices( - tf.cast(all_image_labels, tf.int64) - ) + label_ds = tf.data.Dataset.from_tensor_slices(tf.cast([0], tf.int64)) # Create a dataset of (image, label) pairs image_label_ds = tf.data.Dataset.zip((image_ds, label_ds)) - logging.info( - 'Load images dataset with size: %d, num_label: %d, labels: %s.', - all_image_size, - all_label_size, - ', '.join(label_names), - ) + logging.info('Create dataset for style: %s.', image_name) + return Dataset( dataset=image_label_ds, label_names=label_names, - size=all_image_size, + size=1, ) diff --git a/mediapipe/model_maker/python/vision/face_stylizer/dataset_test.py b/mediapipe/model_maker/python/vision/face_stylizer/dataset_test.py index 900371de1..914f50007 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/dataset_test.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/dataset_test.py @@ -25,24 +25,17 @@ class DatasetTest(tf.test.TestCase): def setUp(self): super().setUp() - def test_from_folder(self): - test_data_dirname = 'input/style' - input_data_dir = test_utils.get_test_data_path(test_data_dirname) - data = dataset.Dataset.from_folder(dirname=input_data_dir) - self.assertEqual(data.num_classes, 2) - self.assertEqual(data.label_names, ['cartoon', 'sketch']) - self.assertLen(data, 2) + def test_from_image(self): + test_image_file = 'input/style/cartoon/cartoon.jpg' + input_data_dir = test_utils.get_test_data_path(test_image_file) + data = dataset.Dataset.from_image(filename=input_data_dir) + self.assertEqual(data.num_classes, 1) + self.assertEqual(data.label_names, ['cartoon']) + self.assertLen(data, 1) - def test_from_folder_raise_value_error_for_invalid_path(self): - with self.assertRaisesRegex(ValueError, 'Invalid input data directory'): - dataset.Dataset.from_folder(dirname='invalid') - - def test_from_folder_raise_value_error_for_valid_no_data_path(self): - input_data_dir = test_utils.get_test_data_path('face_stylizer') - with self.assertRaisesRegex( - ValueError, 'No images found under given directory' - ): - dataset.Dataset.from_folder(dirname=input_data_dir) + def test_from_image_raise_value_error_for_invalid_path(self): + with self.assertRaisesRegex(ValueError, 'Unsupported image formats: .zip'): + dataset.Dataset.from_image(filename='input/style/cartoon/cartoon.zip') if __name__ == '__main__': diff --git a/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_test.py b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_test.py index c97c2199d..a815817ea 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_test.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_test.py @@ -24,11 +24,13 @@ from mediapipe.tasks.python.test import test_utils class FaceStylizerTest(tf.test.TestCase): - def _load_data(self): - """Loads training dataset.""" - input_data_dir = test_utils.get_test_data_path('input/style') + def _create_training_dataset(self): + """Creates training dataset.""" + input_style_image_file = test_utils.get_test_data_path( + 'input/style/cartoon/cartoon.jpg' + ) - data = face_stylizer.Dataset.from_folder(dirname=input_data_dir) + data = face_stylizer.Dataset.from_image(filename=input_style_image_file) return data def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer): @@ -41,7 +43,7 @@ class FaceStylizerTest(tf.test.TestCase): def setUp(self): super().setUp() - self._train_data = self._load_data() + self._train_data = self._create_training_dataset() def test_finetuning_face_stylizer_with_single_input_style_image(self): with self.test_session(use_gpu=True):