Add an extra op to rescale face stylizer generation output from [-1, 1] to [0, 1].
This conversion is to support running the model on both GPU and CPU. PiperOrigin-RevId: 528400297
This commit is contained in:
parent
80b19fff4b
commit
ad4ae6559b
|
@ -16,7 +16,8 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from typing import List, Union
|
from typing import Sequence
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
# Dependency imports
|
# Dependency imports
|
||||||
|
|
||||||
|
@ -94,6 +95,17 @@ def is_same_output(tflite_model: bytearray,
|
||||||
return np.allclose(lite_output, keras_output, atol=atol)
|
return np.allclose(lite_output, keras_output, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
def run_tflite(
|
||||||
|
tflite_filename: str,
|
||||||
|
input_tensors: Union[List[tf.Tensor], Dict[str, tf.Tensor]],
|
||||||
|
) -> Union[Sequence[tf.Tensor], tf.Tensor]:
|
||||||
|
"""Runs TFLite model inference."""
|
||||||
|
with tf.io.gfile.GFile(tflite_filename, "rb") as f:
|
||||||
|
tflite_model = f.read()
|
||||||
|
lite_runner = model_util.get_lite_runner(tflite_model)
|
||||||
|
return lite_runner.run(input_tensors)
|
||||||
|
|
||||||
|
|
||||||
def test_tflite(keras_model: tf.keras.Model,
|
def test_tflite(keras_model: tf.keras.Model,
|
||||||
tflite_model: bytearray,
|
tflite_model: bytearray,
|
||||||
size: Union[int, List[int]],
|
size: Union[int, List[int]],
|
||||||
|
|
|
@ -221,7 +221,10 @@ class FaceStylizer(object):
|
||||||
inputs = tf.keras.Input(shape=(256, 256, 3))
|
inputs = tf.keras.Input(shape=(256, 256, 3))
|
||||||
x = self._encoder(inputs)
|
x = self._encoder(inputs)
|
||||||
x = self._decoder({'inputs': x + self.w_avg})
|
x = self._decoder({'inputs': x + self.w_avg})
|
||||||
outputs = x['image'][-1]
|
x = x['image'][-1]
|
||||||
|
# Scale the data range from [-1, 1] to [0, 1] to support running inference
|
||||||
|
# on both CPU and GPU.
|
||||||
|
outputs = (x + 1.0) / 2.0
|
||||||
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
|
||||||
tflite_model = model_util.convert_to_tflite(
|
tflite_model = model_util.convert_to_tflite(
|
||||||
|
|
|
@ -12,8 +12,11 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import test_util as mm_test_util
|
||||||
from mediapipe.model_maker.python.vision import face_stylizer
|
from mediapipe.model_maker.python.vision import face_stylizer
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
@ -52,14 +55,28 @@ class FaceStylizerTest(tf.test.TestCase):
|
||||||
|
|
||||||
def test_export_face_stylizer_tflite_model(self):
|
def test_export_face_stylizer_tflite_model(self):
|
||||||
with self.test_session(use_gpu=True):
|
with self.test_session(use_gpu=True):
|
||||||
|
model_enum = face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256
|
||||||
face_stylizer_options = face_stylizer.FaceStylizerOptions(
|
face_stylizer_options = face_stylizer.FaceStylizerOptions(
|
||||||
model=face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256,
|
model=model_enum,
|
||||||
hparams=face_stylizer.HParams(epochs=0),
|
hparams=face_stylizer.HParams(
|
||||||
|
epochs=0, export_dir=self.get_temp_dir()
|
||||||
|
),
|
||||||
)
|
)
|
||||||
model = face_stylizer.FaceStylizer.create(
|
model = face_stylizer.FaceStylizer.create(
|
||||||
train_data=self._train_data, options=face_stylizer_options
|
train_data=self._train_data, options=face_stylizer_options
|
||||||
)
|
)
|
||||||
model.export_model()
|
tflite_model_name = 'custom_face_stylizer.tflite'
|
||||||
|
model.export_model(model_name=tflite_model_name)
|
||||||
|
face_stylizer_tflite_file = os.path.join(
|
||||||
|
self.get_temp_dir(), tflite_model_name
|
||||||
|
)
|
||||||
|
spec = face_stylizer.SupportedModels.get(model_enum)
|
||||||
|
input_image_shape = spec.input_image_shape
|
||||||
|
input_tensor_shape = [1] + list(input_image_shape) + [3]
|
||||||
|
input_tensor = mm_test_util.create_random_sample(size=input_tensor_shape)
|
||||||
|
output = mm_test_util.run_tflite(face_stylizer_tflite_file, input_tensor)
|
||||||
|
self.assertTrue((output >= 0.0).all())
|
||||||
|
self.assertTrue((output <= 1.0).all())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user