Add an API to run inference with face stylizer TF model.
PiperOrigin-RevId: 558926645
This commit is contained in:
parent
bbf168ddda
commit
7f8150776a
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
"""APIs to train face stylization model."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
import zipfile
|
||||
|
@ -103,6 +104,37 @@ class FaceStylizer(object):
|
|||
face_stylizer._create_and_train_model(train_data)
|
||||
return face_stylizer
|
||||
|
||||
def stylize(
|
||||
self, data: classification_ds.ClassificationDataset
|
||||
) -> classification_ds.ClassificationDataset:
|
||||
"""Stylizes the images represented by the input dataset.
|
||||
|
||||
Args:
|
||||
data: Dataset of input images, can contain multiple images.
|
||||
|
||||
Returns:
|
||||
A dataset contains the stylized images
|
||||
"""
|
||||
input_dataset = data.gen_tf_dataset(preprocess=self._preprocessor)
|
||||
output_img_list = []
|
||||
for sample in input_dataset:
|
||||
image = sample[0]
|
||||
w = self._encoder(image, training=True)
|
||||
x = self._decoder({'inputs': w + self.w_avg}, training=True)
|
||||
output_batch = x['image'][-1]
|
||||
output_img_tensor = (tf.squeeze(output_batch).numpy() + 1.0) * 127.5
|
||||
output_img_list.append(output_img_tensor)
|
||||
|
||||
image_ds = tf.data.Dataset.from_tensor_slices(output_img_list)
|
||||
|
||||
logging.info('Stylized %s images.', len(output_img_list))
|
||||
|
||||
return classification_ds.ClassificationDataset(
|
||||
dataset=image_ds,
|
||||
label_names=['stylized'],
|
||||
size=len(output_img_list),
|
||||
)
|
||||
|
||||
def _create_and_train_model(
|
||||
self, train_data: classification_ds.ClassificationDataset
|
||||
):
|
||||
|
|
|
@ -33,6 +33,15 @@ class FaceStylizerTest(tf.test.TestCase):
|
|||
data = face_stylizer.Dataset.from_image(filename=input_style_image_file)
|
||||
return data
|
||||
|
||||
def _create_eval_dataset(self):
|
||||
"""Create evaluation dataset."""
|
||||
input_test_image_file = test_utils.get_test_data_path(
|
||||
'input/raw/face/portrait.jpg'
|
||||
)
|
||||
|
||||
data = face_stylizer.Dataset.from_image(filename=input_test_image_file)
|
||||
return data
|
||||
|
||||
def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer):
|
||||
"""Evaluates the fine-tuned face stylizer model."""
|
||||
test_image = tf.ones(shape=(256, 256, 3), dtype=tf.float32)
|
||||
|
@ -44,6 +53,7 @@ class FaceStylizerTest(tf.test.TestCase):
|
|||
def setUp(self):
|
||||
super().setUp()
|
||||
self._train_data = self._create_training_dataset()
|
||||
self._eval_data = self._create_eval_dataset()
|
||||
|
||||
def test_finetuning_face_stylizer_with_single_input_style_image(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
|
@ -56,6 +66,21 @@ class FaceStylizerTest(tf.test.TestCase):
|
|||
)
|
||||
self._evaluate_saved_model(model)
|
||||
|
||||
def test_evaluate_face_stylizer(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
face_stylizer_options = face_stylizer.FaceStylizerOptions(
|
||||
model=face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256,
|
||||
hparams=face_stylizer.HParams(epochs=1),
|
||||
)
|
||||
model = face_stylizer.FaceStylizer.create(
|
||||
train_data=self._train_data, options=face_stylizer_options
|
||||
)
|
||||
eval_output = model.stylize(self._eval_data)
|
||||
self.assertLen(eval_output, 1)
|
||||
eval_output_data = eval_output.gen_tf_dataset()
|
||||
iterator = iter(eval_output_data)
|
||||
self.assertEqual(iterator.get_next().shape, (1, 256, 256, 3))
|
||||
|
||||
def test_export_face_stylizer_tflite_model(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
model_enum = face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256
|
||||
|
|
Loading…
Reference in New Issue
Block a user