From 13bb65db960ff1934c18a1178a225b07148473ac Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 16 Aug 2023 12:14:05 -0700 Subject: [PATCH] Internal Changes PiperOrigin-RevId: 557563669 --- .../python/core/utils/model_util.py | 51 +++++++++++++++---- .../text/text_classifier/text_classifier.py | 42 ++++++++++++--- 2 files changed, 76 insertions(+), 17 deletions(-) diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index 5ca2c2b7b..fd11c60b2 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -112,6 +112,39 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None, return len(train_data) // batch_size +def convert_to_tflite_from_file( + saved_model_file: str, + quantization_config: Optional[quantization.QuantizationConfig] = None, + supported_ops: Tuple[tf.lite.OpsSet, ...] = ( + tf.lite.OpsSet.TFLITE_BUILTINS, + ), + preprocess: Optional[Callable[..., Any]] = None, +) -> bytearray: + """Converts the input Keras model to TFLite format. + + Args: + saved_model_file: Keras model to be converted to TFLite. + quantization_config: Configuration for post-training quantization. + supported_ops: A list of supported ops in the converted TFLite file. + preprocess: A callable to preprocess the representative dataset for + quantization. The callable takes three arguments in order: feature, label, + and is_training. + + Returns: + bytearray of TFLite model + """ + converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_file) + + if quantization_config: + converter = quantization_config.set_converter_with_quantization( + converter, preprocess=preprocess + ) + + converter.target_spec.supported_ops = supported_ops + tflite_model = converter.convert() + return tflite_model + + def convert_to_tflite( model: tf.keras.Model, quantization_config: Optional[quantization.QuantizationConfig] = None, @@ -135,16 +168,14 @@ def convert_to_tflite( """ with tempfile.TemporaryDirectory() as temp_dir: save_path = os.path.join(temp_dir, 'saved_model') - model.save(save_path, include_optimizer=False, save_format='tf') - converter = tf.lite.TFLiteConverter.from_saved_model(save_path) - - if quantization_config: - converter = quantization_config.set_converter_with_quantization( - converter, preprocess=preprocess) - - converter.target_spec.supported_ops = supported_ops - tflite_model = converter.convert() - return tflite_model + model.save( + save_path, + include_optimizer=False, + save_format='tf', + ) + return convert_to_tflite_from_file( + save_path, quantization_config, supported_ops, preprocess + ) def save_tflite(tflite_model: bytearray, tflite_file: str) -> None: diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index 76043aa72..752752230 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -169,6 +169,25 @@ class TextClassifier(classifier.Classifier): with self._hparams.get_strategy().scope(): return self._model.evaluate(dataset) + def save_model( + self, + model_name: str = "saved_model", + ): + """Saves the model in SavedModel format. + + For more information, see https://www.tensorflow.org/guide/saved_model. + + Args: + model_name: Name of the saved model. + """ + tf.io.gfile.makedirs(self._hparams.export_dir) + saved_model_file = os.path.join(self._hparams.export_dir, model_name) + self._model.save( + saved_model_file, + include_optimizer=False, + save_format="tf", + ) + def export_model( self, model_name: str = "model.tflite", @@ -184,12 +203,16 @@ class TextClassifier(classifier.Classifier): path is {self._hparams.export_dir}/{model_name}. quantization_config: The configuration for model quantization. """ + tf.io.gfile.makedirs(self._hparams.export_dir) tflite_file = os.path.join(self._hparams.export_dir, model_name) - tf.io.gfile.makedirs(os.path.dirname(tflite_file)) metadata_file = os.path.join(self._hparams.export_dir, "metadata.json") - tflite_model = model_util.convert_to_tflite( - model=self._model, quantization_config=quantization_config) + self.save_model(model_name="saved_model") + saved_model_file = os.path.join(self._hparams.export_dir, "saved_model") + + tflite_model = model_util.convert_to_tflite_from_file( + saved_model_file, quantization_config=quantization_config + ) vocab_filepath = os.path.join(tempfile.mkdtemp(), "vocab.txt") self._save_vocab(vocab_filepath) @@ -494,6 +517,9 @@ class _BertClassifier(TextClassifier): encoder = hub.KerasLayer( self._model_spec.get_path(), trainable=self._model_options.do_fine_tuning, + load_options=tf.saved_model.LoadOptions( + experimental_io_device="/job:localhost" + ), ) encoder_outputs = encoder(encoder_inputs) pooled_output = encoder_outputs["pooled_output"] @@ -512,16 +538,18 @@ class _BertClassifier(TextClassifier): pooled_output = encoder(renamed_inputs) output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)( - pooled_output) + pooled_output + ) initializer = tf.keras.initializers.TruncatedNormal( - stddev=self._INITIALIZER_RANGE) + stddev=self._INITIALIZER_RANGE + ) output = tf.keras.layers.Dense( self._num_classes, kernel_initializer=initializer, name="output", activation="softmax", - dtype=tf.float32)( - output) + dtype=tf.float32, + )(output) self._model = tf.keras.Model(inputs=encoder_inputs, outputs=output) def _create_optimizer(self, train_data: text_ds.Dataset):