Add allow_custom_ops to model_util.convert_to_tflite and enable custom ops for face stylizer.

PiperOrigin-RevId: 562212965
This commit is contained in:
MediaPipe Team 2023-09-02 09:54:10 -07:00 committed by Copybara-Service
parent 2b5e281c27
commit cac462c486
2 changed files with 14 additions and 5 deletions

View File

@ -119,6 +119,7 @@ def convert_to_tflite_from_file(
tf.lite.OpsSet.TFLITE_BUILTINS,
),
preprocess: Optional[Callable[..., Any]] = None,
allow_custom_ops: bool = False,
) -> bytearray:
"""Converts the input Keras model to TFLite format.
@ -129,6 +130,8 @@ def convert_to_tflite_from_file(
preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature, label,
and is_training.
allow_custom_ops: A boolean flag to enable custom ops in model convsion.
Default to False.
Returns:
bytearray of TFLite model
@ -140,6 +143,7 @@ def convert_to_tflite_from_file(
converter, preprocess=preprocess
)
converter.allow_custom_ops = allow_custom_ops
converter.target_spec.supported_ops = supported_ops
tflite_model = converter.convert()
return tflite_model
@ -152,6 +156,7 @@ def convert_to_tflite(
tf.lite.OpsSet.TFLITE_BUILTINS,
),
preprocess: Optional[Callable[..., Any]] = None,
allow_custom_ops: bool = False,
) -> bytearray:
"""Converts the input Keras model to TFLite format.
@ -162,6 +167,8 @@ def convert_to_tflite(
preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature, label,
and is_training.
allow_custom_ops: A boolean flag to enable custom ops in model conversion.
Default to False.
Returns:
bytearray of TFLite model
@ -174,7 +181,11 @@ def convert_to_tflite(
save_format='tf',
)
return convert_to_tflite_from_file(
save_path, quantization_config, supported_ops, preprocess
save_path,
quantization_config,
supported_ops,
preprocess,
allow_custom_ops,
)

View File

@ -274,11 +274,9 @@ class FaceStylizer(object):
face_stylizer_model_buffer = model_util.convert_to_tflite(
model=model,
quantization_config=None,
supported_ops=(
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
),
supported_ops=(tf.lite.OpsSet.TFLITE_BUILTINS,),
preprocess=self._preprocessor,
allow_custom_ops=True,
)
face_aligner_task_file_path = constants.FACE_ALIGNER_TASK_FILES.get_path()