diff --git a/mediapipe/model_maker/python/vision/face_stylizer/BUILD b/mediapipe/model_maker/python/vision/face_stylizer/BUILD index a2e30a112..29c30c873 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/BUILD +++ b/mediapipe/model_maker/python/vision/face_stylizer/BUILD @@ -20,13 +20,6 @@ licenses(["notice"]) package(default_visibility = ["//mediapipe:__subpackages__"]) -filegroup( - name = "testdata", - srcs = glob([ - "testdata/**", - ]), -) - py_library( name = "constants", srcs = ["constants.py"], @@ -72,18 +65,11 @@ py_library( name = "dataset", srcs = ["dataset.py"], deps = [ + ":constants", "//mediapipe/model_maker/python/core/data:classification_dataset", - "//mediapipe/model_maker/python/vision/core:image_utils", - ], -) - -py_test( - name = "dataset_test", - srcs = ["dataset_test.py"], - data = [":testdata"], - deps = [ - ":dataset", - "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/vision:face_aligner", ], ) diff --git a/mediapipe/model_maker/python/vision/face_stylizer/constants.py b/mediapipe/model_maker/python/vision/face_stylizer/constants.py index e7a03aebd..ac7675232 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/constants.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/constants.py @@ -41,5 +41,11 @@ FACE_STYLIZER_W_FILES = file_util.DownloadedFiles( 'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy', ) +FACE_ALIGNER_TASK_FILES = file_util.DownloadedFiles( + 'face_stylizer/face_landmarker_v2.task', + 'https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task', + is_folder=False, +) + # Dimension of the input style vector to the decoder STYLE_DIM = 512 diff --git a/mediapipe/model_maker/python/vision/face_stylizer/dataset.py b/mediapipe/model_maker/python/vision/face_stylizer/dataset.py index d517fd9c1..93478de1b 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/dataset.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/dataset.py @@ -13,13 +13,37 @@ # limitations under the License. """Face stylizer dataset library.""" +from typing import Sequence import logging import os import tensorflow as tf from mediapipe.model_maker.python.core.data import classification_dataset -from mediapipe.model_maker.python.vision.core import image_utils +from mediapipe.model_maker.python.vision.face_stylizer import constants +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.vision import face_aligner + + +def _preprocess_face_dataset( + all_image_paths: Sequence[str], +) -> Sequence[tf.Tensor]: + """Preprocess face image dataset by aligning the face.""" + path = constants.FACE_ALIGNER_TASK_FILES.get_path() + base_options = base_options_module.BaseOptions(model_asset_path=path) + options = face_aligner.FaceAlignerOptions(base_options=base_options) + aligner = face_aligner.FaceAligner.create_from_options(options) + + preprocessed_images = [] + for path in all_image_paths: + tf.compat.v1.logging.info('Preprocess image %s', path) + image = image_module.Image.create_from_file(path) + aligned_image = aligner.align(image) + aligned_image_tensor = tf.convert_to_tensor(aligned_image.numpy_view()) + preprocessed_images.append(aligned_image_tensor) + + return preprocessed_images # TODO: Change to a unlabeled dataset if it makes sense. @@ -58,6 +82,7 @@ class Dataset(classification_dataset.ClassificationDataset): ): raise ValueError('No images found under given directory') + image_data = _preprocess_face_dataset(all_image_paths) label_names = sorted( name for name in os.listdir(data_root) @@ -73,11 +98,7 @@ class Dataset(classification_dataset.ClassificationDataset): for path in all_image_paths ] - path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths) - - image_ds = path_ds.map( - image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE - ) + image_ds = tf.data.Dataset.from_tensor_slices(image_data) # Load label label_ds = tf.data.Dataset.from_tensor_slices( diff --git a/mediapipe/model_maker/python/vision/face_stylizer/dataset_test.py b/mediapipe/model_maker/python/vision/face_stylizer/dataset_test.py index 73140f30e..900371de1 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/dataset_test.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/dataset_test.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import tensorflow as tf +from mediapipe.model_maker.python.vision.core import image_utils from mediapipe.model_maker.python.vision.face_stylizer import dataset from mediapipe.tasks.python.test import test_utils @@ -22,10 +24,10 @@ class DatasetTest(tf.test.TestCase): def setUp(self): super().setUp() - self._test_data_dirname = 'input/style' def test_from_folder(self): - input_data_dir = test_utils.get_test_data_path(self._test_data_dirname) + test_data_dirname = 'input/style' + input_data_dir = test_utils.get_test_data_path(test_data_dirname) data = dataset.Dataset.from_folder(dirname=input_data_dir) self.assertEqual(data.num_classes, 2) self.assertEqual(data.label_names, ['cartoon', 'sketch']) diff --git a/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py index 5758ac7b5..dfa8a04b4 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py @@ -14,7 +14,7 @@ """APIs to train face stylization model.""" import os -from typing import Callable, Optional +from typing import Any, Callable, Optional import numpy as np import tensorflow as tf @@ -54,7 +54,6 @@ class FaceStylizer(object): self._model_spec = model_spec self._model_options = model_options self._hparams = hparams - # TODO: Support face alignment in image preprocessor. self._preprocessor = image_preprocessing.Preprocessor( input_shape=self._model_spec.input_image_shape, num_classes=1, @@ -128,7 +127,7 @@ class FaceStylizer(object): def _train_model( self, train_data: classification_ds.ClassificationDataset, - preprocessor: Optional[Callable[..., bool]] = None, + preprocessor: Optional[Callable[..., Any]] = None, ): """Trains the face stylizer model. diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index 76791c232..9d2dc3f0b 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -29,7 +29,7 @@ py_library( name = "base_options", srcs = ["base_options.py"], visibility = [ - "//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__", + "//mediapipe/model_maker:__subpackages__", "//mediapipe/tasks:users", ], deps = [ diff --git a/mediapipe/tasks/python/test/test_utils.py b/mediapipe/tasks/python/test/test_utils.py index 2dfc5a8c4..e790b9156 100644 --- a/mediapipe/tasks/python/test/test_utils.py +++ b/mediapipe/tasks/python/test/test_utils.py @@ -22,7 +22,6 @@ import six from google.protobuf import descriptor from google.protobuf import descriptor_pool from google.protobuf import text_format - from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image_frame as image_frame_module @@ -44,18 +43,21 @@ def test_srcdir(): def get_test_data_path(file_or_dirname_path: str) -> str: """Returns full test data path.""" - for (directory, subdirs, files) in os.walk(test_srcdir()): + for directory, subdirs, files in os.walk(test_srcdir()): for f in subdirs + files: path = os.path.join(directory, f) if path.endswith(file_or_dirname_path): return path - raise ValueError("No %s in test directory: %s." % - (file_or_dirname_path, test_srcdir())) + raise ValueError( + "No %s in test directory: %s." % (file_or_dirname_path, test_srcdir()) + ) -def create_calibration_file(file_dir: str, - file_name: str = "score_calibration.txt", - content: str = "1.0,2.0,3.0,4.0") -> str: +def create_calibration_file( + file_dir: str, + file_name: str = "score_calibration.txt", + content: str = "1.0,2.0,3.0,4.0", +) -> str: """Creates the calibration file.""" calibration_file = os.path.join(file_dir, file_name) with open(calibration_file, mode="w") as file: @@ -63,12 +65,9 @@ def create_calibration_file(file_dir: str, return calibration_file -def assert_proto_equals(self, - a, - b, - check_initialized=True, - normalize_numbers=True, - msg=None): +def assert_proto_equals( + self, a, b, check_initialized=True, normalize_numbers=True, msg=None +): """assert_proto_equals() is useful for unit tests. It produces much more helpful output than assertEqual() for proto2 messages. @@ -113,7 +112,8 @@ def assert_proto_equals(self, self.assertMultiLineEqual(a_str, b_str, msg=msg) else: diff = "".join( - difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True))) + difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True)) + ) if diff: self.fail("%s :\n%s" % (msg, diff)) @@ -147,14 +147,18 @@ def _normalize_number_fields(pb): # We force 32-bit values to int and 64-bit values to long to make # alternate implementations where the distinction is more significant # (e.g. the C++ implementation) simpler. - if desc.type in (descriptor.FieldDescriptor.TYPE_INT64, - descriptor.FieldDescriptor.TYPE_UINT64, - descriptor.FieldDescriptor.TYPE_SINT64): + if desc.type in ( + descriptor.FieldDescriptor.TYPE_INT64, + descriptor.FieldDescriptor.TYPE_UINT64, + descriptor.FieldDescriptor.TYPE_SINT64, + ): normalized_values = [int(x) for x in values] - elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32, - descriptor.FieldDescriptor.TYPE_UINT32, - descriptor.FieldDescriptor.TYPE_SINT32, - descriptor.FieldDescriptor.TYPE_ENUM): + elif desc.type in ( + descriptor.FieldDescriptor.TYPE_INT32, + descriptor.FieldDescriptor.TYPE_UINT32, + descriptor.FieldDescriptor.TYPE_SINT32, + descriptor.FieldDescriptor.TYPE_ENUM, + ): normalized_values = [int(x) for x in values] elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT: normalized_values = [round(x, 4) for x in values] @@ -168,14 +172,20 @@ def _normalize_number_fields(pb): else: setattr(pb, desc.name, normalized_values[0]) - if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or - desc.type == descriptor.FieldDescriptor.TYPE_GROUP): - if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and - desc.message_type.has_options and - desc.message_type.GetOptions().map_entry): + if ( + desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE + or desc.type == descriptor.FieldDescriptor.TYPE_GROUP + ): + if ( + desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE + and desc.message_type.has_options + and desc.message_type.GetOptions().map_entry + ): # This is a map, only recurse if the values have a message type. - if (desc.message_type.fields_by_number[2].type == - descriptor.FieldDescriptor.TYPE_MESSAGE): + if ( + desc.message_type.fields_by_number[2].type + == descriptor.FieldDescriptor.TYPE_MESSAGE + ): for v in six.itervalues(values): _normalize_number_fields(v) else: