Add a face alignment preprocessor to face stylizer.
PiperOrigin-RevId: 542559764
This commit is contained in:
		
							parent
							
								
									825e3a8af0
								
							
						
					
					
						commit
						ba7e0e0e50
					
				|  | @ -20,13 +20,6 @@ licenses(["notice"]) | ||||||
| 
 | 
 | ||||||
| package(default_visibility = ["//mediapipe:__subpackages__"]) | package(default_visibility = ["//mediapipe:__subpackages__"]) | ||||||
| 
 | 
 | ||||||
| filegroup( |  | ||||||
|     name = "testdata", |  | ||||||
|     srcs = glob([ |  | ||||||
|         "testdata/**", |  | ||||||
|     ]), |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| py_library( | py_library( | ||||||
|     name = "constants", |     name = "constants", | ||||||
|     srcs = ["constants.py"], |     srcs = ["constants.py"], | ||||||
|  | @ -72,18 +65,11 @@ py_library( | ||||||
|     name = "dataset", |     name = "dataset", | ||||||
|     srcs = ["dataset.py"], |     srcs = ["dataset.py"], | ||||||
|     deps = [ |     deps = [ | ||||||
|  |         ":constants", | ||||||
|         "//mediapipe/model_maker/python/core/data:classification_dataset", |         "//mediapipe/model_maker/python/core/data:classification_dataset", | ||||||
|         "//mediapipe/model_maker/python/vision/core:image_utils", |         "//mediapipe/python:_framework_bindings", | ||||||
|     ], |         "//mediapipe/tasks/python/core:base_options", | ||||||
| ) |         "//mediapipe/tasks/python/vision:face_aligner", | ||||||
| 
 |  | ||||||
| py_test( |  | ||||||
|     name = "dataset_test", |  | ||||||
|     srcs = ["dataset_test.py"], |  | ||||||
|     data = [":testdata"], |  | ||||||
|     deps = [ |  | ||||||
|         ":dataset", |  | ||||||
|         "//mediapipe/tasks/python/test:test_utils", |  | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -41,5 +41,11 @@ FACE_STYLIZER_W_FILES = file_util.DownloadedFiles( | ||||||
|     'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy', |     'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy', | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | FACE_ALIGNER_TASK_FILES = file_util.DownloadedFiles( | ||||||
|  |     'face_stylizer/face_landmarker_v2.task', | ||||||
|  |     'https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task', | ||||||
|  |     is_folder=False, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| # Dimension of the input style vector to the decoder | # Dimension of the input style vector to the decoder | ||||||
| STYLE_DIM = 512 | STYLE_DIM = 512 | ||||||
|  |  | ||||||
|  | @ -13,13 +13,37 @@ | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| """Face stylizer dataset library.""" | """Face stylizer dataset library.""" | ||||||
| 
 | 
 | ||||||
|  | from typing import Sequence | ||||||
| import logging | import logging | ||||||
| import os | import os | ||||||
| 
 | 
 | ||||||
| import tensorflow as tf | import tensorflow as tf | ||||||
| 
 | 
 | ||||||
| from mediapipe.model_maker.python.core.data import classification_dataset | from mediapipe.model_maker.python.core.data import classification_dataset | ||||||
| from mediapipe.model_maker.python.vision.core import image_utils | from mediapipe.model_maker.python.vision.face_stylizer import constants | ||||||
|  | from mediapipe.python._framework_bindings import image as image_module | ||||||
|  | from mediapipe.tasks.python.core import base_options as base_options_module | ||||||
|  | from mediapipe.tasks.python.vision import face_aligner | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _preprocess_face_dataset( | ||||||
|  |     all_image_paths: Sequence[str], | ||||||
|  | ) -> Sequence[tf.Tensor]: | ||||||
|  |   """Preprocess face image dataset by aligning the face.""" | ||||||
|  |   path = constants.FACE_ALIGNER_TASK_FILES.get_path() | ||||||
|  |   base_options = base_options_module.BaseOptions(model_asset_path=path) | ||||||
|  |   options = face_aligner.FaceAlignerOptions(base_options=base_options) | ||||||
|  |   aligner = face_aligner.FaceAligner.create_from_options(options) | ||||||
|  | 
 | ||||||
|  |   preprocessed_images = [] | ||||||
|  |   for path in all_image_paths: | ||||||
|  |     tf.compat.v1.logging.info('Preprocess image %s', path) | ||||||
|  |     image = image_module.Image.create_from_file(path) | ||||||
|  |     aligned_image = aligner.align(image) | ||||||
|  |     aligned_image_tensor = tf.convert_to_tensor(aligned_image.numpy_view()) | ||||||
|  |     preprocessed_images.append(aligned_image_tensor) | ||||||
|  | 
 | ||||||
|  |   return preprocessed_images | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # TODO: Change to a unlabeled dataset if it makes sense. | # TODO: Change to a unlabeled dataset if it makes sense. | ||||||
|  | @ -58,6 +82,7 @@ class Dataset(classification_dataset.ClassificationDataset): | ||||||
|     ): |     ): | ||||||
|       raise ValueError('No images found under given directory') |       raise ValueError('No images found under given directory') | ||||||
| 
 | 
 | ||||||
|  |     image_data = _preprocess_face_dataset(all_image_paths) | ||||||
|     label_names = sorted( |     label_names = sorted( | ||||||
|         name |         name | ||||||
|         for name in os.listdir(data_root) |         for name in os.listdir(data_root) | ||||||
|  | @ -73,11 +98,7 @@ class Dataset(classification_dataset.ClassificationDataset): | ||||||
|         for path in all_image_paths |         for path in all_image_paths | ||||||
|     ] |     ] | ||||||
| 
 | 
 | ||||||
|     path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths) |     image_ds = tf.data.Dataset.from_tensor_slices(image_data) | ||||||
| 
 |  | ||||||
|     image_ds = path_ds.map( |  | ||||||
|         image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE |  | ||||||
|     ) |  | ||||||
| 
 | 
 | ||||||
|     # Load label |     # Load label | ||||||
|     label_ds = tf.data.Dataset.from_tensor_slices( |     label_ds = tf.data.Dataset.from_tensor_slices( | ||||||
|  |  | ||||||
|  | @ -12,8 +12,10 @@ | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| 
 | 
 | ||||||
|  | import numpy as np | ||||||
| import tensorflow as tf | import tensorflow as tf | ||||||
| 
 | 
 | ||||||
|  | from mediapipe.model_maker.python.vision.core import image_utils | ||||||
| from mediapipe.model_maker.python.vision.face_stylizer import dataset | from mediapipe.model_maker.python.vision.face_stylizer import dataset | ||||||
| from mediapipe.tasks.python.test import test_utils | from mediapipe.tasks.python.test import test_utils | ||||||
| 
 | 
 | ||||||
|  | @ -22,10 +24,10 @@ class DatasetTest(tf.test.TestCase): | ||||||
| 
 | 
 | ||||||
|   def setUp(self): |   def setUp(self): | ||||||
|     super().setUp() |     super().setUp() | ||||||
|     self._test_data_dirname = 'input/style' |  | ||||||
| 
 | 
 | ||||||
|   def test_from_folder(self): |   def test_from_folder(self): | ||||||
|     input_data_dir = test_utils.get_test_data_path(self._test_data_dirname) |     test_data_dirname = 'input/style' | ||||||
|  |     input_data_dir = test_utils.get_test_data_path(test_data_dirname) | ||||||
|     data = dataset.Dataset.from_folder(dirname=input_data_dir) |     data = dataset.Dataset.from_folder(dirname=input_data_dir) | ||||||
|     self.assertEqual(data.num_classes, 2) |     self.assertEqual(data.num_classes, 2) | ||||||
|     self.assertEqual(data.label_names, ['cartoon', 'sketch']) |     self.assertEqual(data.label_names, ['cartoon', 'sketch']) | ||||||
|  |  | ||||||
|  | @ -14,7 +14,7 @@ | ||||||
| """APIs to train face stylization model.""" | """APIs to train face stylization model.""" | ||||||
| 
 | 
 | ||||||
| import os | import os | ||||||
| from typing import Callable, Optional | from typing import Any, Callable, Optional | ||||||
| 
 | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| import tensorflow as tf | import tensorflow as tf | ||||||
|  | @ -54,7 +54,6 @@ class FaceStylizer(object): | ||||||
|     self._model_spec = model_spec |     self._model_spec = model_spec | ||||||
|     self._model_options = model_options |     self._model_options = model_options | ||||||
|     self._hparams = hparams |     self._hparams = hparams | ||||||
|     # TODO: Support face alignment in image preprocessor. |  | ||||||
|     self._preprocessor = image_preprocessing.Preprocessor( |     self._preprocessor = image_preprocessing.Preprocessor( | ||||||
|         input_shape=self._model_spec.input_image_shape, |         input_shape=self._model_spec.input_image_shape, | ||||||
|         num_classes=1, |         num_classes=1, | ||||||
|  | @ -128,7 +127,7 @@ class FaceStylizer(object): | ||||||
|   def _train_model( |   def _train_model( | ||||||
|       self, |       self, | ||||||
|       train_data: classification_ds.ClassificationDataset, |       train_data: classification_ds.ClassificationDataset, | ||||||
|       preprocessor: Optional[Callable[..., bool]] = None, |       preprocessor: Optional[Callable[..., Any]] = None, | ||||||
|   ): |   ): | ||||||
|     """Trains the face stylizer model. |     """Trains the face stylizer model. | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -29,7 +29,7 @@ py_library( | ||||||
|     name = "base_options", |     name = "base_options", | ||||||
|     srcs = ["base_options.py"], |     srcs = ["base_options.py"], | ||||||
|     visibility = [ |     visibility = [ | ||||||
|         "//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__", |         "//mediapipe/model_maker:__subpackages__", | ||||||
|         "//mediapipe/tasks:users", |         "//mediapipe/tasks:users", | ||||||
|     ], |     ], | ||||||
|     deps = [ |     deps = [ | ||||||
|  |  | ||||||
|  | @ -22,7 +22,6 @@ import six | ||||||
| from google.protobuf import descriptor | from google.protobuf import descriptor | ||||||
| from google.protobuf import descriptor_pool | from google.protobuf import descriptor_pool | ||||||
| from google.protobuf import text_format | from google.protobuf import text_format | ||||||
| 
 |  | ||||||
| from mediapipe.python._framework_bindings import image as image_module | from mediapipe.python._framework_bindings import image as image_module | ||||||
| from mediapipe.python._framework_bindings import image_frame as image_frame_module | from mediapipe.python._framework_bindings import image_frame as image_frame_module | ||||||
| 
 | 
 | ||||||
|  | @ -44,18 +43,21 @@ def test_srcdir(): | ||||||
| 
 | 
 | ||||||
| def get_test_data_path(file_or_dirname_path: str) -> str: | def get_test_data_path(file_or_dirname_path: str) -> str: | ||||||
|   """Returns full test data path.""" |   """Returns full test data path.""" | ||||||
|   for (directory, subdirs, files) in os.walk(test_srcdir()): |   for directory, subdirs, files in os.walk(test_srcdir()): | ||||||
|     for f in subdirs + files: |     for f in subdirs + files: | ||||||
|       path = os.path.join(directory, f) |       path = os.path.join(directory, f) | ||||||
|       if path.endswith(file_or_dirname_path): |       if path.endswith(file_or_dirname_path): | ||||||
|         return path |         return path | ||||||
|   raise ValueError("No %s in test directory: %s." % |   raise ValueError( | ||||||
|                    (file_or_dirname_path, test_srcdir())) |       "No %s in test directory: %s." % (file_or_dirname_path, test_srcdir()) | ||||||
|  |   ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def create_calibration_file(file_dir: str, | def create_calibration_file( | ||||||
|  |     file_dir: str, | ||||||
|     file_name: str = "score_calibration.txt", |     file_name: str = "score_calibration.txt", | ||||||
|                             content: str = "1.0,2.0,3.0,4.0") -> str: |     content: str = "1.0,2.0,3.0,4.0", | ||||||
|  | ) -> str: | ||||||
|   """Creates the calibration file.""" |   """Creates the calibration file.""" | ||||||
|   calibration_file = os.path.join(file_dir, file_name) |   calibration_file = os.path.join(file_dir, file_name) | ||||||
|   with open(calibration_file, mode="w") as file: |   with open(calibration_file, mode="w") as file: | ||||||
|  | @ -63,12 +65,9 @@ def create_calibration_file(file_dir: str, | ||||||
|   return calibration_file |   return calibration_file | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def assert_proto_equals(self, | def assert_proto_equals( | ||||||
|                         a, |     self, a, b, check_initialized=True, normalize_numbers=True, msg=None | ||||||
|                         b, | ): | ||||||
|                         check_initialized=True, |  | ||||||
|                         normalize_numbers=True, |  | ||||||
|                         msg=None): |  | ||||||
|   """assert_proto_equals() is useful for unit tests. |   """assert_proto_equals() is useful for unit tests. | ||||||
| 
 | 
 | ||||||
|   It produces much more helpful output than assertEqual() for proto2 messages. |   It produces much more helpful output than assertEqual() for proto2 messages. | ||||||
|  | @ -113,7 +112,8 @@ def assert_proto_equals(self, | ||||||
|     self.assertMultiLineEqual(a_str, b_str, msg=msg) |     self.assertMultiLineEqual(a_str, b_str, msg=msg) | ||||||
|   else: |   else: | ||||||
|     diff = "".join( |     diff = "".join( | ||||||
|         difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True))) |         difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True)) | ||||||
|  |     ) | ||||||
|     if diff: |     if diff: | ||||||
|       self.fail("%s :\n%s" % (msg, diff)) |       self.fail("%s :\n%s" % (msg, diff)) | ||||||
| 
 | 
 | ||||||
|  | @ -147,14 +147,18 @@ def _normalize_number_fields(pb): | ||||||
|     # We force 32-bit values to int and 64-bit values to long to make |     # We force 32-bit values to int and 64-bit values to long to make | ||||||
|     # alternate implementations where the distinction is more significant |     # alternate implementations where the distinction is more significant | ||||||
|     # (e.g. the C++ implementation) simpler. |     # (e.g. the C++ implementation) simpler. | ||||||
|     if desc.type in (descriptor.FieldDescriptor.TYPE_INT64, |     if desc.type in ( | ||||||
|  |         descriptor.FieldDescriptor.TYPE_INT64, | ||||||
|         descriptor.FieldDescriptor.TYPE_UINT64, |         descriptor.FieldDescriptor.TYPE_UINT64, | ||||||
|                      descriptor.FieldDescriptor.TYPE_SINT64): |         descriptor.FieldDescriptor.TYPE_SINT64, | ||||||
|  |     ): | ||||||
|       normalized_values = [int(x) for x in values] |       normalized_values = [int(x) for x in values] | ||||||
|     elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32, |     elif desc.type in ( | ||||||
|  |         descriptor.FieldDescriptor.TYPE_INT32, | ||||||
|         descriptor.FieldDescriptor.TYPE_UINT32, |         descriptor.FieldDescriptor.TYPE_UINT32, | ||||||
|         descriptor.FieldDescriptor.TYPE_SINT32, |         descriptor.FieldDescriptor.TYPE_SINT32, | ||||||
|                        descriptor.FieldDescriptor.TYPE_ENUM): |         descriptor.FieldDescriptor.TYPE_ENUM, | ||||||
|  |     ): | ||||||
|       normalized_values = [int(x) for x in values] |       normalized_values = [int(x) for x in values] | ||||||
|     elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT: |     elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT: | ||||||
|       normalized_values = [round(x, 4) for x in values] |       normalized_values = [round(x, 4) for x in values] | ||||||
|  | @ -168,14 +172,20 @@ def _normalize_number_fields(pb): | ||||||
|       else: |       else: | ||||||
|         setattr(pb, desc.name, normalized_values[0]) |         setattr(pb, desc.name, normalized_values[0]) | ||||||
| 
 | 
 | ||||||
|     if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or |     if ( | ||||||
|         desc.type == descriptor.FieldDescriptor.TYPE_GROUP): |         desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE | ||||||
|       if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and |         or desc.type == descriptor.FieldDescriptor.TYPE_GROUP | ||||||
|           desc.message_type.has_options and |     ): | ||||||
|           desc.message_type.GetOptions().map_entry): |       if ( | ||||||
|  |           desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE | ||||||
|  |           and desc.message_type.has_options | ||||||
|  |           and desc.message_type.GetOptions().map_entry | ||||||
|  |       ): | ||||||
|         # This is a map, only recurse if the values have a message type. |         # This is a map, only recurse if the values have a message type. | ||||||
|         if (desc.message_type.fields_by_number[2].type == |         if ( | ||||||
|             descriptor.FieldDescriptor.TYPE_MESSAGE): |             desc.message_type.fields_by_number[2].type | ||||||
|  |             == descriptor.FieldDescriptor.TYPE_MESSAGE | ||||||
|  |         ): | ||||||
|           for v in six.itervalues(values): |           for v in six.itervalues(values): | ||||||
|             _normalize_number_fields(v) |             _normalize_number_fields(v) | ||||||
|       else: |       else: | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user