Improve quantization support in model_maker/image_classifier
PiperOrigin-RevId: 480455944
This commit is contained in:
parent
f4fd1063a7
commit
1b611c66bb
|
@ -94,7 +94,8 @@ def export_tflite(
|
||||||
tflite_filepath: str,
|
tflite_filepath: str,
|
||||||
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||||
supported_ops: Tuple[tf.lite.OpsSet,
|
supported_ops: Tuple[tf.lite.OpsSet,
|
||||||
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,)):
|
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
|
||||||
|
preprocess: Optional[Callable[..., bool]] = None):
|
||||||
"""Converts the model to tflite format and saves it.
|
"""Converts the model to tflite format and saves it.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -102,6 +103,9 @@ def export_tflite(
|
||||||
tflite_filepath: File path to save tflite model.
|
tflite_filepath: File path to save tflite model.
|
||||||
quantization_config: Configuration for post-training quantization.
|
quantization_config: Configuration for post-training quantization.
|
||||||
supported_ops: A list of supported ops in the converted TFLite file.
|
supported_ops: A list of supported ops in the converted TFLite file.
|
||||||
|
preprocess: A callable to preprocess the representative dataset for
|
||||||
|
quantization. The callable takes three arguments in order: feature,
|
||||||
|
label, and is_training.
|
||||||
"""
|
"""
|
||||||
if tflite_filepath is None:
|
if tflite_filepath is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -113,7 +117,8 @@ def export_tflite(
|
||||||
converter = tf.lite.TFLiteConverter.from_saved_model(save_path)
|
converter = tf.lite.TFLiteConverter.from_saved_model(save_path)
|
||||||
|
|
||||||
if quantization_config:
|
if quantization_config:
|
||||||
converter = quantization_config.set_converter_with_quantization(converter)
|
converter = quantization_config.set_converter_with_quantization(
|
||||||
|
converter, preprocess=preprocess)
|
||||||
|
|
||||||
converter.target_spec.supported_ops = supported_ops
|
converter.target_spec.supported_ops = supported_ops
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user