Add a new from_image API to create face stylizer dataset from a single image. Also deprecate the from_folder API since we only support one-shot use case now.
PiperOrigin-RevId: 558912896
This commit is contained in:
		
							parent
							
								
									ae9e945e0c
								
							
						
					
					
						commit
						bbf168ddda
					
				|  | @ -51,71 +51,41 @@ class Dataset(classification_dataset.ClassificationDataset): | |||
|   """Dataset library for face stylizer fine tuning.""" | ||||
| 
 | ||||
|   @classmethod | ||||
|   def from_folder( | ||||
|       cls, dirname: str | ||||
|   def from_image( | ||||
|       cls, filename: str | ||||
|   ) -> classification_dataset.ClassificationDataset: | ||||
|     """Loads images from the given directory. | ||||
|     """Creates a dataset from single image. | ||||
| 
 | ||||
|     The style image dataset directory is expected to contain one subdirectory | ||||
|     whose name represents the label of the style. There can be one or multiple | ||||
|     images of the same style in that subdirectory. Supported input image formats | ||||
|     include 'jpg', 'jpeg', 'png'. | ||||
|     Supported input image formats include 'jpg', 'jpeg', 'png'. | ||||
| 
 | ||||
|     Args: | ||||
|       dirname: Name of the directory containing the image files. | ||||
|       filename: Name of the image file. | ||||
| 
 | ||||
|     Returns: | ||||
|       Dataset containing images and labels and other related info. | ||||
|     Raises: | ||||
|       ValueError: if the input data directory is empty. | ||||
|       Dataset containing image and label and other related info. | ||||
|     """ | ||||
|     data_root = os.path.abspath(dirname) | ||||
|     file_path = os.path.abspath(filename) | ||||
|     image_filename = os.path.basename(filename) | ||||
|     image_name, ext_name = os.path.splitext(image_filename) | ||||
| 
 | ||||
|     # Assumes the image data of the same label are in the same subdirectory, | ||||
|     # gets image path and label names. | ||||
|     all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*')) | ||||
|     all_image_size = len(all_image_paths) | ||||
|     if all_image_size == 0: | ||||
|       raise ValueError('Invalid input data directory') | ||||
|     if not any( | ||||
|         fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths | ||||
|     ): | ||||
|       raise ValueError('No images found under given directory') | ||||
|     if not ext_name.endswith(('.jpg', '.jpeg', '.png')): | ||||
|       raise ValueError('Unsupported image formats: %s' % ext_name) | ||||
| 
 | ||||
|     image_data = _preprocess_face_dataset(all_image_paths) | ||||
|     label_names = sorted( | ||||
|         name | ||||
|         for name in os.listdir(data_root) | ||||
|         if os.path.isdir(os.path.join(data_root, name)) | ||||
|     ) | ||||
|     all_label_size = len(label_names) | ||||
|     index_by_label = dict( | ||||
|         (name, index) for index, name in enumerate(label_names) | ||||
|     ) | ||||
|     # Get the style label from the subdirectory name. | ||||
|     all_image_labels = [ | ||||
|         index_by_label[os.path.basename(os.path.dirname(path))] | ||||
|         for path in all_image_paths | ||||
|     ] | ||||
|     image_data = _preprocess_face_dataset([file_path]) | ||||
|     label_names = [image_name] | ||||
| 
 | ||||
|     image_ds = tf.data.Dataset.from_tensor_slices(image_data) | ||||
| 
 | ||||
|     # Load label | ||||
|     label_ds = tf.data.Dataset.from_tensor_slices( | ||||
|         tf.cast(all_image_labels, tf.int64) | ||||
|     ) | ||||
|     label_ds = tf.data.Dataset.from_tensor_slices(tf.cast([0], tf.int64)) | ||||
| 
 | ||||
|     # Create a dataset of (image, label) pairs | ||||
|     image_label_ds = tf.data.Dataset.zip((image_ds, label_ds)) | ||||
| 
 | ||||
|     logging.info( | ||||
|         'Load images dataset with size: %d, num_label: %d, labels: %s.', | ||||
|         all_image_size, | ||||
|         all_label_size, | ||||
|         ', '.join(label_names), | ||||
|     ) | ||||
|     logging.info('Create dataset for style: %s.', image_name) | ||||
| 
 | ||||
|     return Dataset( | ||||
|         dataset=image_label_ds, | ||||
|         label_names=label_names, | ||||
|         size=all_image_size, | ||||
|         size=1, | ||||
|     ) | ||||
|  |  | |||
|  | @ -25,24 +25,17 @@ class DatasetTest(tf.test.TestCase): | |||
|   def setUp(self): | ||||
|     super().setUp() | ||||
| 
 | ||||
|   def test_from_folder(self): | ||||
|     test_data_dirname = 'input/style' | ||||
|     input_data_dir = test_utils.get_test_data_path(test_data_dirname) | ||||
|     data = dataset.Dataset.from_folder(dirname=input_data_dir) | ||||
|     self.assertEqual(data.num_classes, 2) | ||||
|     self.assertEqual(data.label_names, ['cartoon', 'sketch']) | ||||
|     self.assertLen(data, 2) | ||||
|   def test_from_image(self): | ||||
|     test_image_file = 'input/style/cartoon/cartoon.jpg' | ||||
|     input_data_dir = test_utils.get_test_data_path(test_image_file) | ||||
|     data = dataset.Dataset.from_image(filename=input_data_dir) | ||||
|     self.assertEqual(data.num_classes, 1) | ||||
|     self.assertEqual(data.label_names, ['cartoon']) | ||||
|     self.assertLen(data, 1) | ||||
| 
 | ||||
|   def test_from_folder_raise_value_error_for_invalid_path(self): | ||||
|     with self.assertRaisesRegex(ValueError, 'Invalid input data directory'): | ||||
|       dataset.Dataset.from_folder(dirname='invalid') | ||||
| 
 | ||||
|   def test_from_folder_raise_value_error_for_valid_no_data_path(self): | ||||
|     input_data_dir = test_utils.get_test_data_path('face_stylizer') | ||||
|     with self.assertRaisesRegex( | ||||
|         ValueError, 'No images found under given directory' | ||||
|     ): | ||||
|       dataset.Dataset.from_folder(dirname=input_data_dir) | ||||
|   def test_from_image_raise_value_error_for_invalid_path(self): | ||||
|     with self.assertRaisesRegex(ValueError, 'Unsupported image formats: .zip'): | ||||
|       dataset.Dataset.from_image(filename='input/style/cartoon/cartoon.zip') | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|  |  | |||
|  | @ -24,11 +24,13 @@ from mediapipe.tasks.python.test import test_utils | |||
| 
 | ||||
| class FaceStylizerTest(tf.test.TestCase): | ||||
| 
 | ||||
|   def _load_data(self): | ||||
|     """Loads training dataset.""" | ||||
|     input_data_dir = test_utils.get_test_data_path('input/style') | ||||
|   def _create_training_dataset(self): | ||||
|     """Creates training dataset.""" | ||||
|     input_style_image_file = test_utils.get_test_data_path( | ||||
|         'input/style/cartoon/cartoon.jpg' | ||||
|     ) | ||||
| 
 | ||||
|     data = face_stylizer.Dataset.from_folder(dirname=input_data_dir) | ||||
|     data = face_stylizer.Dataset.from_image(filename=input_style_image_file) | ||||
|     return data | ||||
| 
 | ||||
|   def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer): | ||||
|  | @ -41,7 +43,7 @@ class FaceStylizerTest(tf.test.TestCase): | |||
| 
 | ||||
|   def setUp(self): | ||||
|     super().setUp() | ||||
|     self._train_data = self._load_data() | ||||
|     self._train_data = self._create_training_dataset() | ||||
| 
 | ||||
|   def test_finetuning_face_stylizer_with_single_input_style_image(self): | ||||
|     with self.test_session(use_gpu=True): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user