diff --git a/mediapipe/model_maker/python/core/utils/test_util.py b/mediapipe/model_maker/python/core/utils/test_util.py index eda8facc2..72fb229c7 100644 --- a/mediapipe/model_maker/python/core/utils/test_util.py +++ b/mediapipe/model_maker/python/core/utils/test_util.py @@ -16,7 +16,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from typing import List, Union +from typing import Sequence +from typing import Dict, List, Union # Dependency imports @@ -94,6 +95,17 @@ def is_same_output(tflite_model: bytearray, return np.allclose(lite_output, keras_output, atol=atol) +def run_tflite( + tflite_filename: str, + input_tensors: Union[List[tf.Tensor], Dict[str, tf.Tensor]], +) -> Union[Sequence[tf.Tensor], tf.Tensor]: + """Runs TFLite model inference.""" + with tf.io.gfile.GFile(tflite_filename, "rb") as f: + tflite_model = f.read() + lite_runner = model_util.get_lite_runner(tflite_model) + return lite_runner.run(input_tensors) + + def test_tflite(keras_model: tf.keras.Model, tflite_model: bytearray, size: Union[int, List[int]], diff --git a/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py index 2c0c1057c..85b567ca3 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py @@ -221,7 +221,10 @@ class FaceStylizer(object): inputs = tf.keras.Input(shape=(256, 256, 3)) x = self._encoder(inputs) x = self._decoder({'inputs': x + self.w_avg}) - outputs = x['image'][-1] + x = x['image'][-1] + # Scale the data range from [-1, 1] to [0, 1] to support running inference + # on both CPU and GPU. + outputs = (x + 1.0) / 2.0 model = tf.keras.Model(inputs=inputs, outputs=outputs) tflite_model = model_util.convert_to_tflite( diff --git a/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_test.py b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_test.py index 8a3023269..16b314c8e 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_test.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_test.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import tensorflow as tf +from mediapipe.model_maker.python.core.utils import test_util as mm_test_util from mediapipe.model_maker.python.vision import face_stylizer from mediapipe.tasks.python.test import test_utils @@ -52,14 +55,28 @@ class FaceStylizerTest(tf.test.TestCase): def test_export_face_stylizer_tflite_model(self): with self.test_session(use_gpu=True): + model_enum = face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256 face_stylizer_options = face_stylizer.FaceStylizerOptions( - model=face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256, - hparams=face_stylizer.HParams(epochs=0), + model=model_enum, + hparams=face_stylizer.HParams( + epochs=0, export_dir=self.get_temp_dir() + ), ) model = face_stylizer.FaceStylizer.create( train_data=self._train_data, options=face_stylizer_options ) - model.export_model() + tflite_model_name = 'custom_face_stylizer.tflite' + model.export_model(model_name=tflite_model_name) + face_stylizer_tflite_file = os.path.join( + self.get_temp_dir(), tflite_model_name + ) + spec = face_stylizer.SupportedModels.get(model_enum) + input_image_shape = spec.input_image_shape + input_tensor_shape = [1] + list(input_image_shape) + [3] + input_tensor = mm_test_util.create_random_sample(size=input_tensor_shape) + output = mm_test_util.run_tflite(face_stylizer_tflite_file, input_tensor) + self.assertTrue((output >= 0.0).all()) + self.assertTrue((output <= 1.0).all()) if __name__ == '__main__':