Internal Changes
PiperOrigin-RevId: 557563669
This commit is contained in:
parent
9e45e2b6e9
commit
13bb65db96
|
@ -112,6 +112,39 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
|
|||
return len(train_data) // batch_size
|
||||
|
||||
|
||||
def convert_to_tflite_from_file(
|
||||
saved_model_file: str,
|
||||
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||
supported_ops: Tuple[tf.lite.OpsSet, ...] = (
|
||||
tf.lite.OpsSet.TFLITE_BUILTINS,
|
||||
),
|
||||
preprocess: Optional[Callable[..., Any]] = None,
|
||||
) -> bytearray:
|
||||
"""Converts the input Keras model to TFLite format.
|
||||
|
||||
Args:
|
||||
saved_model_file: Keras model to be converted to TFLite.
|
||||
quantization_config: Configuration for post-training quantization.
|
||||
supported_ops: A list of supported ops in the converted TFLite file.
|
||||
preprocess: A callable to preprocess the representative dataset for
|
||||
quantization. The callable takes three arguments in order: feature, label,
|
||||
and is_training.
|
||||
|
||||
Returns:
|
||||
bytearray of TFLite model
|
||||
"""
|
||||
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_file)
|
||||
|
||||
if quantization_config:
|
||||
converter = quantization_config.set_converter_with_quantization(
|
||||
converter, preprocess=preprocess
|
||||
)
|
||||
|
||||
converter.target_spec.supported_ops = supported_ops
|
||||
tflite_model = converter.convert()
|
||||
return tflite_model
|
||||
|
||||
|
||||
def convert_to_tflite(
|
||||
model: tf.keras.Model,
|
||||
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||
|
@ -135,16 +168,14 @@ def convert_to_tflite(
|
|||
"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
save_path = os.path.join(temp_dir, 'saved_model')
|
||||
model.save(save_path, include_optimizer=False, save_format='tf')
|
||||
converter = tf.lite.TFLiteConverter.from_saved_model(save_path)
|
||||
|
||||
if quantization_config:
|
||||
converter = quantization_config.set_converter_with_quantization(
|
||||
converter, preprocess=preprocess)
|
||||
|
||||
converter.target_spec.supported_ops = supported_ops
|
||||
tflite_model = converter.convert()
|
||||
return tflite_model
|
||||
model.save(
|
||||
save_path,
|
||||
include_optimizer=False,
|
||||
save_format='tf',
|
||||
)
|
||||
return convert_to_tflite_from_file(
|
||||
save_path, quantization_config, supported_ops, preprocess
|
||||
)
|
||||
|
||||
|
||||
def save_tflite(tflite_model: bytearray, tflite_file: str) -> None:
|
||||
|
|
|
@ -169,6 +169,25 @@ class TextClassifier(classifier.Classifier):
|
|||
with self._hparams.get_strategy().scope():
|
||||
return self._model.evaluate(dataset)
|
||||
|
||||
def save_model(
|
||||
self,
|
||||
model_name: str = "saved_model",
|
||||
):
|
||||
"""Saves the model in SavedModel format.
|
||||
|
||||
For more information, see https://www.tensorflow.org/guide/saved_model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the saved model.
|
||||
"""
|
||||
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||
saved_model_file = os.path.join(self._hparams.export_dir, model_name)
|
||||
self._model.save(
|
||||
saved_model_file,
|
||||
include_optimizer=False,
|
||||
save_format="tf",
|
||||
)
|
||||
|
||||
def export_model(
|
||||
self,
|
||||
model_name: str = "model.tflite",
|
||||
|
@ -184,12 +203,16 @@ class TextClassifier(classifier.Classifier):
|
|||
path is {self._hparams.export_dir}/{model_name}.
|
||||
quantization_config: The configuration for model quantization.
|
||||
"""
|
||||
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
||||
tf.io.gfile.makedirs(os.path.dirname(tflite_file))
|
||||
metadata_file = os.path.join(self._hparams.export_dir, "metadata.json")
|
||||
|
||||
tflite_model = model_util.convert_to_tflite(
|
||||
model=self._model, quantization_config=quantization_config)
|
||||
self.save_model(model_name="saved_model")
|
||||
saved_model_file = os.path.join(self._hparams.export_dir, "saved_model")
|
||||
|
||||
tflite_model = model_util.convert_to_tflite_from_file(
|
||||
saved_model_file, quantization_config=quantization_config
|
||||
)
|
||||
vocab_filepath = os.path.join(tempfile.mkdtemp(), "vocab.txt")
|
||||
self._save_vocab(vocab_filepath)
|
||||
|
||||
|
@ -494,6 +517,9 @@ class _BertClassifier(TextClassifier):
|
|||
encoder = hub.KerasLayer(
|
||||
self._model_spec.get_path(),
|
||||
trainable=self._model_options.do_fine_tuning,
|
||||
load_options=tf.saved_model.LoadOptions(
|
||||
experimental_io_device="/job:localhost"
|
||||
),
|
||||
)
|
||||
encoder_outputs = encoder(encoder_inputs)
|
||||
pooled_output = encoder_outputs["pooled_output"]
|
||||
|
@ -512,16 +538,18 @@ class _BertClassifier(TextClassifier):
|
|||
pooled_output = encoder(renamed_inputs)
|
||||
|
||||
output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)(
|
||||
pooled_output)
|
||||
pooled_output
|
||||
)
|
||||
initializer = tf.keras.initializers.TruncatedNormal(
|
||||
stddev=self._INITIALIZER_RANGE)
|
||||
stddev=self._INITIALIZER_RANGE
|
||||
)
|
||||
output = tf.keras.layers.Dense(
|
||||
self._num_classes,
|
||||
kernel_initializer=initializer,
|
||||
name="output",
|
||||
activation="softmax",
|
||||
dtype=tf.float32)(
|
||||
output)
|
||||
dtype=tf.float32,
|
||||
)(output)
|
||||
self._model = tf.keras.Model(inputs=encoder_inputs, outputs=output)
|
||||
|
||||
def _create_optimizer(self, train_data: text_ds.Dataset):
|
||||
|
|
Loading…
Reference in New Issue
Block a user