Add a face alignment preprocessor to face stylizer.
PiperOrigin-RevId: 542559764
This commit is contained in:
		
							parent
							
								
									825e3a8af0
								
							
						
					
					
						commit
						ba7e0e0e50
					
				| 
						 | 
				
			
			@ -20,13 +20,6 @@ licenses(["notice"])
 | 
			
		|||
 | 
			
		||||
package(default_visibility = ["//mediapipe:__subpackages__"])
 | 
			
		||||
 | 
			
		||||
filegroup(
 | 
			
		||||
    name = "testdata",
 | 
			
		||||
    srcs = glob([
 | 
			
		||||
        "testdata/**",
 | 
			
		||||
    ]),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "constants",
 | 
			
		||||
    srcs = ["constants.py"],
 | 
			
		||||
| 
						 | 
				
			
			@ -72,18 +65,11 @@ py_library(
 | 
			
		|||
    name = "dataset",
 | 
			
		||||
    srcs = ["dataset.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":constants",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/data:classification_dataset",
 | 
			
		||||
        "//mediapipe/model_maker/python/vision/core:image_utils",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "dataset_test",
 | 
			
		||||
    srcs = ["dataset_test.py"],
 | 
			
		||||
    data = [":testdata"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":dataset",
 | 
			
		||||
        "//mediapipe/tasks/python/test:test_utils",
 | 
			
		||||
        "//mediapipe/python:_framework_bindings",
 | 
			
		||||
        "//mediapipe/tasks/python/core:base_options",
 | 
			
		||||
        "//mediapipe/tasks/python/vision:face_aligner",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -41,5 +41,11 @@ FACE_STYLIZER_W_FILES = file_util.DownloadedFiles(
 | 
			
		|||
    'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy',
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
FACE_ALIGNER_TASK_FILES = file_util.DownloadedFiles(
 | 
			
		||||
    'face_stylizer/face_landmarker_v2.task',
 | 
			
		||||
    'https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task',
 | 
			
		||||
    is_folder=False,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# Dimension of the input style vector to the decoder
 | 
			
		||||
STYLE_DIM = 512
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,13 +13,37 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
"""Face stylizer dataset library."""
 | 
			
		||||
 | 
			
		||||
from typing import Sequence
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core.data import classification_dataset
 | 
			
		||||
from mediapipe.model_maker.python.vision.core import image_utils
 | 
			
		||||
from mediapipe.model_maker.python.vision.face_stylizer import constants
 | 
			
		||||
from mediapipe.python._framework_bindings import image as image_module
 | 
			
		||||
from mediapipe.tasks.python.core import base_options as base_options_module
 | 
			
		||||
from mediapipe.tasks.python.vision import face_aligner
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _preprocess_face_dataset(
 | 
			
		||||
    all_image_paths: Sequence[str],
 | 
			
		||||
) -> Sequence[tf.Tensor]:
 | 
			
		||||
  """Preprocess face image dataset by aligning the face."""
 | 
			
		||||
  path = constants.FACE_ALIGNER_TASK_FILES.get_path()
 | 
			
		||||
  base_options = base_options_module.BaseOptions(model_asset_path=path)
 | 
			
		||||
  options = face_aligner.FaceAlignerOptions(base_options=base_options)
 | 
			
		||||
  aligner = face_aligner.FaceAligner.create_from_options(options)
 | 
			
		||||
 | 
			
		||||
  preprocessed_images = []
 | 
			
		||||
  for path in all_image_paths:
 | 
			
		||||
    tf.compat.v1.logging.info('Preprocess image %s', path)
 | 
			
		||||
    image = image_module.Image.create_from_file(path)
 | 
			
		||||
    aligned_image = aligner.align(image)
 | 
			
		||||
    aligned_image_tensor = tf.convert_to_tensor(aligned_image.numpy_view())
 | 
			
		||||
    preprocessed_images.append(aligned_image_tensor)
 | 
			
		||||
 | 
			
		||||
  return preprocessed_images
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: Change to a unlabeled dataset if it makes sense.
 | 
			
		||||
| 
						 | 
				
			
			@ -58,6 +82,7 @@ class Dataset(classification_dataset.ClassificationDataset):
 | 
			
		|||
    ):
 | 
			
		||||
      raise ValueError('No images found under given directory')
 | 
			
		||||
 | 
			
		||||
    image_data = _preprocess_face_dataset(all_image_paths)
 | 
			
		||||
    label_names = sorted(
 | 
			
		||||
        name
 | 
			
		||||
        for name in os.listdir(data_root)
 | 
			
		||||
| 
						 | 
				
			
			@ -73,11 +98,7 @@ class Dataset(classification_dataset.ClassificationDataset):
 | 
			
		|||
        for path in all_image_paths
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
 | 
			
		||||
 | 
			
		||||
    image_ds = path_ds.map(
 | 
			
		||||
        image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
 | 
			
		||||
    )
 | 
			
		||||
    image_ds = tf.data.Dataset.from_tensor_slices(image_data)
 | 
			
		||||
 | 
			
		||||
    # Load label
 | 
			
		||||
    label_ds = tf.data.Dataset.from_tensor_slices(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,8 +12,10 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.vision.core import image_utils
 | 
			
		||||
from mediapipe.model_maker.python.vision.face_stylizer import dataset
 | 
			
		||||
from mediapipe.tasks.python.test import test_utils
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -22,10 +24,10 @@ class DatasetTest(tf.test.TestCase):
 | 
			
		|||
 | 
			
		||||
  def setUp(self):
 | 
			
		||||
    super().setUp()
 | 
			
		||||
    self._test_data_dirname = 'input/style'
 | 
			
		||||
 | 
			
		||||
  def test_from_folder(self):
 | 
			
		||||
    input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
 | 
			
		||||
    test_data_dirname = 'input/style'
 | 
			
		||||
    input_data_dir = test_utils.get_test_data_path(test_data_dirname)
 | 
			
		||||
    data = dataset.Dataset.from_folder(dirname=input_data_dir)
 | 
			
		||||
    self.assertEqual(data.num_classes, 2)
 | 
			
		||||
    self.assertEqual(data.label_names, ['cartoon', 'sketch'])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,7 +14,7 @@
 | 
			
		|||
"""APIs to train face stylization model."""
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from typing import Callable, Optional
 | 
			
		||||
from typing import Any, Callable, Optional
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
| 
						 | 
				
			
			@ -54,7 +54,6 @@ class FaceStylizer(object):
 | 
			
		|||
    self._model_spec = model_spec
 | 
			
		||||
    self._model_options = model_options
 | 
			
		||||
    self._hparams = hparams
 | 
			
		||||
    # TODO: Support face alignment in image preprocessor.
 | 
			
		||||
    self._preprocessor = image_preprocessing.Preprocessor(
 | 
			
		||||
        input_shape=self._model_spec.input_image_shape,
 | 
			
		||||
        num_classes=1,
 | 
			
		||||
| 
						 | 
				
			
			@ -128,7 +127,7 @@ class FaceStylizer(object):
 | 
			
		|||
  def _train_model(
 | 
			
		||||
      self,
 | 
			
		||||
      train_data: classification_ds.ClassificationDataset,
 | 
			
		||||
      preprocessor: Optional[Callable[..., bool]] = None,
 | 
			
		||||
      preprocessor: Optional[Callable[..., Any]] = None,
 | 
			
		||||
  ):
 | 
			
		||||
    """Trains the face stylizer model.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -29,7 +29,7 @@ py_library(
 | 
			
		|||
    name = "base_options",
 | 
			
		||||
    srcs = ["base_options.py"],
 | 
			
		||||
    visibility = [
 | 
			
		||||
        "//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__",
 | 
			
		||||
        "//mediapipe/model_maker:__subpackages__",
 | 
			
		||||
        "//mediapipe/tasks:users",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -22,7 +22,6 @@ import six
 | 
			
		|||
from google.protobuf import descriptor
 | 
			
		||||
from google.protobuf import descriptor_pool
 | 
			
		||||
from google.protobuf import text_format
 | 
			
		||||
 | 
			
		||||
from mediapipe.python._framework_bindings import image as image_module
 | 
			
		||||
from mediapipe.python._framework_bindings import image_frame as image_frame_module
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -44,18 +43,21 @@ def test_srcdir():
 | 
			
		|||
 | 
			
		||||
def get_test_data_path(file_or_dirname_path: str) -> str:
 | 
			
		||||
  """Returns full test data path."""
 | 
			
		||||
  for (directory, subdirs, files) in os.walk(test_srcdir()):
 | 
			
		||||
  for directory, subdirs, files in os.walk(test_srcdir()):
 | 
			
		||||
    for f in subdirs + files:
 | 
			
		||||
      path = os.path.join(directory, f)
 | 
			
		||||
      if path.endswith(file_or_dirname_path):
 | 
			
		||||
        return path
 | 
			
		||||
  raise ValueError("No %s in test directory: %s." %
 | 
			
		||||
                   (file_or_dirname_path, test_srcdir()))
 | 
			
		||||
  raise ValueError(
 | 
			
		||||
      "No %s in test directory: %s." % (file_or_dirname_path, test_srcdir())
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_calibration_file(file_dir: str,
 | 
			
		||||
                            file_name: str = "score_calibration.txt",
 | 
			
		||||
                            content: str = "1.0,2.0,3.0,4.0") -> str:
 | 
			
		||||
def create_calibration_file(
 | 
			
		||||
    file_dir: str,
 | 
			
		||||
    file_name: str = "score_calibration.txt",
 | 
			
		||||
    content: str = "1.0,2.0,3.0,4.0",
 | 
			
		||||
) -> str:
 | 
			
		||||
  """Creates the calibration file."""
 | 
			
		||||
  calibration_file = os.path.join(file_dir, file_name)
 | 
			
		||||
  with open(calibration_file, mode="w") as file:
 | 
			
		||||
| 
						 | 
				
			
			@ -63,12 +65,9 @@ def create_calibration_file(file_dir: str,
 | 
			
		|||
  return calibration_file
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def assert_proto_equals(self,
 | 
			
		||||
                        a,
 | 
			
		||||
                        b,
 | 
			
		||||
                        check_initialized=True,
 | 
			
		||||
                        normalize_numbers=True,
 | 
			
		||||
                        msg=None):
 | 
			
		||||
def assert_proto_equals(
 | 
			
		||||
    self, a, b, check_initialized=True, normalize_numbers=True, msg=None
 | 
			
		||||
):
 | 
			
		||||
  """assert_proto_equals() is useful for unit tests.
 | 
			
		||||
 | 
			
		||||
  It produces much more helpful output than assertEqual() for proto2 messages.
 | 
			
		||||
| 
						 | 
				
			
			@ -113,7 +112,8 @@ def assert_proto_equals(self,
 | 
			
		|||
    self.assertMultiLineEqual(a_str, b_str, msg=msg)
 | 
			
		||||
  else:
 | 
			
		||||
    diff = "".join(
 | 
			
		||||
        difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True)))
 | 
			
		||||
        difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True))
 | 
			
		||||
    )
 | 
			
		||||
    if diff:
 | 
			
		||||
      self.fail("%s :\n%s" % (msg, diff))
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -147,14 +147,18 @@ def _normalize_number_fields(pb):
 | 
			
		|||
    # We force 32-bit values to int and 64-bit values to long to make
 | 
			
		||||
    # alternate implementations where the distinction is more significant
 | 
			
		||||
    # (e.g. the C++ implementation) simpler.
 | 
			
		||||
    if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
 | 
			
		||||
                     descriptor.FieldDescriptor.TYPE_UINT64,
 | 
			
		||||
                     descriptor.FieldDescriptor.TYPE_SINT64):
 | 
			
		||||
    if desc.type in (
 | 
			
		||||
        descriptor.FieldDescriptor.TYPE_INT64,
 | 
			
		||||
        descriptor.FieldDescriptor.TYPE_UINT64,
 | 
			
		||||
        descriptor.FieldDescriptor.TYPE_SINT64,
 | 
			
		||||
    ):
 | 
			
		||||
      normalized_values = [int(x) for x in values]
 | 
			
		||||
    elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
 | 
			
		||||
                       descriptor.FieldDescriptor.TYPE_UINT32,
 | 
			
		||||
                       descriptor.FieldDescriptor.TYPE_SINT32,
 | 
			
		||||
                       descriptor.FieldDescriptor.TYPE_ENUM):
 | 
			
		||||
    elif desc.type in (
 | 
			
		||||
        descriptor.FieldDescriptor.TYPE_INT32,
 | 
			
		||||
        descriptor.FieldDescriptor.TYPE_UINT32,
 | 
			
		||||
        descriptor.FieldDescriptor.TYPE_SINT32,
 | 
			
		||||
        descriptor.FieldDescriptor.TYPE_ENUM,
 | 
			
		||||
    ):
 | 
			
		||||
      normalized_values = [int(x) for x in values]
 | 
			
		||||
    elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
 | 
			
		||||
      normalized_values = [round(x, 4) for x in values]
 | 
			
		||||
| 
						 | 
				
			
			@ -168,14 +172,20 @@ def _normalize_number_fields(pb):
 | 
			
		|||
      else:
 | 
			
		||||
        setattr(pb, desc.name, normalized_values[0])
 | 
			
		||||
 | 
			
		||||
    if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
 | 
			
		||||
        desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
 | 
			
		||||
      if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
 | 
			
		||||
          desc.message_type.has_options and
 | 
			
		||||
          desc.message_type.GetOptions().map_entry):
 | 
			
		||||
    if (
 | 
			
		||||
        desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE
 | 
			
		||||
        or desc.type == descriptor.FieldDescriptor.TYPE_GROUP
 | 
			
		||||
    ):
 | 
			
		||||
      if (
 | 
			
		||||
          desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE
 | 
			
		||||
          and desc.message_type.has_options
 | 
			
		||||
          and desc.message_type.GetOptions().map_entry
 | 
			
		||||
      ):
 | 
			
		||||
        # This is a map, only recurse if the values have a message type.
 | 
			
		||||
        if (desc.message_type.fields_by_number[2].type ==
 | 
			
		||||
            descriptor.FieldDescriptor.TYPE_MESSAGE):
 | 
			
		||||
        if (
 | 
			
		||||
            desc.message_type.fields_by_number[2].type
 | 
			
		||||
            == descriptor.FieldDescriptor.TYPE_MESSAGE
 | 
			
		||||
        ):
 | 
			
		||||
          for v in six.itervalues(values):
 | 
			
		||||
            _normalize_number_fields(v)
 | 
			
		||||
      else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user