Internal Changes
PiperOrigin-RevId: 557563669
This commit is contained in:
parent
9e45e2b6e9
commit
13bb65db96
|
@ -112,6 +112,39 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
|
||||||
return len(train_data) // batch_size
|
return len(train_data) // batch_size
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_tflite_from_file(
|
||||||
|
saved_model_file: str,
|
||||||
|
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||||
|
supported_ops: Tuple[tf.lite.OpsSet, ...] = (
|
||||||
|
tf.lite.OpsSet.TFLITE_BUILTINS,
|
||||||
|
),
|
||||||
|
preprocess: Optional[Callable[..., Any]] = None,
|
||||||
|
) -> bytearray:
|
||||||
|
"""Converts the input Keras model to TFLite format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
saved_model_file: Keras model to be converted to TFLite.
|
||||||
|
quantization_config: Configuration for post-training quantization.
|
||||||
|
supported_ops: A list of supported ops in the converted TFLite file.
|
||||||
|
preprocess: A callable to preprocess the representative dataset for
|
||||||
|
quantization. The callable takes three arguments in order: feature, label,
|
||||||
|
and is_training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytearray of TFLite model
|
||||||
|
"""
|
||||||
|
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_file)
|
||||||
|
|
||||||
|
if quantization_config:
|
||||||
|
converter = quantization_config.set_converter_with_quantization(
|
||||||
|
converter, preprocess=preprocess
|
||||||
|
)
|
||||||
|
|
||||||
|
converter.target_spec.supported_ops = supported_ops
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
return tflite_model
|
||||||
|
|
||||||
|
|
||||||
def convert_to_tflite(
|
def convert_to_tflite(
|
||||||
model: tf.keras.Model,
|
model: tf.keras.Model,
|
||||||
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||||
|
@ -135,16 +168,14 @@ def convert_to_tflite(
|
||||||
"""
|
"""
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
save_path = os.path.join(temp_dir, 'saved_model')
|
save_path = os.path.join(temp_dir, 'saved_model')
|
||||||
model.save(save_path, include_optimizer=False, save_format='tf')
|
model.save(
|
||||||
converter = tf.lite.TFLiteConverter.from_saved_model(save_path)
|
save_path,
|
||||||
|
include_optimizer=False,
|
||||||
if quantization_config:
|
save_format='tf',
|
||||||
converter = quantization_config.set_converter_with_quantization(
|
)
|
||||||
converter, preprocess=preprocess)
|
return convert_to_tflite_from_file(
|
||||||
|
save_path, quantization_config, supported_ops, preprocess
|
||||||
converter.target_spec.supported_ops = supported_ops
|
)
|
||||||
tflite_model = converter.convert()
|
|
||||||
return tflite_model
|
|
||||||
|
|
||||||
|
|
||||||
def save_tflite(tflite_model: bytearray, tflite_file: str) -> None:
|
def save_tflite(tflite_model: bytearray, tflite_file: str) -> None:
|
||||||
|
|
|
@ -169,6 +169,25 @@ class TextClassifier(classifier.Classifier):
|
||||||
with self._hparams.get_strategy().scope():
|
with self._hparams.get_strategy().scope():
|
||||||
return self._model.evaluate(dataset)
|
return self._model.evaluate(dataset)
|
||||||
|
|
||||||
|
def save_model(
|
||||||
|
self,
|
||||||
|
model_name: str = "saved_model",
|
||||||
|
):
|
||||||
|
"""Saves the model in SavedModel format.
|
||||||
|
|
||||||
|
For more information, see https://www.tensorflow.org/guide/saved_model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the saved model.
|
||||||
|
"""
|
||||||
|
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||||
|
saved_model_file = os.path.join(self._hparams.export_dir, model_name)
|
||||||
|
self._model.save(
|
||||||
|
saved_model_file,
|
||||||
|
include_optimizer=False,
|
||||||
|
save_format="tf",
|
||||||
|
)
|
||||||
|
|
||||||
def export_model(
|
def export_model(
|
||||||
self,
|
self,
|
||||||
model_name: str = "model.tflite",
|
model_name: str = "model.tflite",
|
||||||
|
@ -184,12 +203,16 @@ class TextClassifier(classifier.Classifier):
|
||||||
path is {self._hparams.export_dir}/{model_name}.
|
path is {self._hparams.export_dir}/{model_name}.
|
||||||
quantization_config: The configuration for model quantization.
|
quantization_config: The configuration for model quantization.
|
||||||
"""
|
"""
|
||||||
|
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||||
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
||||||
tf.io.gfile.makedirs(os.path.dirname(tflite_file))
|
|
||||||
metadata_file = os.path.join(self._hparams.export_dir, "metadata.json")
|
metadata_file = os.path.join(self._hparams.export_dir, "metadata.json")
|
||||||
|
|
||||||
tflite_model = model_util.convert_to_tflite(
|
self.save_model(model_name="saved_model")
|
||||||
model=self._model, quantization_config=quantization_config)
|
saved_model_file = os.path.join(self._hparams.export_dir, "saved_model")
|
||||||
|
|
||||||
|
tflite_model = model_util.convert_to_tflite_from_file(
|
||||||
|
saved_model_file, quantization_config=quantization_config
|
||||||
|
)
|
||||||
vocab_filepath = os.path.join(tempfile.mkdtemp(), "vocab.txt")
|
vocab_filepath = os.path.join(tempfile.mkdtemp(), "vocab.txt")
|
||||||
self._save_vocab(vocab_filepath)
|
self._save_vocab(vocab_filepath)
|
||||||
|
|
||||||
|
@ -494,6 +517,9 @@ class _BertClassifier(TextClassifier):
|
||||||
encoder = hub.KerasLayer(
|
encoder = hub.KerasLayer(
|
||||||
self._model_spec.get_path(),
|
self._model_spec.get_path(),
|
||||||
trainable=self._model_options.do_fine_tuning,
|
trainable=self._model_options.do_fine_tuning,
|
||||||
|
load_options=tf.saved_model.LoadOptions(
|
||||||
|
experimental_io_device="/job:localhost"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
encoder_outputs = encoder(encoder_inputs)
|
encoder_outputs = encoder(encoder_inputs)
|
||||||
pooled_output = encoder_outputs["pooled_output"]
|
pooled_output = encoder_outputs["pooled_output"]
|
||||||
|
@ -512,16 +538,18 @@ class _BertClassifier(TextClassifier):
|
||||||
pooled_output = encoder(renamed_inputs)
|
pooled_output = encoder(renamed_inputs)
|
||||||
|
|
||||||
output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)(
|
output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)(
|
||||||
pooled_output)
|
pooled_output
|
||||||
|
)
|
||||||
initializer = tf.keras.initializers.TruncatedNormal(
|
initializer = tf.keras.initializers.TruncatedNormal(
|
||||||
stddev=self._INITIALIZER_RANGE)
|
stddev=self._INITIALIZER_RANGE
|
||||||
|
)
|
||||||
output = tf.keras.layers.Dense(
|
output = tf.keras.layers.Dense(
|
||||||
self._num_classes,
|
self._num_classes,
|
||||||
kernel_initializer=initializer,
|
kernel_initializer=initializer,
|
||||||
name="output",
|
name="output",
|
||||||
activation="softmax",
|
activation="softmax",
|
||||||
dtype=tf.float32)(
|
dtype=tf.float32,
|
||||||
output)
|
)(output)
|
||||||
self._model = tf.keras.Model(inputs=encoder_inputs, outputs=output)
|
self._model = tf.keras.Model(inputs=encoder_inputs, outputs=output)
|
||||||
|
|
||||||
def _create_optimizer(self, train_data: text_ds.Dataset):
|
def _create_optimizer(self, train_data: text_ds.Dataset):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user