Add allow_custom_ops to model_util.convert_to_tflite and enable custom ops for face stylizer.

PiperOrigin-RevId: 562212965
This commit is contained in:
MediaPipe Team 2023-09-02 09:54:10 -07:00 committed by Copybara-Service
parent 2b5e281c27
commit cac462c486
2 changed files with 14 additions and 5 deletions

View File

@ -119,6 +119,7 @@ def convert_to_tflite_from_file(
tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.TFLITE_BUILTINS,
), ),
preprocess: Optional[Callable[..., Any]] = None, preprocess: Optional[Callable[..., Any]] = None,
allow_custom_ops: bool = False,
) -> bytearray: ) -> bytearray:
"""Converts the input Keras model to TFLite format. """Converts the input Keras model to TFLite format.
@ -129,6 +130,8 @@ def convert_to_tflite_from_file(
preprocess: A callable to preprocess the representative dataset for preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature, label, quantization. The callable takes three arguments in order: feature, label,
and is_training. and is_training.
allow_custom_ops: A boolean flag to enable custom ops in model convsion.
Default to False.
Returns: Returns:
bytearray of TFLite model bytearray of TFLite model
@ -140,6 +143,7 @@ def convert_to_tflite_from_file(
converter, preprocess=preprocess converter, preprocess=preprocess
) )
converter.allow_custom_ops = allow_custom_ops
converter.target_spec.supported_ops = supported_ops converter.target_spec.supported_ops = supported_ops
tflite_model = converter.convert() tflite_model = converter.convert()
return tflite_model return tflite_model
@ -152,6 +156,7 @@ def convert_to_tflite(
tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.TFLITE_BUILTINS,
), ),
preprocess: Optional[Callable[..., Any]] = None, preprocess: Optional[Callable[..., Any]] = None,
allow_custom_ops: bool = False,
) -> bytearray: ) -> bytearray:
"""Converts the input Keras model to TFLite format. """Converts the input Keras model to TFLite format.
@ -162,6 +167,8 @@ def convert_to_tflite(
preprocess: A callable to preprocess the representative dataset for preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature, label, quantization. The callable takes three arguments in order: feature, label,
and is_training. and is_training.
allow_custom_ops: A boolean flag to enable custom ops in model conversion.
Default to False.
Returns: Returns:
bytearray of TFLite model bytearray of TFLite model
@ -174,7 +181,11 @@ def convert_to_tflite(
save_format='tf', save_format='tf',
) )
return convert_to_tflite_from_file( return convert_to_tflite_from_file(
save_path, quantization_config, supported_ops, preprocess save_path,
quantization_config,
supported_ops,
preprocess,
allow_custom_ops,
) )

View File

@ -274,11 +274,9 @@ class FaceStylizer(object):
face_stylizer_model_buffer = model_util.convert_to_tflite( face_stylizer_model_buffer = model_util.convert_to_tflite(
model=model, model=model,
quantization_config=None, quantization_config=None,
supported_ops=( supported_ops=(tf.lite.OpsSet.TFLITE_BUILTINS,),
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
),
preprocess=self._preprocessor, preprocess=self._preprocessor,
allow_custom_ops=True,
) )
face_aligner_task_file_path = constants.FACE_ALIGNER_TASK_FILES.get_path() face_aligner_task_file_path = constants.FACE_ALIGNER_TASK_FILES.get_path()