Add an extra op to rescale face stylizer generation output from [-1, 1] to [0, 1].

This conversion is to support running the model on both GPU and CPU.

PiperOrigin-RevId: 528400297
This commit is contained in:
MediaPipe Team 2023-04-30 23:11:14 -07:00 committed by Copybara-Service
parent 80b19fff4b
commit ad4ae6559b
3 changed files with 37 additions and 5 deletions

View File

@ -16,7 +16,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from typing import List, Union from typing import Sequence
from typing import Dict, List, Union
# Dependency imports # Dependency imports
@ -94,6 +95,17 @@ def is_same_output(tflite_model: bytearray,
return np.allclose(lite_output, keras_output, atol=atol) return np.allclose(lite_output, keras_output, atol=atol)
def run_tflite(
tflite_filename: str,
input_tensors: Union[List[tf.Tensor], Dict[str, tf.Tensor]],
) -> Union[Sequence[tf.Tensor], tf.Tensor]:
"""Runs TFLite model inference."""
with tf.io.gfile.GFile(tflite_filename, "rb") as f:
tflite_model = f.read()
lite_runner = model_util.get_lite_runner(tflite_model)
return lite_runner.run(input_tensors)
def test_tflite(keras_model: tf.keras.Model, def test_tflite(keras_model: tf.keras.Model,
tflite_model: bytearray, tflite_model: bytearray,
size: Union[int, List[int]], size: Union[int, List[int]],

View File

@ -221,7 +221,10 @@ class FaceStylizer(object):
inputs = tf.keras.Input(shape=(256, 256, 3)) inputs = tf.keras.Input(shape=(256, 256, 3))
x = self._encoder(inputs) x = self._encoder(inputs)
x = self._decoder({'inputs': x + self.w_avg}) x = self._decoder({'inputs': x + self.w_avg})
outputs = x['image'][-1] x = x['image'][-1]
# Scale the data range from [-1, 1] to [0, 1] to support running inference
# on both CPU and GPU.
outputs = (x + 1.0) / 2.0
model = tf.keras.Model(inputs=inputs, outputs=outputs) model = tf.keras.Model(inputs=inputs, outputs=outputs)
tflite_model = model_util.convert_to_tflite( tflite_model = model_util.convert_to_tflite(

View File

@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.utils import test_util as mm_test_util
from mediapipe.model_maker.python.vision import face_stylizer from mediapipe.model_maker.python.vision import face_stylizer
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
@ -52,14 +55,28 @@ class FaceStylizerTest(tf.test.TestCase):
def test_export_face_stylizer_tflite_model(self): def test_export_face_stylizer_tflite_model(self):
with self.test_session(use_gpu=True): with self.test_session(use_gpu=True):
model_enum = face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256
face_stylizer_options = face_stylizer.FaceStylizerOptions( face_stylizer_options = face_stylizer.FaceStylizerOptions(
model=face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256, model=model_enum,
hparams=face_stylizer.HParams(epochs=0), hparams=face_stylizer.HParams(
epochs=0, export_dir=self.get_temp_dir()
),
) )
model = face_stylizer.FaceStylizer.create( model = face_stylizer.FaceStylizer.create(
train_data=self._train_data, options=face_stylizer_options train_data=self._train_data, options=face_stylizer_options
) )
model.export_model() tflite_model_name = 'custom_face_stylizer.tflite'
model.export_model(model_name=tflite_model_name)
face_stylizer_tflite_file = os.path.join(
self.get_temp_dir(), tflite_model_name
)
spec = face_stylizer.SupportedModels.get(model_enum)
input_image_shape = spec.input_image_shape
input_tensor_shape = [1] + list(input_image_shape) + [3]
input_tensor = mm_test_util.create_random_sample(size=input_tensor_shape)
output = mm_test_util.run_tflite(face_stylizer_tflite_file, input_tensor)
self.assertTrue((output >= 0.0).all())
self.assertTrue((output <= 1.0).all())
if __name__ == '__main__': if __name__ == '__main__':