Add a face alignment preprocessor to face stylizer.
PiperOrigin-RevId: 542559764
This commit is contained in:
parent
825e3a8af0
commit
ba7e0e0e50
|
@ -20,13 +20,6 @@ licenses(["notice"])
|
|||
|
||||
package(default_visibility = ["//mediapipe:__subpackages__"])
|
||||
|
||||
filegroup(
|
||||
name = "testdata",
|
||||
srcs = glob([
|
||||
"testdata/**",
|
||||
]),
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "constants",
|
||||
srcs = ["constants.py"],
|
||||
|
@ -72,18 +65,11 @@ py_library(
|
|||
name = "dataset",
|
||||
srcs = ["dataset.py"],
|
||||
deps = [
|
||||
":constants",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
"//mediapipe/model_maker/python/vision/core:image_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dataset_test",
|
||||
srcs = ["dataset_test.py"],
|
||||
data = [":testdata"],
|
||||
deps = [
|
||||
":dataset",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/vision:face_aligner",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -41,5 +41,11 @@ FACE_STYLIZER_W_FILES = file_util.DownloadedFiles(
|
|||
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy',
|
||||
)
|
||||
|
||||
FACE_ALIGNER_TASK_FILES = file_util.DownloadedFiles(
|
||||
'face_stylizer/face_landmarker_v2.task',
|
||||
'https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task',
|
||||
is_folder=False,
|
||||
)
|
||||
|
||||
# Dimension of the input style vector to the decoder
|
||||
STYLE_DIM = 512
|
||||
|
|
|
@ -13,13 +13,37 @@
|
|||
# limitations under the License.
|
||||
"""Face stylizer dataset library."""
|
||||
|
||||
from typing import Sequence
|
||||
import logging
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
from mediapipe.model_maker.python.vision.core import image_utils
|
||||
from mediapipe.model_maker.python.vision.face_stylizer import constants
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.vision import face_aligner
|
||||
|
||||
|
||||
def _preprocess_face_dataset(
|
||||
all_image_paths: Sequence[str],
|
||||
) -> Sequence[tf.Tensor]:
|
||||
"""Preprocess face image dataset by aligning the face."""
|
||||
path = constants.FACE_ALIGNER_TASK_FILES.get_path()
|
||||
base_options = base_options_module.BaseOptions(model_asset_path=path)
|
||||
options = face_aligner.FaceAlignerOptions(base_options=base_options)
|
||||
aligner = face_aligner.FaceAligner.create_from_options(options)
|
||||
|
||||
preprocessed_images = []
|
||||
for path in all_image_paths:
|
||||
tf.compat.v1.logging.info('Preprocess image %s', path)
|
||||
image = image_module.Image.create_from_file(path)
|
||||
aligned_image = aligner.align(image)
|
||||
aligned_image_tensor = tf.convert_to_tensor(aligned_image.numpy_view())
|
||||
preprocessed_images.append(aligned_image_tensor)
|
||||
|
||||
return preprocessed_images
|
||||
|
||||
|
||||
# TODO: Change to a unlabeled dataset if it makes sense.
|
||||
|
@ -58,6 +82,7 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
):
|
||||
raise ValueError('No images found under given directory')
|
||||
|
||||
image_data = _preprocess_face_dataset(all_image_paths)
|
||||
label_names = sorted(
|
||||
name
|
||||
for name in os.listdir(data_root)
|
||||
|
@ -73,11 +98,7 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
for path in all_image_paths
|
||||
]
|
||||
|
||||
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
||||
|
||||
image_ds = path_ds.map(
|
||||
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
|
||||
)
|
||||
image_ds = tf.data.Dataset.from_tensor_slices(image_data)
|
||||
|
||||
# Load label
|
||||
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||
|
|
|
@ -12,8 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.core import image_utils
|
||||
from mediapipe.model_maker.python.vision.face_stylizer import dataset
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
|
@ -22,10 +24,10 @@ class DatasetTest(tf.test.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._test_data_dirname = 'input/style'
|
||||
|
||||
def test_from_folder(self):
|
||||
input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
|
||||
test_data_dirname = 'input/style'
|
||||
input_data_dir = test_utils.get_test_data_path(test_data_dirname)
|
||||
data = dataset.Dataset.from_folder(dirname=input_data_dir)
|
||||
self.assertEqual(data.num_classes, 2)
|
||||
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
"""APIs to train face stylization model."""
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
@ -54,7 +54,6 @@ class FaceStylizer(object):
|
|||
self._model_spec = model_spec
|
||||
self._model_options = model_options
|
||||
self._hparams = hparams
|
||||
# TODO: Support face alignment in image preprocessor.
|
||||
self._preprocessor = image_preprocessing.Preprocessor(
|
||||
input_shape=self._model_spec.input_image_shape,
|
||||
num_classes=1,
|
||||
|
@ -128,7 +127,7 @@ class FaceStylizer(object):
|
|||
def _train_model(
|
||||
self,
|
||||
train_data: classification_ds.ClassificationDataset,
|
||||
preprocessor: Optional[Callable[..., bool]] = None,
|
||||
preprocessor: Optional[Callable[..., Any]] = None,
|
||||
):
|
||||
"""Trains the face stylizer model.
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ py_library(
|
|||
name = "base_options",
|
||||
srcs = ["base_options.py"],
|
||||
visibility = [
|
||||
"//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__",
|
||||
"//mediapipe/model_maker:__subpackages__",
|
||||
"//mediapipe/tasks:users",
|
||||
],
|
||||
deps = [
|
||||
|
|
|
@ -22,7 +22,6 @@ import six
|
|||
from google.protobuf import descriptor
|
||||
from google.protobuf import descriptor_pool
|
||||
from google.protobuf import text_format
|
||||
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.python._framework_bindings import image_frame as image_frame_module
|
||||
|
||||
|
@ -44,18 +43,21 @@ def test_srcdir():
|
|||
|
||||
def get_test_data_path(file_or_dirname_path: str) -> str:
|
||||
"""Returns full test data path."""
|
||||
for (directory, subdirs, files) in os.walk(test_srcdir()):
|
||||
for directory, subdirs, files in os.walk(test_srcdir()):
|
||||
for f in subdirs + files:
|
||||
path = os.path.join(directory, f)
|
||||
if path.endswith(file_or_dirname_path):
|
||||
return path
|
||||
raise ValueError("No %s in test directory: %s." %
|
||||
(file_or_dirname_path, test_srcdir()))
|
||||
raise ValueError(
|
||||
"No %s in test directory: %s." % (file_or_dirname_path, test_srcdir())
|
||||
)
|
||||
|
||||
|
||||
def create_calibration_file(file_dir: str,
|
||||
def create_calibration_file(
|
||||
file_dir: str,
|
||||
file_name: str = "score_calibration.txt",
|
||||
content: str = "1.0,2.0,3.0,4.0") -> str:
|
||||
content: str = "1.0,2.0,3.0,4.0",
|
||||
) -> str:
|
||||
"""Creates the calibration file."""
|
||||
calibration_file = os.path.join(file_dir, file_name)
|
||||
with open(calibration_file, mode="w") as file:
|
||||
|
@ -63,12 +65,9 @@ def create_calibration_file(file_dir: str,
|
|||
return calibration_file
|
||||
|
||||
|
||||
def assert_proto_equals(self,
|
||||
a,
|
||||
b,
|
||||
check_initialized=True,
|
||||
normalize_numbers=True,
|
||||
msg=None):
|
||||
def assert_proto_equals(
|
||||
self, a, b, check_initialized=True, normalize_numbers=True, msg=None
|
||||
):
|
||||
"""assert_proto_equals() is useful for unit tests.
|
||||
|
||||
It produces much more helpful output than assertEqual() for proto2 messages.
|
||||
|
@ -113,7 +112,8 @@ def assert_proto_equals(self,
|
|||
self.assertMultiLineEqual(a_str, b_str, msg=msg)
|
||||
else:
|
||||
diff = "".join(
|
||||
difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True)))
|
||||
difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True))
|
||||
)
|
||||
if diff:
|
||||
self.fail("%s :\n%s" % (msg, diff))
|
||||
|
||||
|
@ -147,14 +147,18 @@ def _normalize_number_fields(pb):
|
|||
# We force 32-bit values to int and 64-bit values to long to make
|
||||
# alternate implementations where the distinction is more significant
|
||||
# (e.g. the C++ implementation) simpler.
|
||||
if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
|
||||
if desc.type in (
|
||||
descriptor.FieldDescriptor.TYPE_INT64,
|
||||
descriptor.FieldDescriptor.TYPE_UINT64,
|
||||
descriptor.FieldDescriptor.TYPE_SINT64):
|
||||
descriptor.FieldDescriptor.TYPE_SINT64,
|
||||
):
|
||||
normalized_values = [int(x) for x in values]
|
||||
elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
|
||||
elif desc.type in (
|
||||
descriptor.FieldDescriptor.TYPE_INT32,
|
||||
descriptor.FieldDescriptor.TYPE_UINT32,
|
||||
descriptor.FieldDescriptor.TYPE_SINT32,
|
||||
descriptor.FieldDescriptor.TYPE_ENUM):
|
||||
descriptor.FieldDescriptor.TYPE_ENUM,
|
||||
):
|
||||
normalized_values = [int(x) for x in values]
|
||||
elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
|
||||
normalized_values = [round(x, 4) for x in values]
|
||||
|
@ -168,14 +172,20 @@ def _normalize_number_fields(pb):
|
|||
else:
|
||||
setattr(pb, desc.name, normalized_values[0])
|
||||
|
||||
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
|
||||
desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
|
||||
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
|
||||
desc.message_type.has_options and
|
||||
desc.message_type.GetOptions().map_entry):
|
||||
if (
|
||||
desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE
|
||||
or desc.type == descriptor.FieldDescriptor.TYPE_GROUP
|
||||
):
|
||||
if (
|
||||
desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE
|
||||
and desc.message_type.has_options
|
||||
and desc.message_type.GetOptions().map_entry
|
||||
):
|
||||
# This is a map, only recurse if the values have a message type.
|
||||
if (desc.message_type.fields_by_number[2].type ==
|
||||
descriptor.FieldDescriptor.TYPE_MESSAGE):
|
||||
if (
|
||||
desc.message_type.fields_by_number[2].type
|
||||
== descriptor.FieldDescriptor.TYPE_MESSAGE
|
||||
):
|
||||
for v in six.itervalues(values):
|
||||
_normalize_number_fields(v)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue
Block a user