Add a face alignment preprocessor to face stylizer.
PiperOrigin-RevId: 542559764
This commit is contained in:
parent
825e3a8af0
commit
ba7e0e0e50
|
@ -20,13 +20,6 @@ licenses(["notice"])
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe:__subpackages__"])
|
package(default_visibility = ["//mediapipe:__subpackages__"])
|
||||||
|
|
||||||
filegroup(
|
|
||||||
name = "testdata",
|
|
||||||
srcs = glob([
|
|
||||||
"testdata/**",
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "constants",
|
name = "constants",
|
||||||
srcs = ["constants.py"],
|
srcs = ["constants.py"],
|
||||||
|
@ -72,18 +65,11 @@ py_library(
|
||||||
name = "dataset",
|
name = "dataset",
|
||||||
srcs = ["dataset.py"],
|
srcs = ["dataset.py"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":constants",
|
||||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
"//mediapipe/model_maker/python/vision/core:image_utils",
|
"//mediapipe/python:_framework_bindings",
|
||||||
],
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
)
|
"//mediapipe/tasks/python/vision:face_aligner",
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "dataset_test",
|
|
||||||
srcs = ["dataset_test.py"],
|
|
||||||
data = [":testdata"],
|
|
||||||
deps = [
|
|
||||||
":dataset",
|
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -41,5 +41,11 @@ FACE_STYLIZER_W_FILES = file_util.DownloadedFiles(
|
||||||
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy',
|
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
FACE_ALIGNER_TASK_FILES = file_util.DownloadedFiles(
|
||||||
|
'face_stylizer/face_landmarker_v2.task',
|
||||||
|
'https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task',
|
||||||
|
is_folder=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Dimension of the input style vector to the decoder
|
# Dimension of the input style vector to the decoder
|
||||||
STYLE_DIM = 512
|
STYLE_DIM = 512
|
||||||
|
|
|
@ -13,13 +13,37 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Face stylizer dataset library."""
|
"""Face stylizer dataset library."""
|
||||||
|
|
||||||
|
from typing import Sequence
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
from mediapipe.model_maker.python.vision.core import image_utils
|
from mediapipe.model_maker.python.vision.face_stylizer import constants
|
||||||
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.vision import face_aligner
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess_face_dataset(
|
||||||
|
all_image_paths: Sequence[str],
|
||||||
|
) -> Sequence[tf.Tensor]:
|
||||||
|
"""Preprocess face image dataset by aligning the face."""
|
||||||
|
path = constants.FACE_ALIGNER_TASK_FILES.get_path()
|
||||||
|
base_options = base_options_module.BaseOptions(model_asset_path=path)
|
||||||
|
options = face_aligner.FaceAlignerOptions(base_options=base_options)
|
||||||
|
aligner = face_aligner.FaceAligner.create_from_options(options)
|
||||||
|
|
||||||
|
preprocessed_images = []
|
||||||
|
for path in all_image_paths:
|
||||||
|
tf.compat.v1.logging.info('Preprocess image %s', path)
|
||||||
|
image = image_module.Image.create_from_file(path)
|
||||||
|
aligned_image = aligner.align(image)
|
||||||
|
aligned_image_tensor = tf.convert_to_tensor(aligned_image.numpy_view())
|
||||||
|
preprocessed_images.append(aligned_image_tensor)
|
||||||
|
|
||||||
|
return preprocessed_images
|
||||||
|
|
||||||
|
|
||||||
# TODO: Change to a unlabeled dataset if it makes sense.
|
# TODO: Change to a unlabeled dataset if it makes sense.
|
||||||
|
@ -58,6 +82,7 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
):
|
):
|
||||||
raise ValueError('No images found under given directory')
|
raise ValueError('No images found under given directory')
|
||||||
|
|
||||||
|
image_data = _preprocess_face_dataset(all_image_paths)
|
||||||
label_names = sorted(
|
label_names = sorted(
|
||||||
name
|
name
|
||||||
for name in os.listdir(data_root)
|
for name in os.listdir(data_root)
|
||||||
|
@ -73,11 +98,7 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
for path in all_image_paths
|
for path in all_image_paths
|
||||||
]
|
]
|
||||||
|
|
||||||
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
image_ds = tf.data.Dataset.from_tensor_slices(image_data)
|
||||||
|
|
||||||
image_ds = path_ds.map(
|
|
||||||
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load label
|
# Load label
|
||||||
label_ds = tf.data.Dataset.from_tensor_slices(
|
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||||
|
|
|
@ -12,8 +12,10 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.core import image_utils
|
||||||
from mediapipe.model_maker.python.vision.face_stylizer import dataset
|
from mediapipe.model_maker.python.vision.face_stylizer import dataset
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
@ -22,10 +24,10 @@ class DatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self._test_data_dirname = 'input/style'
|
|
||||||
|
|
||||||
def test_from_folder(self):
|
def test_from_folder(self):
|
||||||
input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
|
test_data_dirname = 'input/style'
|
||||||
|
input_data_dir = test_utils.get_test_data_path(test_data_dirname)
|
||||||
data = dataset.Dataset.from_folder(dirname=input_data_dir)
|
data = dataset.Dataset.from_folder(dirname=input_data_dir)
|
||||||
self.assertEqual(data.num_classes, 2)
|
self.assertEqual(data.num_classes, 2)
|
||||||
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
|
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
"""APIs to train face stylization model."""
|
"""APIs to train face stylization model."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -54,7 +54,6 @@ class FaceStylizer(object):
|
||||||
self._model_spec = model_spec
|
self._model_spec = model_spec
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
self._hparams = hparams
|
self._hparams = hparams
|
||||||
# TODO: Support face alignment in image preprocessor.
|
|
||||||
self._preprocessor = image_preprocessing.Preprocessor(
|
self._preprocessor = image_preprocessing.Preprocessor(
|
||||||
input_shape=self._model_spec.input_image_shape,
|
input_shape=self._model_spec.input_image_shape,
|
||||||
num_classes=1,
|
num_classes=1,
|
||||||
|
@ -128,7 +127,7 @@ class FaceStylizer(object):
|
||||||
def _train_model(
|
def _train_model(
|
||||||
self,
|
self,
|
||||||
train_data: classification_ds.ClassificationDataset,
|
train_data: classification_ds.ClassificationDataset,
|
||||||
preprocessor: Optional[Callable[..., bool]] = None,
|
preprocessor: Optional[Callable[..., Any]] = None,
|
||||||
):
|
):
|
||||||
"""Trains the face stylizer model.
|
"""Trains the face stylizer model.
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ py_library(
|
||||||
name = "base_options",
|
name = "base_options",
|
||||||
srcs = ["base_options.py"],
|
srcs = ["base_options.py"],
|
||||||
visibility = [
|
visibility = [
|
||||||
"//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__",
|
"//mediapipe/model_maker:__subpackages__",
|
||||||
"//mediapipe/tasks:users",
|
"//mediapipe/tasks:users",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
|
|
@ -22,7 +22,6 @@ import six
|
||||||
from google.protobuf import descriptor
|
from google.protobuf import descriptor
|
||||||
from google.protobuf import descriptor_pool
|
from google.protobuf import descriptor_pool
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
|
|
||||||
from mediapipe.python._framework_bindings import image as image_module
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
from mediapipe.python._framework_bindings import image_frame as image_frame_module
|
from mediapipe.python._framework_bindings import image_frame as image_frame_module
|
||||||
|
|
||||||
|
@ -44,18 +43,21 @@ def test_srcdir():
|
||||||
|
|
||||||
def get_test_data_path(file_or_dirname_path: str) -> str:
|
def get_test_data_path(file_or_dirname_path: str) -> str:
|
||||||
"""Returns full test data path."""
|
"""Returns full test data path."""
|
||||||
for (directory, subdirs, files) in os.walk(test_srcdir()):
|
for directory, subdirs, files in os.walk(test_srcdir()):
|
||||||
for f in subdirs + files:
|
for f in subdirs + files:
|
||||||
path = os.path.join(directory, f)
|
path = os.path.join(directory, f)
|
||||||
if path.endswith(file_or_dirname_path):
|
if path.endswith(file_or_dirname_path):
|
||||||
return path
|
return path
|
||||||
raise ValueError("No %s in test directory: %s." %
|
raise ValueError(
|
||||||
(file_or_dirname_path, test_srcdir()))
|
"No %s in test directory: %s." % (file_or_dirname_path, test_srcdir())
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_calibration_file(file_dir: str,
|
def create_calibration_file(
|
||||||
file_name: str = "score_calibration.txt",
|
file_dir: str,
|
||||||
content: str = "1.0,2.0,3.0,4.0") -> str:
|
file_name: str = "score_calibration.txt",
|
||||||
|
content: str = "1.0,2.0,3.0,4.0",
|
||||||
|
) -> str:
|
||||||
"""Creates the calibration file."""
|
"""Creates the calibration file."""
|
||||||
calibration_file = os.path.join(file_dir, file_name)
|
calibration_file = os.path.join(file_dir, file_name)
|
||||||
with open(calibration_file, mode="w") as file:
|
with open(calibration_file, mode="w") as file:
|
||||||
|
@ -63,12 +65,9 @@ def create_calibration_file(file_dir: str,
|
||||||
return calibration_file
|
return calibration_file
|
||||||
|
|
||||||
|
|
||||||
def assert_proto_equals(self,
|
def assert_proto_equals(
|
||||||
a,
|
self, a, b, check_initialized=True, normalize_numbers=True, msg=None
|
||||||
b,
|
):
|
||||||
check_initialized=True,
|
|
||||||
normalize_numbers=True,
|
|
||||||
msg=None):
|
|
||||||
"""assert_proto_equals() is useful for unit tests.
|
"""assert_proto_equals() is useful for unit tests.
|
||||||
|
|
||||||
It produces much more helpful output than assertEqual() for proto2 messages.
|
It produces much more helpful output than assertEqual() for proto2 messages.
|
||||||
|
@ -113,7 +112,8 @@ def assert_proto_equals(self,
|
||||||
self.assertMultiLineEqual(a_str, b_str, msg=msg)
|
self.assertMultiLineEqual(a_str, b_str, msg=msg)
|
||||||
else:
|
else:
|
||||||
diff = "".join(
|
diff = "".join(
|
||||||
difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True)))
|
difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True))
|
||||||
|
)
|
||||||
if diff:
|
if diff:
|
||||||
self.fail("%s :\n%s" % (msg, diff))
|
self.fail("%s :\n%s" % (msg, diff))
|
||||||
|
|
||||||
|
@ -147,14 +147,18 @@ def _normalize_number_fields(pb):
|
||||||
# We force 32-bit values to int and 64-bit values to long to make
|
# We force 32-bit values to int and 64-bit values to long to make
|
||||||
# alternate implementations where the distinction is more significant
|
# alternate implementations where the distinction is more significant
|
||||||
# (e.g. the C++ implementation) simpler.
|
# (e.g. the C++ implementation) simpler.
|
||||||
if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
|
if desc.type in (
|
||||||
descriptor.FieldDescriptor.TYPE_UINT64,
|
descriptor.FieldDescriptor.TYPE_INT64,
|
||||||
descriptor.FieldDescriptor.TYPE_SINT64):
|
descriptor.FieldDescriptor.TYPE_UINT64,
|
||||||
|
descriptor.FieldDescriptor.TYPE_SINT64,
|
||||||
|
):
|
||||||
normalized_values = [int(x) for x in values]
|
normalized_values = [int(x) for x in values]
|
||||||
elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
|
elif desc.type in (
|
||||||
descriptor.FieldDescriptor.TYPE_UINT32,
|
descriptor.FieldDescriptor.TYPE_INT32,
|
||||||
descriptor.FieldDescriptor.TYPE_SINT32,
|
descriptor.FieldDescriptor.TYPE_UINT32,
|
||||||
descriptor.FieldDescriptor.TYPE_ENUM):
|
descriptor.FieldDescriptor.TYPE_SINT32,
|
||||||
|
descriptor.FieldDescriptor.TYPE_ENUM,
|
||||||
|
):
|
||||||
normalized_values = [int(x) for x in values]
|
normalized_values = [int(x) for x in values]
|
||||||
elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
|
elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
|
||||||
normalized_values = [round(x, 4) for x in values]
|
normalized_values = [round(x, 4) for x in values]
|
||||||
|
@ -168,14 +172,20 @@ def _normalize_number_fields(pb):
|
||||||
else:
|
else:
|
||||||
setattr(pb, desc.name, normalized_values[0])
|
setattr(pb, desc.name, normalized_values[0])
|
||||||
|
|
||||||
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
|
if (
|
||||||
desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
|
desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE
|
||||||
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
|
or desc.type == descriptor.FieldDescriptor.TYPE_GROUP
|
||||||
desc.message_type.has_options and
|
):
|
||||||
desc.message_type.GetOptions().map_entry):
|
if (
|
||||||
|
desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE
|
||||||
|
and desc.message_type.has_options
|
||||||
|
and desc.message_type.GetOptions().map_entry
|
||||||
|
):
|
||||||
# This is a map, only recurse if the values have a message type.
|
# This is a map, only recurse if the values have a message type.
|
||||||
if (desc.message_type.fields_by_number[2].type ==
|
if (
|
||||||
descriptor.FieldDescriptor.TYPE_MESSAGE):
|
desc.message_type.fields_by_number[2].type
|
||||||
|
== descriptor.FieldDescriptor.TYPE_MESSAGE
|
||||||
|
):
|
||||||
for v in six.itervalues(values):
|
for v in six.itervalues(values):
|
||||||
_normalize_number_fields(v)
|
_normalize_number_fields(v)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user