Add an API to run inference with face stylizer TF model.

PiperOrigin-RevId: 558926645
This commit is contained in:
MediaPipe Team 2023-08-21 16:06:12 -07:00 committed by Copybara-Service
parent bbf168ddda
commit 7f8150776a
2 changed files with 57 additions and 0 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""APIs to train face stylization model.""" """APIs to train face stylization model."""
import logging
import os import os
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
import zipfile import zipfile
@ -103,6 +104,37 @@ class FaceStylizer(object):
face_stylizer._create_and_train_model(train_data) face_stylizer._create_and_train_model(train_data)
return face_stylizer return face_stylizer
def stylize(
self, data: classification_ds.ClassificationDataset
) -> classification_ds.ClassificationDataset:
"""Stylizes the images represented by the input dataset.
Args:
data: Dataset of input images, can contain multiple images.
Returns:
A dataset contains the stylized images
"""
input_dataset = data.gen_tf_dataset(preprocess=self._preprocessor)
output_img_list = []
for sample in input_dataset:
image = sample[0]
w = self._encoder(image, training=True)
x = self._decoder({'inputs': w + self.w_avg}, training=True)
output_batch = x['image'][-1]
output_img_tensor = (tf.squeeze(output_batch).numpy() + 1.0) * 127.5
output_img_list.append(output_img_tensor)
image_ds = tf.data.Dataset.from_tensor_slices(output_img_list)
logging.info('Stylized %s images.', len(output_img_list))
return classification_ds.ClassificationDataset(
dataset=image_ds,
label_names=['stylized'],
size=len(output_img_list),
)
def _create_and_train_model( def _create_and_train_model(
self, train_data: classification_ds.ClassificationDataset self, train_data: classification_ds.ClassificationDataset
): ):

View File

@ -33,6 +33,15 @@ class FaceStylizerTest(tf.test.TestCase):
data = face_stylizer.Dataset.from_image(filename=input_style_image_file) data = face_stylizer.Dataset.from_image(filename=input_style_image_file)
return data return data
def _create_eval_dataset(self):
"""Create evaluation dataset."""
input_test_image_file = test_utils.get_test_data_path(
'input/raw/face/portrait.jpg'
)
data = face_stylizer.Dataset.from_image(filename=input_test_image_file)
return data
def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer): def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer):
"""Evaluates the fine-tuned face stylizer model.""" """Evaluates the fine-tuned face stylizer model."""
test_image = tf.ones(shape=(256, 256, 3), dtype=tf.float32) test_image = tf.ones(shape=(256, 256, 3), dtype=tf.float32)
@ -44,6 +53,7 @@ class FaceStylizerTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self._train_data = self._create_training_dataset() self._train_data = self._create_training_dataset()
self._eval_data = self._create_eval_dataset()
def test_finetuning_face_stylizer_with_single_input_style_image(self): def test_finetuning_face_stylizer_with_single_input_style_image(self):
with self.test_session(use_gpu=True): with self.test_session(use_gpu=True):
@ -56,6 +66,21 @@ class FaceStylizerTest(tf.test.TestCase):
) )
self._evaluate_saved_model(model) self._evaluate_saved_model(model)
def test_evaluate_face_stylizer(self):
with self.test_session(use_gpu=True):
face_stylizer_options = face_stylizer.FaceStylizerOptions(
model=face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256,
hparams=face_stylizer.HParams(epochs=1),
)
model = face_stylizer.FaceStylizer.create(
train_data=self._train_data, options=face_stylizer_options
)
eval_output = model.stylize(self._eval_data)
self.assertLen(eval_output, 1)
eval_output_data = eval_output.gen_tf_dataset()
iterator = iter(eval_output_data)
self.assertEqual(iterator.get_next().shape, (1, 256, 256, 3))
def test_export_face_stylizer_tflite_model(self): def test_export_face_stylizer_tflite_model(self):
with self.test_session(use_gpu=True): with self.test_session(use_gpu=True):
model_enum = face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256 model_enum = face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256