Add the dataset module for face stylizer in model maker.
PiperOrigin-RevId: 516628350
This commit is contained in:
		
							parent
							
								
									ade31b567b
								
							
						
					
					
						commit
						6774794d02
					
				
							
								
								
									
										48
									
								
								mediapipe/model_maker/python/vision/face_stylizer/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								mediapipe/model_maker/python/vision/face_stylizer/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,48 @@
 | 
				
			||||||
 | 
					# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Placeholder for internal Python strict test compatibility macro.
 | 
				
			||||||
 | 
					# Placeholder for internal Python strict library and test compatibility macro.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe:__subpackages__"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					filegroup(
 | 
				
			||||||
 | 
					    name = "testdata",
 | 
				
			||||||
 | 
					    srcs = glob([
 | 
				
			||||||
 | 
					        "testdata/**",
 | 
				
			||||||
 | 
					    ]),
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_library(
 | 
				
			||||||
 | 
					    name = "dataset",
 | 
				
			||||||
 | 
					    srcs = ["dataset.py"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/model_maker/python/core/data:classification_dataset",
 | 
				
			||||||
 | 
					        "//mediapipe/model_maker/python/vision/core:image_utils",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_test(
 | 
				
			||||||
 | 
					    name = "dataset_test",
 | 
				
			||||||
 | 
					    srcs = ["dataset_test.py"],
 | 
				
			||||||
 | 
					    data = [
 | 
				
			||||||
 | 
					        ":testdata",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":dataset",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/test:test_utils",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,14 @@
 | 
				
			||||||
 | 
					# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""MediaPipe Model Maker Python Public API For Face Stylization."""
 | 
				
			||||||
							
								
								
									
										98
									
								
								mediapipe/model_maker/python/vision/face_stylizer/dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								mediapipe/model_maker/python/vision/face_stylizer/dataset.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,98 @@
 | 
				
			||||||
 | 
					# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""Face stylizer dataset library."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.core.data import classification_dataset
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.vision.core import image_utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# TODO: Change to a unlabeled dataset if it makes sense.
 | 
				
			||||||
 | 
					class Dataset(classification_dataset.ClassificationDataset):
 | 
				
			||||||
 | 
					  """Dataset library for face stylizer fine tuning."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @classmethod
 | 
				
			||||||
 | 
					  def from_folder(
 | 
				
			||||||
 | 
					      cls, dirname: str
 | 
				
			||||||
 | 
					  ) -> classification_dataset.ClassificationDataset:
 | 
				
			||||||
 | 
					    """Loads images from the given directory.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The style image dataset directory is expected to contain one subdirectory
 | 
				
			||||||
 | 
					    whose name represents the label of the style. There can be one or multiple
 | 
				
			||||||
 | 
					    images of the same style in that subdirectory. Supported input image formats
 | 
				
			||||||
 | 
					    include 'jpg', 'jpeg', 'png'.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      dirname: Name of the directory containing the image files.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      Dataset containing images and labels and other related info.
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError: if the input data directory is empty.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    data_root = os.path.abspath(dirname)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Assumes the image data of the same label are in the same subdirectory,
 | 
				
			||||||
 | 
					    # gets image path and label names.
 | 
				
			||||||
 | 
					    all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
 | 
				
			||||||
 | 
					    all_image_size = len(all_image_paths)
 | 
				
			||||||
 | 
					    if all_image_size == 0:
 | 
				
			||||||
 | 
					      raise ValueError('Invalid input data directory')
 | 
				
			||||||
 | 
					    if not any(
 | 
				
			||||||
 | 
					        fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					      raise ValueError('No images found under given directory')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    label_names = sorted(
 | 
				
			||||||
 | 
					        name
 | 
				
			||||||
 | 
					        for name in os.listdir(data_root)
 | 
				
			||||||
 | 
					        if os.path.isdir(os.path.join(data_root, name))
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    all_label_size = len(label_names)
 | 
				
			||||||
 | 
					    index_by_label = dict(
 | 
				
			||||||
 | 
					        (name, index) for index, name in enumerate(label_names)
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    # Get the style label from the subdirectory name.
 | 
				
			||||||
 | 
					    all_image_labels = [
 | 
				
			||||||
 | 
					        index_by_label[os.path.basename(os.path.dirname(path))]
 | 
				
			||||||
 | 
					        for path in all_image_paths
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    image_ds = path_ds.map(
 | 
				
			||||||
 | 
					        image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Load label
 | 
				
			||||||
 | 
					    label_ds = tf.data.Dataset.from_tensor_slices(
 | 
				
			||||||
 | 
					        tf.cast(all_image_labels, tf.int64)
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create a dataset of (image, label) pairs
 | 
				
			||||||
 | 
					    image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logging.info(
 | 
				
			||||||
 | 
					        'Load images dataset with size: %d, num_label: %d, labels: %s.',
 | 
				
			||||||
 | 
					        all_image_size,
 | 
				
			||||||
 | 
					        all_label_size,
 | 
				
			||||||
 | 
					        ', '.join(label_names),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    return Dataset(
 | 
				
			||||||
 | 
					        dataset=image_label_ds, size=all_image_size, label_names=label_names
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,48 @@
 | 
				
			||||||
 | 
					# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.vision.face_stylizer import dataset
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.test import test_utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DatasetTest(tf.test.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def setUp(self):
 | 
				
			||||||
 | 
					    super().setUp()
 | 
				
			||||||
 | 
					    # TODO: Replace the stylize image dataset with licensed images.
 | 
				
			||||||
 | 
					    self._test_data_dirname = 'testdata'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_from_folder(self):
 | 
				
			||||||
 | 
					    input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
 | 
				
			||||||
 | 
					    data = dataset.Dataset.from_folder(dirname=input_data_dir)
 | 
				
			||||||
 | 
					    self.assertEqual(data.num_classes, 2)
 | 
				
			||||||
 | 
					    self.assertEqual(data.label_names, ['cartoon', 'sketch'])
 | 
				
			||||||
 | 
					    self.assertLen(data, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_from_folder_raise_value_error_for_invalid_path(self):
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(ValueError, 'Invalid input data directory'):
 | 
				
			||||||
 | 
					      dataset.Dataset.from_folder(dirname='invalid')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_from_folder_raise_value_error_for_valid_no_data_path(self):
 | 
				
			||||||
 | 
					    input_data_dir = test_utils.get_test_data_path('face_stylizer')
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(
 | 
				
			||||||
 | 
					        ValueError, 'No images found under given directory'
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					      dataset.Dataset.from_folder(dirname=input_data_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					  tf.test.main()
 | 
				
			||||||
							
								
								
									
										
											BIN
										
									
								
								mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 347 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 336 KiB  | 
		Loading…
	
		Reference in New Issue
	
	Block a user