Add a face alignment preprocessor to face stylizer.

PiperOrigin-RevId: 542559764
This commit is contained in:
MediaPipe Team 2023-06-22 07:55:09 -07:00 committed by Copybara-Service
parent 825e3a8af0
commit ba7e0e0e50
7 changed files with 82 additions and 58 deletions

View File

@ -20,13 +20,6 @@ licenses(["notice"])
package(default_visibility = ["//mediapipe:__subpackages__"]) package(default_visibility = ["//mediapipe:__subpackages__"])
filegroup(
name = "testdata",
srcs = glob([
"testdata/**",
]),
)
py_library( py_library(
name = "constants", name = "constants",
srcs = ["constants.py"], srcs = ["constants.py"],
@ -72,18 +65,11 @@ py_library(
name = "dataset", name = "dataset",
srcs = ["dataset.py"], srcs = ["dataset.py"],
deps = [ deps = [
":constants",
"//mediapipe/model_maker/python/core/data:classification_dataset", "//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/vision/core:image_utils", "//mediapipe/python:_framework_bindings",
], "//mediapipe/tasks/python/core:base_options",
) "//mediapipe/tasks/python/vision:face_aligner",
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
data = [":testdata"],
deps = [
":dataset",
"//mediapipe/tasks/python/test:test_utils",
], ],
) )

View File

@ -41,5 +41,11 @@ FACE_STYLIZER_W_FILES = file_util.DownloadedFiles(
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy', 'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy',
) )
FACE_ALIGNER_TASK_FILES = file_util.DownloadedFiles(
'face_stylizer/face_landmarker_v2.task',
'https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task',
is_folder=False,
)
# Dimension of the input style vector to the decoder # Dimension of the input style vector to the decoder
STYLE_DIM = 512 STYLE_DIM = 512

View File

@ -13,13 +13,37 @@
# limitations under the License. # limitations under the License.
"""Face stylizer dataset library.""" """Face stylizer dataset library."""
from typing import Sequence
import logging import logging
import os import os
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.vision.core import image_utils from mediapipe.model_maker.python.vision.face_stylizer import constants
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.vision import face_aligner
def _preprocess_face_dataset(
all_image_paths: Sequence[str],
) -> Sequence[tf.Tensor]:
"""Preprocess face image dataset by aligning the face."""
path = constants.FACE_ALIGNER_TASK_FILES.get_path()
base_options = base_options_module.BaseOptions(model_asset_path=path)
options = face_aligner.FaceAlignerOptions(base_options=base_options)
aligner = face_aligner.FaceAligner.create_from_options(options)
preprocessed_images = []
for path in all_image_paths:
tf.compat.v1.logging.info('Preprocess image %s', path)
image = image_module.Image.create_from_file(path)
aligned_image = aligner.align(image)
aligned_image_tensor = tf.convert_to_tensor(aligned_image.numpy_view())
preprocessed_images.append(aligned_image_tensor)
return preprocessed_images
# TODO: Change to a unlabeled dataset if it makes sense. # TODO: Change to a unlabeled dataset if it makes sense.
@ -58,6 +82,7 @@ class Dataset(classification_dataset.ClassificationDataset):
): ):
raise ValueError('No images found under given directory') raise ValueError('No images found under given directory')
image_data = _preprocess_face_dataset(all_image_paths)
label_names = sorted( label_names = sorted(
name name
for name in os.listdir(data_root) for name in os.listdir(data_root)
@ -73,11 +98,7 @@ class Dataset(classification_dataset.ClassificationDataset):
for path in all_image_paths for path in all_image_paths
] ]
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths) image_ds = tf.data.Dataset.from_tensor_slices(image_data)
image_ds = path_ds.map(
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
)
# Load label # Load label
label_ds = tf.data.Dataset.from_tensor_slices( label_ds = tf.data.Dataset.from_tensor_slices(

View File

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.vision.core import image_utils
from mediapipe.model_maker.python.vision.face_stylizer import dataset from mediapipe.model_maker.python.vision.face_stylizer import dataset
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
@ -22,10 +24,10 @@ class DatasetTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self._test_data_dirname = 'input/style'
def test_from_folder(self): def test_from_folder(self):
input_data_dir = test_utils.get_test_data_path(self._test_data_dirname) test_data_dirname = 'input/style'
input_data_dir = test_utils.get_test_data_path(test_data_dirname)
data = dataset.Dataset.from_folder(dirname=input_data_dir) data = dataset.Dataset.from_folder(dirname=input_data_dir)
self.assertEqual(data.num_classes, 2) self.assertEqual(data.num_classes, 2)
self.assertEqual(data.label_names, ['cartoon', 'sketch']) self.assertEqual(data.label_names, ['cartoon', 'sketch'])

View File

@ -14,7 +14,7 @@
"""APIs to train face stylization model.""" """APIs to train face stylization model."""
import os import os
from typing import Callable, Optional from typing import Any, Callable, Optional
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -54,7 +54,6 @@ class FaceStylizer(object):
self._model_spec = model_spec self._model_spec = model_spec
self._model_options = model_options self._model_options = model_options
self._hparams = hparams self._hparams = hparams
# TODO: Support face alignment in image preprocessor.
self._preprocessor = image_preprocessing.Preprocessor( self._preprocessor = image_preprocessing.Preprocessor(
input_shape=self._model_spec.input_image_shape, input_shape=self._model_spec.input_image_shape,
num_classes=1, num_classes=1,
@ -128,7 +127,7 @@ class FaceStylizer(object):
def _train_model( def _train_model(
self, self,
train_data: classification_ds.ClassificationDataset, train_data: classification_ds.ClassificationDataset,
preprocessor: Optional[Callable[..., bool]] = None, preprocessor: Optional[Callable[..., Any]] = None,
): ):
"""Trains the face stylizer model. """Trains the face stylizer model.

View File

@ -29,7 +29,7 @@ py_library(
name = "base_options", name = "base_options",
srcs = ["base_options.py"], srcs = ["base_options.py"],
visibility = [ visibility = [
"//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__", "//mediapipe/model_maker:__subpackages__",
"//mediapipe/tasks:users", "//mediapipe/tasks:users",
], ],
deps = [ deps = [

View File

@ -22,7 +22,6 @@ import six
from google.protobuf import descriptor from google.protobuf import descriptor
from google.protobuf import descriptor_pool from google.protobuf import descriptor_pool
from google.protobuf import text_format from google.protobuf import text_format
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import image_frame as image_frame_module from mediapipe.python._framework_bindings import image_frame as image_frame_module
@ -44,18 +43,21 @@ def test_srcdir():
def get_test_data_path(file_or_dirname_path: str) -> str: def get_test_data_path(file_or_dirname_path: str) -> str:
"""Returns full test data path.""" """Returns full test data path."""
for (directory, subdirs, files) in os.walk(test_srcdir()): for directory, subdirs, files in os.walk(test_srcdir()):
for f in subdirs + files: for f in subdirs + files:
path = os.path.join(directory, f) path = os.path.join(directory, f)
if path.endswith(file_or_dirname_path): if path.endswith(file_or_dirname_path):
return path return path
raise ValueError("No %s in test directory: %s." % raise ValueError(
(file_or_dirname_path, test_srcdir())) "No %s in test directory: %s." % (file_or_dirname_path, test_srcdir())
)
def create_calibration_file(file_dir: str, def create_calibration_file(
file_dir: str,
file_name: str = "score_calibration.txt", file_name: str = "score_calibration.txt",
content: str = "1.0,2.0,3.0,4.0") -> str: content: str = "1.0,2.0,3.0,4.0",
) -> str:
"""Creates the calibration file.""" """Creates the calibration file."""
calibration_file = os.path.join(file_dir, file_name) calibration_file = os.path.join(file_dir, file_name)
with open(calibration_file, mode="w") as file: with open(calibration_file, mode="w") as file:
@ -63,12 +65,9 @@ def create_calibration_file(file_dir: str,
return calibration_file return calibration_file
def assert_proto_equals(self, def assert_proto_equals(
a, self, a, b, check_initialized=True, normalize_numbers=True, msg=None
b, ):
check_initialized=True,
normalize_numbers=True,
msg=None):
"""assert_proto_equals() is useful for unit tests. """assert_proto_equals() is useful for unit tests.
It produces much more helpful output than assertEqual() for proto2 messages. It produces much more helpful output than assertEqual() for proto2 messages.
@ -113,7 +112,8 @@ def assert_proto_equals(self,
self.assertMultiLineEqual(a_str, b_str, msg=msg) self.assertMultiLineEqual(a_str, b_str, msg=msg)
else: else:
diff = "".join( diff = "".join(
difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True))) difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True))
)
if diff: if diff:
self.fail("%s :\n%s" % (msg, diff)) self.fail("%s :\n%s" % (msg, diff))
@ -147,14 +147,18 @@ def _normalize_number_fields(pb):
# We force 32-bit values to int and 64-bit values to long to make # We force 32-bit values to int and 64-bit values to long to make
# alternate implementations where the distinction is more significant # alternate implementations where the distinction is more significant
# (e.g. the C++ implementation) simpler. # (e.g. the C++ implementation) simpler.
if desc.type in (descriptor.FieldDescriptor.TYPE_INT64, if desc.type in (
descriptor.FieldDescriptor.TYPE_INT64,
descriptor.FieldDescriptor.TYPE_UINT64, descriptor.FieldDescriptor.TYPE_UINT64,
descriptor.FieldDescriptor.TYPE_SINT64): descriptor.FieldDescriptor.TYPE_SINT64,
):
normalized_values = [int(x) for x in values] normalized_values = [int(x) for x in values]
elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32, elif desc.type in (
descriptor.FieldDescriptor.TYPE_INT32,
descriptor.FieldDescriptor.TYPE_UINT32, descriptor.FieldDescriptor.TYPE_UINT32,
descriptor.FieldDescriptor.TYPE_SINT32, descriptor.FieldDescriptor.TYPE_SINT32,
descriptor.FieldDescriptor.TYPE_ENUM): descriptor.FieldDescriptor.TYPE_ENUM,
):
normalized_values = [int(x) for x in values] normalized_values = [int(x) for x in values]
elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT: elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
normalized_values = [round(x, 4) for x in values] normalized_values = [round(x, 4) for x in values]
@ -168,14 +172,20 @@ def _normalize_number_fields(pb):
else: else:
setattr(pb, desc.name, normalized_values[0]) setattr(pb, desc.name, normalized_values[0])
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or if (
desc.type == descriptor.FieldDescriptor.TYPE_GROUP): desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and or desc.type == descriptor.FieldDescriptor.TYPE_GROUP
desc.message_type.has_options and ):
desc.message_type.GetOptions().map_entry): if (
desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE
and desc.message_type.has_options
and desc.message_type.GetOptions().map_entry
):
# This is a map, only recurse if the values have a message type. # This is a map, only recurse if the values have a message type.
if (desc.message_type.fields_by_number[2].type == if (
descriptor.FieldDescriptor.TYPE_MESSAGE): desc.message_type.fields_by_number[2].type
== descriptor.FieldDescriptor.TYPE_MESSAGE
):
for v in six.itervalues(values): for v in six.itervalues(values):
_normalize_number_fields(v) _normalize_number_fields(v)
else: else: