Improve quantization support in model_maker/image_classifier

PiperOrigin-RevId: 480455944
This commit is contained in:
MediaPipe Team 2022-10-11 14:38:14 -07:00 committed by Copybara-Service
parent f4fd1063a7
commit 1b611c66bb

View File

@ -94,7 +94,8 @@ def export_tflite(
tflite_filepath: str, tflite_filepath: str,
quantization_config: Optional[quantization.QuantizationConfig] = None, quantization_config: Optional[quantization.QuantizationConfig] = None,
supported_ops: Tuple[tf.lite.OpsSet, supported_ops: Tuple[tf.lite.OpsSet,
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,)): ...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
preprocess: Optional[Callable[..., bool]] = None):
"""Converts the model to tflite format and saves it. """Converts the model to tflite format and saves it.
Args: Args:
@ -102,6 +103,9 @@ def export_tflite(
tflite_filepath: File path to save tflite model. tflite_filepath: File path to save tflite model.
quantization_config: Configuration for post-training quantization. quantization_config: Configuration for post-training quantization.
supported_ops: A list of supported ops in the converted TFLite file. supported_ops: A list of supported ops in the converted TFLite file.
preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature,
label, and is_training.
""" """
if tflite_filepath is None: if tflite_filepath is None:
raise ValueError( raise ValueError(
@ -113,7 +117,8 @@ def export_tflite(
converter = tf.lite.TFLiteConverter.from_saved_model(save_path) converter = tf.lite.TFLiteConverter.from_saved_model(save_path)
if quantization_config: if quantization_config:
converter = quantization_config.set_converter_with_quantization(converter) converter = quantization_config.set_converter_with_quantization(
converter, preprocess=preprocess)
converter.target_spec.supported_ops = supported_ops converter.target_spec.supported_ops = supported_ops
tflite_model = converter.convert() tflite_model = converter.convert()