Add a face alignment preprocessor to face stylizer.

PiperOrigin-RevId: 542559764
This commit is contained in:
MediaPipe Team 2023-06-22 07:55:09 -07:00 committed by Copybara-Service
parent 825e3a8af0
commit ba7e0e0e50
7 changed files with 82 additions and 58 deletions

View File

@ -20,13 +20,6 @@ licenses(["notice"])
package(default_visibility = ["//mediapipe:__subpackages__"])
filegroup(
name = "testdata",
srcs = glob([
"testdata/**",
]),
)
py_library(
name = "constants",
srcs = ["constants.py"],
@ -72,18 +65,11 @@ py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = [
":constants",
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/vision/core:image_utils",
],
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
data = [":testdata"],
deps = [
":dataset",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/vision:face_aligner",
],
)

View File

@ -41,5 +41,11 @@ FACE_STYLIZER_W_FILES = file_util.DownloadedFiles(
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy',
)
FACE_ALIGNER_TASK_FILES = file_util.DownloadedFiles(
'face_stylizer/face_landmarker_v2.task',
'https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task',
is_folder=False,
)
# Dimension of the input style vector to the decoder
STYLE_DIM = 512

View File

@ -13,13 +13,37 @@
# limitations under the License.
"""Face stylizer dataset library."""
from typing import Sequence
import logging
import os
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.vision.core import image_utils
from mediapipe.model_maker.python.vision.face_stylizer import constants
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.vision import face_aligner
def _preprocess_face_dataset(
all_image_paths: Sequence[str],
) -> Sequence[tf.Tensor]:
"""Preprocess face image dataset by aligning the face."""
path = constants.FACE_ALIGNER_TASK_FILES.get_path()
base_options = base_options_module.BaseOptions(model_asset_path=path)
options = face_aligner.FaceAlignerOptions(base_options=base_options)
aligner = face_aligner.FaceAligner.create_from_options(options)
preprocessed_images = []
for path in all_image_paths:
tf.compat.v1.logging.info('Preprocess image %s', path)
image = image_module.Image.create_from_file(path)
aligned_image = aligner.align(image)
aligned_image_tensor = tf.convert_to_tensor(aligned_image.numpy_view())
preprocessed_images.append(aligned_image_tensor)
return preprocessed_images
# TODO: Change to a unlabeled dataset if it makes sense.
@ -58,6 +82,7 @@ class Dataset(classification_dataset.ClassificationDataset):
):
raise ValueError('No images found under given directory')
image_data = _preprocess_face_dataset(all_image_paths)
label_names = sorted(
name
for name in os.listdir(data_root)
@ -73,11 +98,7 @@ class Dataset(classification_dataset.ClassificationDataset):
for path in all_image_paths
]
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
image_ds = path_ds.map(
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
)
image_ds = tf.data.Dataset.from_tensor_slices(image_data)
# Load label
label_ds = tf.data.Dataset.from_tensor_slices(

View File

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.vision.core import image_utils
from mediapipe.model_maker.python.vision.face_stylizer import dataset
from mediapipe.tasks.python.test import test_utils
@ -22,10 +24,10 @@ class DatasetTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self._test_data_dirname = 'input/style'
def test_from_folder(self):
input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
test_data_dirname = 'input/style'
input_data_dir = test_utils.get_test_data_path(test_data_dirname)
data = dataset.Dataset.from_folder(dirname=input_data_dir)
self.assertEqual(data.num_classes, 2)
self.assertEqual(data.label_names, ['cartoon', 'sketch'])

View File

@ -14,7 +14,7 @@
"""APIs to train face stylization model."""
import os
from typing import Callable, Optional
from typing import Any, Callable, Optional
import numpy as np
import tensorflow as tf
@ -54,7 +54,6 @@ class FaceStylizer(object):
self._model_spec = model_spec
self._model_options = model_options
self._hparams = hparams
# TODO: Support face alignment in image preprocessor.
self._preprocessor = image_preprocessing.Preprocessor(
input_shape=self._model_spec.input_image_shape,
num_classes=1,
@ -128,7 +127,7 @@ class FaceStylizer(object):
def _train_model(
self,
train_data: classification_ds.ClassificationDataset,
preprocessor: Optional[Callable[..., bool]] = None,
preprocessor: Optional[Callable[..., Any]] = None,
):
"""Trains the face stylizer model.

View File

@ -29,7 +29,7 @@ py_library(
name = "base_options",
srcs = ["base_options.py"],
visibility = [
"//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__",
"//mediapipe/model_maker:__subpackages__",
"//mediapipe/tasks:users",
],
deps = [

View File

@ -22,7 +22,6 @@ import six
from google.protobuf import descriptor
from google.protobuf import descriptor_pool
from google.protobuf import text_format
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import image_frame as image_frame_module
@ -44,18 +43,21 @@ def test_srcdir():
def get_test_data_path(file_or_dirname_path: str) -> str:
"""Returns full test data path."""
for (directory, subdirs, files) in os.walk(test_srcdir()):
for directory, subdirs, files in os.walk(test_srcdir()):
for f in subdirs + files:
path = os.path.join(directory, f)
if path.endswith(file_or_dirname_path):
return path
raise ValueError("No %s in test directory: %s." %
(file_or_dirname_path, test_srcdir()))
raise ValueError(
"No %s in test directory: %s." % (file_or_dirname_path, test_srcdir())
)
def create_calibration_file(file_dir: str,
def create_calibration_file(
file_dir: str,
file_name: str = "score_calibration.txt",
content: str = "1.0,2.0,3.0,4.0") -> str:
content: str = "1.0,2.0,3.0,4.0",
) -> str:
"""Creates the calibration file."""
calibration_file = os.path.join(file_dir, file_name)
with open(calibration_file, mode="w") as file:
@ -63,12 +65,9 @@ def create_calibration_file(file_dir: str,
return calibration_file
def assert_proto_equals(self,
a,
b,
check_initialized=True,
normalize_numbers=True,
msg=None):
def assert_proto_equals(
self, a, b, check_initialized=True, normalize_numbers=True, msg=None
):
"""assert_proto_equals() is useful for unit tests.
It produces much more helpful output than assertEqual() for proto2 messages.
@ -113,7 +112,8 @@ def assert_proto_equals(self,
self.assertMultiLineEqual(a_str, b_str, msg=msg)
else:
diff = "".join(
difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True)))
difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True))
)
if diff:
self.fail("%s :\n%s" % (msg, diff))
@ -147,14 +147,18 @@ def _normalize_number_fields(pb):
# We force 32-bit values to int and 64-bit values to long to make
# alternate implementations where the distinction is more significant
# (e.g. the C++ implementation) simpler.
if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
if desc.type in (
descriptor.FieldDescriptor.TYPE_INT64,
descriptor.FieldDescriptor.TYPE_UINT64,
descriptor.FieldDescriptor.TYPE_SINT64):
descriptor.FieldDescriptor.TYPE_SINT64,
):
normalized_values = [int(x) for x in values]
elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
elif desc.type in (
descriptor.FieldDescriptor.TYPE_INT32,
descriptor.FieldDescriptor.TYPE_UINT32,
descriptor.FieldDescriptor.TYPE_SINT32,
descriptor.FieldDescriptor.TYPE_ENUM):
descriptor.FieldDescriptor.TYPE_ENUM,
):
normalized_values = [int(x) for x in values]
elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
normalized_values = [round(x, 4) for x in values]
@ -168,14 +172,20 @@ def _normalize_number_fields(pb):
else:
setattr(pb, desc.name, normalized_values[0])
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
desc.message_type.has_options and
desc.message_type.GetOptions().map_entry):
if (
desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE
or desc.type == descriptor.FieldDescriptor.TYPE_GROUP
):
if (
desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE
and desc.message_type.has_options
and desc.message_type.GetOptions().map_entry
):
# This is a map, only recurse if the values have a message type.
if (desc.message_type.fields_by_number[2].type ==
descriptor.FieldDescriptor.TYPE_MESSAGE):
if (
desc.message_type.fields_by_number[2].type
== descriptor.FieldDescriptor.TYPE_MESSAGE
):
for v in six.itervalues(values):
_normalize_number_fields(v)
else: