Add an API to run inference with face stylizer TF model.

PiperOrigin-RevId: 558926645
This commit is contained in:
MediaPipe Team 2023-08-21 16:06:12 -07:00 committed by Copybara-Service
parent bbf168ddda
commit 7f8150776a
2 changed files with 57 additions and 0 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""APIs to train face stylization model."""
import logging
import os
from typing import Any, Callable, Optional
import zipfile
@ -103,6 +104,37 @@ class FaceStylizer(object):
face_stylizer._create_and_train_model(train_data)
return face_stylizer
def stylize(
self, data: classification_ds.ClassificationDataset
) -> classification_ds.ClassificationDataset:
"""Stylizes the images represented by the input dataset.
Args:
data: Dataset of input images, can contain multiple images.
Returns:
A dataset contains the stylized images
"""
input_dataset = data.gen_tf_dataset(preprocess=self._preprocessor)
output_img_list = []
for sample in input_dataset:
image = sample[0]
w = self._encoder(image, training=True)
x = self._decoder({'inputs': w + self.w_avg}, training=True)
output_batch = x['image'][-1]
output_img_tensor = (tf.squeeze(output_batch).numpy() + 1.0) * 127.5
output_img_list.append(output_img_tensor)
image_ds = tf.data.Dataset.from_tensor_slices(output_img_list)
logging.info('Stylized %s images.', len(output_img_list))
return classification_ds.ClassificationDataset(
dataset=image_ds,
label_names=['stylized'],
size=len(output_img_list),
)
def _create_and_train_model(
self, train_data: classification_ds.ClassificationDataset
):

View File

@ -33,6 +33,15 @@ class FaceStylizerTest(tf.test.TestCase):
data = face_stylizer.Dataset.from_image(filename=input_style_image_file)
return data
def _create_eval_dataset(self):
"""Create evaluation dataset."""
input_test_image_file = test_utils.get_test_data_path(
'input/raw/face/portrait.jpg'
)
data = face_stylizer.Dataset.from_image(filename=input_test_image_file)
return data
def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer):
"""Evaluates the fine-tuned face stylizer model."""
test_image = tf.ones(shape=(256, 256, 3), dtype=tf.float32)
@ -44,6 +53,7 @@ class FaceStylizerTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self._train_data = self._create_training_dataset()
self._eval_data = self._create_eval_dataset()
def test_finetuning_face_stylizer_with_single_input_style_image(self):
with self.test_session(use_gpu=True):
@ -56,6 +66,21 @@ class FaceStylizerTest(tf.test.TestCase):
)
self._evaluate_saved_model(model)
def test_evaluate_face_stylizer(self):
with self.test_session(use_gpu=True):
face_stylizer_options = face_stylizer.FaceStylizerOptions(
model=face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256,
hparams=face_stylizer.HParams(epochs=1),
)
model = face_stylizer.FaceStylizer.create(
train_data=self._train_data, options=face_stylizer_options
)
eval_output = model.stylize(self._eval_data)
self.assertLen(eval_output, 1)
eval_output_data = eval_output.gen_tf_dataset()
iterator = iter(eval_output_data)
self.assertEqual(iterator.get_next().shape, (1, 256, 256, 3))
def test_export_face_stylizer_tflite_model(self):
with self.test_session(use_gpu=True):
model_enum = face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256