Internal Changes

PiperOrigin-RevId: 557563669
This commit is contained in:
MediaPipe Team 2023-08-16 12:14:05 -07:00 committed by Copybara-Service
parent 9e45e2b6e9
commit 13bb65db96
2 changed files with 76 additions and 17 deletions

View File

@ -112,6 +112,39 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
return len(train_data) // batch_size return len(train_data) // batch_size
def convert_to_tflite_from_file(
saved_model_file: str,
quantization_config: Optional[quantization.QuantizationConfig] = None,
supported_ops: Tuple[tf.lite.OpsSet, ...] = (
tf.lite.OpsSet.TFLITE_BUILTINS,
),
preprocess: Optional[Callable[..., Any]] = None,
) -> bytearray:
"""Converts the input Keras model to TFLite format.
Args:
saved_model_file: Keras model to be converted to TFLite.
quantization_config: Configuration for post-training quantization.
supported_ops: A list of supported ops in the converted TFLite file.
preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature, label,
and is_training.
Returns:
bytearray of TFLite model
"""
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_file)
if quantization_config:
converter = quantization_config.set_converter_with_quantization(
converter, preprocess=preprocess
)
converter.target_spec.supported_ops = supported_ops
tflite_model = converter.convert()
return tflite_model
def convert_to_tflite( def convert_to_tflite(
model: tf.keras.Model, model: tf.keras.Model,
quantization_config: Optional[quantization.QuantizationConfig] = None, quantization_config: Optional[quantization.QuantizationConfig] = None,
@ -135,16 +168,14 @@ def convert_to_tflite(
""" """
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
save_path = os.path.join(temp_dir, 'saved_model') save_path = os.path.join(temp_dir, 'saved_model')
model.save(save_path, include_optimizer=False, save_format='tf') model.save(
converter = tf.lite.TFLiteConverter.from_saved_model(save_path) save_path,
include_optimizer=False,
if quantization_config: save_format='tf',
converter = quantization_config.set_converter_with_quantization( )
converter, preprocess=preprocess) return convert_to_tflite_from_file(
save_path, quantization_config, supported_ops, preprocess
converter.target_spec.supported_ops = supported_ops )
tflite_model = converter.convert()
return tflite_model
def save_tflite(tflite_model: bytearray, tflite_file: str) -> None: def save_tflite(tflite_model: bytearray, tflite_file: str) -> None:

View File

@ -169,6 +169,25 @@ class TextClassifier(classifier.Classifier):
with self._hparams.get_strategy().scope(): with self._hparams.get_strategy().scope():
return self._model.evaluate(dataset) return self._model.evaluate(dataset)
def save_model(
self,
model_name: str = "saved_model",
):
"""Saves the model in SavedModel format.
For more information, see https://www.tensorflow.org/guide/saved_model.
Args:
model_name: Name of the saved model.
"""
tf.io.gfile.makedirs(self._hparams.export_dir)
saved_model_file = os.path.join(self._hparams.export_dir, model_name)
self._model.save(
saved_model_file,
include_optimizer=False,
save_format="tf",
)
def export_model( def export_model(
self, self,
model_name: str = "model.tflite", model_name: str = "model.tflite",
@ -184,12 +203,16 @@ class TextClassifier(classifier.Classifier):
path is {self._hparams.export_dir}/{model_name}. path is {self._hparams.export_dir}/{model_name}.
quantization_config: The configuration for model quantization. quantization_config: The configuration for model quantization.
""" """
tf.io.gfile.makedirs(self._hparams.export_dir)
tflite_file = os.path.join(self._hparams.export_dir, model_name) tflite_file = os.path.join(self._hparams.export_dir, model_name)
tf.io.gfile.makedirs(os.path.dirname(tflite_file))
metadata_file = os.path.join(self._hparams.export_dir, "metadata.json") metadata_file = os.path.join(self._hparams.export_dir, "metadata.json")
tflite_model = model_util.convert_to_tflite( self.save_model(model_name="saved_model")
model=self._model, quantization_config=quantization_config) saved_model_file = os.path.join(self._hparams.export_dir, "saved_model")
tflite_model = model_util.convert_to_tflite_from_file(
saved_model_file, quantization_config=quantization_config
)
vocab_filepath = os.path.join(tempfile.mkdtemp(), "vocab.txt") vocab_filepath = os.path.join(tempfile.mkdtemp(), "vocab.txt")
self._save_vocab(vocab_filepath) self._save_vocab(vocab_filepath)
@ -494,6 +517,9 @@ class _BertClassifier(TextClassifier):
encoder = hub.KerasLayer( encoder = hub.KerasLayer(
self._model_spec.get_path(), self._model_spec.get_path(),
trainable=self._model_options.do_fine_tuning, trainable=self._model_options.do_fine_tuning,
load_options=tf.saved_model.LoadOptions(
experimental_io_device="/job:localhost"
),
) )
encoder_outputs = encoder(encoder_inputs) encoder_outputs = encoder(encoder_inputs)
pooled_output = encoder_outputs["pooled_output"] pooled_output = encoder_outputs["pooled_output"]
@ -512,16 +538,18 @@ class _BertClassifier(TextClassifier):
pooled_output = encoder(renamed_inputs) pooled_output = encoder(renamed_inputs)
output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)( output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)(
pooled_output) pooled_output
)
initializer = tf.keras.initializers.TruncatedNormal( initializer = tf.keras.initializers.TruncatedNormal(
stddev=self._INITIALIZER_RANGE) stddev=self._INITIALIZER_RANGE
)
output = tf.keras.layers.Dense( output = tf.keras.layers.Dense(
self._num_classes, self._num_classes,
kernel_initializer=initializer, kernel_initializer=initializer,
name="output", name="output",
activation="softmax", activation="softmax",
dtype=tf.float32)( dtype=tf.float32,
output) )(output)
self._model = tf.keras.Model(inputs=encoder_inputs, outputs=output) self._model = tf.keras.Model(inputs=encoder_inputs, outputs=output)
def _create_optimizer(self, train_data: text_ds.Dataset): def _create_optimizer(self, train_data: text_ds.Dataset):