Add a face alignment preprocessor to face stylizer.
PiperOrigin-RevId: 542559764
This commit is contained in:
		
							parent
							
								
									825e3a8af0
								
							
						
					
					
						commit
						ba7e0e0e50
					
				| 
						 | 
					@ -20,13 +20,6 @@ licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
package(default_visibility = ["//mediapipe:__subpackages__"])
 | 
					package(default_visibility = ["//mediapipe:__subpackages__"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
filegroup(
 | 
					 | 
				
			||||||
    name = "testdata",
 | 
					 | 
				
			||||||
    srcs = glob([
 | 
					 | 
				
			||||||
        "testdata/**",
 | 
					 | 
				
			||||||
    ]),
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
py_library(
 | 
					py_library(
 | 
				
			||||||
    name = "constants",
 | 
					    name = "constants",
 | 
				
			||||||
    srcs = ["constants.py"],
 | 
					    srcs = ["constants.py"],
 | 
				
			||||||
| 
						 | 
					@ -72,18 +65,11 @@ py_library(
 | 
				
			||||||
    name = "dataset",
 | 
					    name = "dataset",
 | 
				
			||||||
    srcs = ["dataset.py"],
 | 
					    srcs = ["dataset.py"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":constants",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core/data:classification_dataset",
 | 
					        "//mediapipe/model_maker/python/core/data:classification_dataset",
 | 
				
			||||||
        "//mediapipe/model_maker/python/vision/core:image_utils",
 | 
					        "//mediapipe/python:_framework_bindings",
 | 
				
			||||||
    ],
 | 
					        "//mediapipe/tasks/python/core:base_options",
 | 
				
			||||||
)
 | 
					        "//mediapipe/tasks/python/vision:face_aligner",
 | 
				
			||||||
 | 
					 | 
				
			||||||
py_test(
 | 
					 | 
				
			||||||
    name = "dataset_test",
 | 
					 | 
				
			||||||
    srcs = ["dataset_test.py"],
 | 
					 | 
				
			||||||
    data = [":testdata"],
 | 
					 | 
				
			||||||
    deps = [
 | 
					 | 
				
			||||||
        ":dataset",
 | 
					 | 
				
			||||||
        "//mediapipe/tasks/python/test:test_utils",
 | 
					 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -41,5 +41,11 @@ FACE_STYLIZER_W_FILES = file_util.DownloadedFiles(
 | 
				
			||||||
    'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy',
 | 
					    'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy',
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					FACE_ALIGNER_TASK_FILES = file_util.DownloadedFiles(
 | 
				
			||||||
 | 
					    'face_stylizer/face_landmarker_v2.task',
 | 
				
			||||||
 | 
					    'https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task',
 | 
				
			||||||
 | 
					    is_folder=False,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Dimension of the input style vector to the decoder
 | 
					# Dimension of the input style vector to the decoder
 | 
				
			||||||
STYLE_DIM = 512
 | 
					STYLE_DIM = 512
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,13 +13,37 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
"""Face stylizer dataset library."""
 | 
					"""Face stylizer dataset library."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Sequence
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mediapipe.model_maker.python.core.data import classification_dataset
 | 
					from mediapipe.model_maker.python.core.data import classification_dataset
 | 
				
			||||||
from mediapipe.model_maker.python.vision.core import image_utils
 | 
					from mediapipe.model_maker.python.vision.face_stylizer import constants
 | 
				
			||||||
 | 
					from mediapipe.python._framework_bindings import image as image_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.core import base_options as base_options_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.vision import face_aligner
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _preprocess_face_dataset(
 | 
				
			||||||
 | 
					    all_image_paths: Sequence[str],
 | 
				
			||||||
 | 
					) -> Sequence[tf.Tensor]:
 | 
				
			||||||
 | 
					  """Preprocess face image dataset by aligning the face."""
 | 
				
			||||||
 | 
					  path = constants.FACE_ALIGNER_TASK_FILES.get_path()
 | 
				
			||||||
 | 
					  base_options = base_options_module.BaseOptions(model_asset_path=path)
 | 
				
			||||||
 | 
					  options = face_aligner.FaceAlignerOptions(base_options=base_options)
 | 
				
			||||||
 | 
					  aligner = face_aligner.FaceAligner.create_from_options(options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  preprocessed_images = []
 | 
				
			||||||
 | 
					  for path in all_image_paths:
 | 
				
			||||||
 | 
					    tf.compat.v1.logging.info('Preprocess image %s', path)
 | 
				
			||||||
 | 
					    image = image_module.Image.create_from_file(path)
 | 
				
			||||||
 | 
					    aligned_image = aligner.align(image)
 | 
				
			||||||
 | 
					    aligned_image_tensor = tf.convert_to_tensor(aligned_image.numpy_view())
 | 
				
			||||||
 | 
					    preprocessed_images.append(aligned_image_tensor)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return preprocessed_images
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO: Change to a unlabeled dataset if it makes sense.
 | 
					# TODO: Change to a unlabeled dataset if it makes sense.
 | 
				
			||||||
| 
						 | 
					@ -58,6 +82,7 @@ class Dataset(classification_dataset.ClassificationDataset):
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
      raise ValueError('No images found under given directory')
 | 
					      raise ValueError('No images found under given directory')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    image_data = _preprocess_face_dataset(all_image_paths)
 | 
				
			||||||
    label_names = sorted(
 | 
					    label_names = sorted(
 | 
				
			||||||
        name
 | 
					        name
 | 
				
			||||||
        for name in os.listdir(data_root)
 | 
					        for name in os.listdir(data_root)
 | 
				
			||||||
| 
						 | 
					@ -73,11 +98,7 @@ class Dataset(classification_dataset.ClassificationDataset):
 | 
				
			||||||
        for path in all_image_paths
 | 
					        for path in all_image_paths
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
 | 
					    image_ds = tf.data.Dataset.from_tensor_slices(image_data)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    image_ds = path_ds.map(
 | 
					 | 
				
			||||||
        image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Load label
 | 
					    # Load label
 | 
				
			||||||
    label_ds = tf.data.Dataset.from_tensor_slices(
 | 
					    label_ds = tf.data.Dataset.from_tensor_slices(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -12,8 +12,10 @@
 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.vision.core import image_utils
 | 
				
			||||||
from mediapipe.model_maker.python.vision.face_stylizer import dataset
 | 
					from mediapipe.model_maker.python.vision.face_stylizer import dataset
 | 
				
			||||||
from mediapipe.tasks.python.test import test_utils
 | 
					from mediapipe.tasks.python.test import test_utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -22,10 +24,10 @@ class DatasetTest(tf.test.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def setUp(self):
 | 
					  def setUp(self):
 | 
				
			||||||
    super().setUp()
 | 
					    super().setUp()
 | 
				
			||||||
    self._test_data_dirname = 'input/style'
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_from_folder(self):
 | 
					  def test_from_folder(self):
 | 
				
			||||||
    input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
 | 
					    test_data_dirname = 'input/style'
 | 
				
			||||||
 | 
					    input_data_dir = test_utils.get_test_data_path(test_data_dirname)
 | 
				
			||||||
    data = dataset.Dataset.from_folder(dirname=input_data_dir)
 | 
					    data = dataset.Dataset.from_folder(dirname=input_data_dir)
 | 
				
			||||||
    self.assertEqual(data.num_classes, 2)
 | 
					    self.assertEqual(data.num_classes, 2)
 | 
				
			||||||
    self.assertEqual(data.label_names, ['cartoon', 'sketch'])
 | 
					    self.assertEqual(data.label_names, ['cartoon', 'sketch'])
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,7 +14,7 @@
 | 
				
			||||||
"""APIs to train face stylization model."""
 | 
					"""APIs to train face stylization model."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
from typing import Callable, Optional
 | 
					from typing import Any, Callable, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
| 
						 | 
					@ -54,7 +54,6 @@ class FaceStylizer(object):
 | 
				
			||||||
    self._model_spec = model_spec
 | 
					    self._model_spec = model_spec
 | 
				
			||||||
    self._model_options = model_options
 | 
					    self._model_options = model_options
 | 
				
			||||||
    self._hparams = hparams
 | 
					    self._hparams = hparams
 | 
				
			||||||
    # TODO: Support face alignment in image preprocessor.
 | 
					 | 
				
			||||||
    self._preprocessor = image_preprocessing.Preprocessor(
 | 
					    self._preprocessor = image_preprocessing.Preprocessor(
 | 
				
			||||||
        input_shape=self._model_spec.input_image_shape,
 | 
					        input_shape=self._model_spec.input_image_shape,
 | 
				
			||||||
        num_classes=1,
 | 
					        num_classes=1,
 | 
				
			||||||
| 
						 | 
					@ -128,7 +127,7 @@ class FaceStylizer(object):
 | 
				
			||||||
  def _train_model(
 | 
					  def _train_model(
 | 
				
			||||||
      self,
 | 
					      self,
 | 
				
			||||||
      train_data: classification_ds.ClassificationDataset,
 | 
					      train_data: classification_ds.ClassificationDataset,
 | 
				
			||||||
      preprocessor: Optional[Callable[..., bool]] = None,
 | 
					      preprocessor: Optional[Callable[..., Any]] = None,
 | 
				
			||||||
  ):
 | 
					  ):
 | 
				
			||||||
    """Trains the face stylizer model.
 | 
					    """Trains the face stylizer model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -29,7 +29,7 @@ py_library(
 | 
				
			||||||
    name = "base_options",
 | 
					    name = "base_options",
 | 
				
			||||||
    srcs = ["base_options.py"],
 | 
					    srcs = ["base_options.py"],
 | 
				
			||||||
    visibility = [
 | 
					    visibility = [
 | 
				
			||||||
        "//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__",
 | 
					        "//mediapipe/model_maker:__subpackages__",
 | 
				
			||||||
        "//mediapipe/tasks:users",
 | 
					        "//mediapipe/tasks:users",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -22,7 +22,6 @@ import six
 | 
				
			||||||
from google.protobuf import descriptor
 | 
					from google.protobuf import descriptor
 | 
				
			||||||
from google.protobuf import descriptor_pool
 | 
					from google.protobuf import descriptor_pool
 | 
				
			||||||
from google.protobuf import text_format
 | 
					from google.protobuf import text_format
 | 
				
			||||||
 | 
					 | 
				
			||||||
from mediapipe.python._framework_bindings import image as image_module
 | 
					from mediapipe.python._framework_bindings import image as image_module
 | 
				
			||||||
from mediapipe.python._framework_bindings import image_frame as image_frame_module
 | 
					from mediapipe.python._framework_bindings import image_frame as image_frame_module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -44,18 +43,21 @@ def test_srcdir():
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_test_data_path(file_or_dirname_path: str) -> str:
 | 
					def get_test_data_path(file_or_dirname_path: str) -> str:
 | 
				
			||||||
  """Returns full test data path."""
 | 
					  """Returns full test data path."""
 | 
				
			||||||
  for (directory, subdirs, files) in os.walk(test_srcdir()):
 | 
					  for directory, subdirs, files in os.walk(test_srcdir()):
 | 
				
			||||||
    for f in subdirs + files:
 | 
					    for f in subdirs + files:
 | 
				
			||||||
      path = os.path.join(directory, f)
 | 
					      path = os.path.join(directory, f)
 | 
				
			||||||
      if path.endswith(file_or_dirname_path):
 | 
					      if path.endswith(file_or_dirname_path):
 | 
				
			||||||
        return path
 | 
					        return path
 | 
				
			||||||
  raise ValueError("No %s in test directory: %s." %
 | 
					  raise ValueError(
 | 
				
			||||||
                   (file_or_dirname_path, test_srcdir()))
 | 
					      "No %s in test directory: %s." % (file_or_dirname_path, test_srcdir())
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_calibration_file(file_dir: str,
 | 
					def create_calibration_file(
 | 
				
			||||||
                            file_name: str = "score_calibration.txt",
 | 
					    file_dir: str,
 | 
				
			||||||
                            content: str = "1.0,2.0,3.0,4.0") -> str:
 | 
					    file_name: str = "score_calibration.txt",
 | 
				
			||||||
 | 
					    content: str = "1.0,2.0,3.0,4.0",
 | 
				
			||||||
 | 
					) -> str:
 | 
				
			||||||
  """Creates the calibration file."""
 | 
					  """Creates the calibration file."""
 | 
				
			||||||
  calibration_file = os.path.join(file_dir, file_name)
 | 
					  calibration_file = os.path.join(file_dir, file_name)
 | 
				
			||||||
  with open(calibration_file, mode="w") as file:
 | 
					  with open(calibration_file, mode="w") as file:
 | 
				
			||||||
| 
						 | 
					@ -63,12 +65,9 @@ def create_calibration_file(file_dir: str,
 | 
				
			||||||
  return calibration_file
 | 
					  return calibration_file
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def assert_proto_equals(self,
 | 
					def assert_proto_equals(
 | 
				
			||||||
                        a,
 | 
					    self, a, b, check_initialized=True, normalize_numbers=True, msg=None
 | 
				
			||||||
                        b,
 | 
					):
 | 
				
			||||||
                        check_initialized=True,
 | 
					 | 
				
			||||||
                        normalize_numbers=True,
 | 
					 | 
				
			||||||
                        msg=None):
 | 
					 | 
				
			||||||
  """assert_proto_equals() is useful for unit tests.
 | 
					  """assert_proto_equals() is useful for unit tests.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  It produces much more helpful output than assertEqual() for proto2 messages.
 | 
					  It produces much more helpful output than assertEqual() for proto2 messages.
 | 
				
			||||||
| 
						 | 
					@ -113,7 +112,8 @@ def assert_proto_equals(self,
 | 
				
			||||||
    self.assertMultiLineEqual(a_str, b_str, msg=msg)
 | 
					    self.assertMultiLineEqual(a_str, b_str, msg=msg)
 | 
				
			||||||
  else:
 | 
					  else:
 | 
				
			||||||
    diff = "".join(
 | 
					    diff = "".join(
 | 
				
			||||||
        difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True)))
 | 
					        difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True))
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    if diff:
 | 
					    if diff:
 | 
				
			||||||
      self.fail("%s :\n%s" % (msg, diff))
 | 
					      self.fail("%s :\n%s" % (msg, diff))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -147,14 +147,18 @@ def _normalize_number_fields(pb):
 | 
				
			||||||
    # We force 32-bit values to int and 64-bit values to long to make
 | 
					    # We force 32-bit values to int and 64-bit values to long to make
 | 
				
			||||||
    # alternate implementations where the distinction is more significant
 | 
					    # alternate implementations where the distinction is more significant
 | 
				
			||||||
    # (e.g. the C++ implementation) simpler.
 | 
					    # (e.g. the C++ implementation) simpler.
 | 
				
			||||||
    if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
 | 
					    if desc.type in (
 | 
				
			||||||
                     descriptor.FieldDescriptor.TYPE_UINT64,
 | 
					        descriptor.FieldDescriptor.TYPE_INT64,
 | 
				
			||||||
                     descriptor.FieldDescriptor.TYPE_SINT64):
 | 
					        descriptor.FieldDescriptor.TYPE_UINT64,
 | 
				
			||||||
 | 
					        descriptor.FieldDescriptor.TYPE_SINT64,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
      normalized_values = [int(x) for x in values]
 | 
					      normalized_values = [int(x) for x in values]
 | 
				
			||||||
    elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
 | 
					    elif desc.type in (
 | 
				
			||||||
                       descriptor.FieldDescriptor.TYPE_UINT32,
 | 
					        descriptor.FieldDescriptor.TYPE_INT32,
 | 
				
			||||||
                       descriptor.FieldDescriptor.TYPE_SINT32,
 | 
					        descriptor.FieldDescriptor.TYPE_UINT32,
 | 
				
			||||||
                       descriptor.FieldDescriptor.TYPE_ENUM):
 | 
					        descriptor.FieldDescriptor.TYPE_SINT32,
 | 
				
			||||||
 | 
					        descriptor.FieldDescriptor.TYPE_ENUM,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
      normalized_values = [int(x) for x in values]
 | 
					      normalized_values = [int(x) for x in values]
 | 
				
			||||||
    elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
 | 
					    elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
 | 
				
			||||||
      normalized_values = [round(x, 4) for x in values]
 | 
					      normalized_values = [round(x, 4) for x in values]
 | 
				
			||||||
| 
						 | 
					@ -168,14 +172,20 @@ def _normalize_number_fields(pb):
 | 
				
			||||||
      else:
 | 
					      else:
 | 
				
			||||||
        setattr(pb, desc.name, normalized_values[0])
 | 
					        setattr(pb, desc.name, normalized_values[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
 | 
					    if (
 | 
				
			||||||
        desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
 | 
					        desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE
 | 
				
			||||||
      if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
 | 
					        or desc.type == descriptor.FieldDescriptor.TYPE_GROUP
 | 
				
			||||||
          desc.message_type.has_options and
 | 
					    ):
 | 
				
			||||||
          desc.message_type.GetOptions().map_entry):
 | 
					      if (
 | 
				
			||||||
 | 
					          desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE
 | 
				
			||||||
 | 
					          and desc.message_type.has_options
 | 
				
			||||||
 | 
					          and desc.message_type.GetOptions().map_entry
 | 
				
			||||||
 | 
					      ):
 | 
				
			||||||
        # This is a map, only recurse if the values have a message type.
 | 
					        # This is a map, only recurse if the values have a message type.
 | 
				
			||||||
        if (desc.message_type.fields_by_number[2].type ==
 | 
					        if (
 | 
				
			||||||
            descriptor.FieldDescriptor.TYPE_MESSAGE):
 | 
					            desc.message_type.fields_by_number[2].type
 | 
				
			||||||
 | 
					            == descriptor.FieldDescriptor.TYPE_MESSAGE
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
          for v in six.itervalues(values):
 | 
					          for v in six.itervalues(values):
 | 
				
			||||||
            _normalize_number_fields(v)
 | 
					            _normalize_number_fields(v)
 | 
				
			||||||
      else:
 | 
					      else:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user