Add an API to run inference with face stylizer TF model.
PiperOrigin-RevId: 558926645
This commit is contained in:
parent
bbf168ddda
commit
7f8150776a
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""APIs to train face stylization model."""
|
"""APIs to train face stylization model."""
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
import zipfile
|
import zipfile
|
||||||
|
@ -103,6 +104,37 @@ class FaceStylizer(object):
|
||||||
face_stylizer._create_and_train_model(train_data)
|
face_stylizer._create_and_train_model(train_data)
|
||||||
return face_stylizer
|
return face_stylizer
|
||||||
|
|
||||||
|
def stylize(
|
||||||
|
self, data: classification_ds.ClassificationDataset
|
||||||
|
) -> classification_ds.ClassificationDataset:
|
||||||
|
"""Stylizes the images represented by the input dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dataset of input images, can contain multiple images.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dataset contains the stylized images
|
||||||
|
"""
|
||||||
|
input_dataset = data.gen_tf_dataset(preprocess=self._preprocessor)
|
||||||
|
output_img_list = []
|
||||||
|
for sample in input_dataset:
|
||||||
|
image = sample[0]
|
||||||
|
w = self._encoder(image, training=True)
|
||||||
|
x = self._decoder({'inputs': w + self.w_avg}, training=True)
|
||||||
|
output_batch = x['image'][-1]
|
||||||
|
output_img_tensor = (tf.squeeze(output_batch).numpy() + 1.0) * 127.5
|
||||||
|
output_img_list.append(output_img_tensor)
|
||||||
|
|
||||||
|
image_ds = tf.data.Dataset.from_tensor_slices(output_img_list)
|
||||||
|
|
||||||
|
logging.info('Stylized %s images.', len(output_img_list))
|
||||||
|
|
||||||
|
return classification_ds.ClassificationDataset(
|
||||||
|
dataset=image_ds,
|
||||||
|
label_names=['stylized'],
|
||||||
|
size=len(output_img_list),
|
||||||
|
)
|
||||||
|
|
||||||
def _create_and_train_model(
|
def _create_and_train_model(
|
||||||
self, train_data: classification_ds.ClassificationDataset
|
self, train_data: classification_ds.ClassificationDataset
|
||||||
):
|
):
|
||||||
|
|
|
@ -33,6 +33,15 @@ class FaceStylizerTest(tf.test.TestCase):
|
||||||
data = face_stylizer.Dataset.from_image(filename=input_style_image_file)
|
data = face_stylizer.Dataset.from_image(filename=input_style_image_file)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def _create_eval_dataset(self):
|
||||||
|
"""Create evaluation dataset."""
|
||||||
|
input_test_image_file = test_utils.get_test_data_path(
|
||||||
|
'input/raw/face/portrait.jpg'
|
||||||
|
)
|
||||||
|
|
||||||
|
data = face_stylizer.Dataset.from_image(filename=input_test_image_file)
|
||||||
|
return data
|
||||||
|
|
||||||
def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer):
|
def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer):
|
||||||
"""Evaluates the fine-tuned face stylizer model."""
|
"""Evaluates the fine-tuned face stylizer model."""
|
||||||
test_image = tf.ones(shape=(256, 256, 3), dtype=tf.float32)
|
test_image = tf.ones(shape=(256, 256, 3), dtype=tf.float32)
|
||||||
|
@ -44,6 +53,7 @@ class FaceStylizerTest(tf.test.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self._train_data = self._create_training_dataset()
|
self._train_data = self._create_training_dataset()
|
||||||
|
self._eval_data = self._create_eval_dataset()
|
||||||
|
|
||||||
def test_finetuning_face_stylizer_with_single_input_style_image(self):
|
def test_finetuning_face_stylizer_with_single_input_style_image(self):
|
||||||
with self.test_session(use_gpu=True):
|
with self.test_session(use_gpu=True):
|
||||||
|
@ -56,6 +66,21 @@ class FaceStylizerTest(tf.test.TestCase):
|
||||||
)
|
)
|
||||||
self._evaluate_saved_model(model)
|
self._evaluate_saved_model(model)
|
||||||
|
|
||||||
|
def test_evaluate_face_stylizer(self):
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
face_stylizer_options = face_stylizer.FaceStylizerOptions(
|
||||||
|
model=face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256,
|
||||||
|
hparams=face_stylizer.HParams(epochs=1),
|
||||||
|
)
|
||||||
|
model = face_stylizer.FaceStylizer.create(
|
||||||
|
train_data=self._train_data, options=face_stylizer_options
|
||||||
|
)
|
||||||
|
eval_output = model.stylize(self._eval_data)
|
||||||
|
self.assertLen(eval_output, 1)
|
||||||
|
eval_output_data = eval_output.gen_tf_dataset()
|
||||||
|
iterator = iter(eval_output_data)
|
||||||
|
self.assertEqual(iterator.get_next().shape, (1, 256, 256, 3))
|
||||||
|
|
||||||
def test_export_face_stylizer_tflite_model(self):
|
def test_export_face_stylizer_tflite_model(self):
|
||||||
with self.test_session(use_gpu=True):
|
with self.test_session(use_gpu=True):
|
||||||
model_enum = face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256
|
model_enum = face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256
|
||||||
|
|
Loading…
Reference in New Issue
Block a user