Add the dataset module for face stylizer in model maker.
PiperOrigin-RevId: 516628350
This commit is contained in:
		
							parent
							
								
									ade31b567b
								
							
						
					
					
						commit
						6774794d02
					
				
							
								
								
									
										48
									
								
								mediapipe/model_maker/python/vision/face_stylizer/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								mediapipe/model_maker/python/vision/face_stylizer/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,48 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
# Placeholder for internal Python strict test compatibility macro.
 | 
			
		||||
# Placeholder for internal Python strict library and test compatibility macro.
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
package(default_visibility = ["//mediapipe:__subpackages__"])
 | 
			
		||||
 | 
			
		||||
filegroup(
 | 
			
		||||
    name = "testdata",
 | 
			
		||||
    srcs = glob([
 | 
			
		||||
        "testdata/**",
 | 
			
		||||
    ]),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "dataset",
 | 
			
		||||
    srcs = ["dataset.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/model_maker/python/core/data:classification_dataset",
 | 
			
		||||
        "//mediapipe/model_maker/python/vision/core:image_utils",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "dataset_test",
 | 
			
		||||
    srcs = ["dataset_test.py"],
 | 
			
		||||
    data = [
 | 
			
		||||
        ":testdata",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":dataset",
 | 
			
		||||
        "//mediapipe/tasks/python/test:test_utils",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,14 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""MediaPipe Model Maker Python Public API For Face Stylization."""
 | 
			
		||||
							
								
								
									
										98
									
								
								mediapipe/model_maker/python/vision/face_stylizer/dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								mediapipe/model_maker/python/vision/face_stylizer/dataset.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,98 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Face stylizer dataset library."""
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core.data import classification_dataset
 | 
			
		||||
from mediapipe.model_maker.python.vision.core import image_utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: Change to a unlabeled dataset if it makes sense.
 | 
			
		||||
class Dataset(classification_dataset.ClassificationDataset):
 | 
			
		||||
  """Dataset library for face stylizer fine tuning."""
 | 
			
		||||
 | 
			
		||||
  @classmethod
 | 
			
		||||
  def from_folder(
 | 
			
		||||
      cls, dirname: str
 | 
			
		||||
  ) -> classification_dataset.ClassificationDataset:
 | 
			
		||||
    """Loads images from the given directory.
 | 
			
		||||
 | 
			
		||||
    The style image dataset directory is expected to contain one subdirectory
 | 
			
		||||
    whose name represents the label of the style. There can be one or multiple
 | 
			
		||||
    images of the same style in that subdirectory. Supported input image formats
 | 
			
		||||
    include 'jpg', 'jpeg', 'png'.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      dirname: Name of the directory containing the image files.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      Dataset containing images and labels and other related info.
 | 
			
		||||
    Raises:
 | 
			
		||||
      ValueError: if the input data directory is empty.
 | 
			
		||||
    """
 | 
			
		||||
    data_root = os.path.abspath(dirname)
 | 
			
		||||
 | 
			
		||||
    # Assumes the image data of the same label are in the same subdirectory,
 | 
			
		||||
    # gets image path and label names.
 | 
			
		||||
    all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
 | 
			
		||||
    all_image_size = len(all_image_paths)
 | 
			
		||||
    if all_image_size == 0:
 | 
			
		||||
      raise ValueError('Invalid input data directory')
 | 
			
		||||
    if not any(
 | 
			
		||||
        fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths
 | 
			
		||||
    ):
 | 
			
		||||
      raise ValueError('No images found under given directory')
 | 
			
		||||
 | 
			
		||||
    label_names = sorted(
 | 
			
		||||
        name
 | 
			
		||||
        for name in os.listdir(data_root)
 | 
			
		||||
        if os.path.isdir(os.path.join(data_root, name))
 | 
			
		||||
    )
 | 
			
		||||
    all_label_size = len(label_names)
 | 
			
		||||
    index_by_label = dict(
 | 
			
		||||
        (name, index) for index, name in enumerate(label_names)
 | 
			
		||||
    )
 | 
			
		||||
    # Get the style label from the subdirectory name.
 | 
			
		||||
    all_image_labels = [
 | 
			
		||||
        index_by_label[os.path.basename(os.path.dirname(path))]
 | 
			
		||||
        for path in all_image_paths
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
 | 
			
		||||
 | 
			
		||||
    image_ds = path_ds.map(
 | 
			
		||||
        image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Load label
 | 
			
		||||
    label_ds = tf.data.Dataset.from_tensor_slices(
 | 
			
		||||
        tf.cast(all_image_labels, tf.int64)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Create a dataset of (image, label) pairs
 | 
			
		||||
    image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
 | 
			
		||||
 | 
			
		||||
    logging.info(
 | 
			
		||||
        'Load images dataset with size: %d, num_label: %d, labels: %s.',
 | 
			
		||||
        all_image_size,
 | 
			
		||||
        all_label_size,
 | 
			
		||||
        ', '.join(label_names),
 | 
			
		||||
    )
 | 
			
		||||
    return Dataset(
 | 
			
		||||
        dataset=image_label_ds, size=all_image_size, label_names=label_names
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,48 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.vision.face_stylizer import dataset
 | 
			
		||||
from mediapipe.tasks.python.test import test_utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DatasetTest(tf.test.TestCase):
 | 
			
		||||
 | 
			
		||||
  def setUp(self):
 | 
			
		||||
    super().setUp()
 | 
			
		||||
    # TODO: Replace the stylize image dataset with licensed images.
 | 
			
		||||
    self._test_data_dirname = 'testdata'
 | 
			
		||||
 | 
			
		||||
  def test_from_folder(self):
 | 
			
		||||
    input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
 | 
			
		||||
    data = dataset.Dataset.from_folder(dirname=input_data_dir)
 | 
			
		||||
    self.assertEqual(data.num_classes, 2)
 | 
			
		||||
    self.assertEqual(data.label_names, ['cartoon', 'sketch'])
 | 
			
		||||
    self.assertLen(data, 2)
 | 
			
		||||
 | 
			
		||||
  def test_from_folder_raise_value_error_for_invalid_path(self):
 | 
			
		||||
    with self.assertRaisesRegex(ValueError, 'Invalid input data directory'):
 | 
			
		||||
      dataset.Dataset.from_folder(dirname='invalid')
 | 
			
		||||
 | 
			
		||||
  def test_from_folder_raise_value_error_for_valid_no_data_path(self):
 | 
			
		||||
    input_data_dir = test_utils.get_test_data_path('face_stylizer')
 | 
			
		||||
    with self.assertRaisesRegex(
 | 
			
		||||
        ValueError, 'No images found under given directory'
 | 
			
		||||
    ):
 | 
			
		||||
      dataset.Dataset.from_folder(dirname=input_data_dir)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
  tf.test.main()
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 347 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 336 KiB  | 
		Loading…
	
		Reference in New Issue
	
	Block a user