From 573fdad1732a3801075b6292e35cda5858e1d978 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 22 Sep 2023 16:28:36 -0700 Subject: [PATCH] Add export_model_with_tokenizer to Text Classifier API. PiperOrigin-RevId: 567744604 --- .../python/text/text_classifier/BUILD | 21 +++- .../text/text_classifier/bert_tokenizer.py | 41 ++++++- .../text_classifier/model_with_tokenizer.py | 35 ++++++ .../model_with_tokenizer_test.py | 105 ++++++++++++++++++ .../text/text_classifier/preprocessor.py | 4 + .../text/text_classifier/text_classifier.py | 55 +++++++++ .../text_classifier/text_classifier_test.py | 6 + 7 files changed, 262 insertions(+), 5 deletions(-) create mode 100644 mediapipe/model_maker/python/text/text_classifier/model_with_tokenizer.py create mode 100644 mediapipe/model_maker/python/text/text_classifier/model_with_tokenizer_test.py diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 2c239e4b0..8b5721590 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -93,6 +93,23 @@ py_test( deps = [":dataset"], ) +py_library( + name = "model_with_tokenizer", + srcs = ["model_with_tokenizer.py"], +) + +py_test( + name = "model_with_tokenizer_test", + srcs = ["model_with_tokenizer_test.py"], + tags = ["requires-net:external"], + deps = [ + ":bert_tokenizer", + ":model_spec", + ":model_with_tokenizer", + "//mediapipe/model_maker/python/core/utils:hub_loader", + ], +) + py_library( name = "bert_tokenizer", srcs = ["bert_tokenizer.py"], @@ -145,10 +162,12 @@ py_library( name = "text_classifier", srcs = ["text_classifier.py"], deps = [ + ":bert_tokenizer", ":dataset", ":hyperparameters", ":model_options", ":model_spec", + ":model_with_tokenizer", ":preprocessor", ":text_classifier_options", "//mediapipe/model_maker/python/core/data:dataset", @@ -165,7 +184,7 @@ py_library( py_test( name = "text_classifier_test", - size = "large", + size = "enormous", srcs = ["text_classifier_test.py"], data = [ "//mediapipe/model_maker/python/text/text_classifier/testdata", diff --git a/mediapipe/model_maker/python/text/text_classifier/bert_tokenizer.py b/mediapipe/model_maker/python/text/text_classifier/bert_tokenizer.py index 8e92bc29c..ce4b47d4c 100644 --- a/mediapipe/model_maker/python/text/text_classifier/bert_tokenizer.py +++ b/mediapipe/model_maker/python/text/text_classifier/bert_tokenizer.py @@ -56,6 +56,15 @@ class BertFullTokenizer(BertTokenizer): self._seq_len = seq_len def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]: + """Processes one input_tensor example. + + Args: + input_tensor: A tensor with shape (1, None) of a utf-8 encoded string. + + Returns: + A dictionary of lists all with shape (1, self._seq_len) containing the + keys "input_word_ids", "input_type_ids", and "input_mask". + """ tokens = self._tokenizer.tokenize(input_tensor.numpy()[0].decode("utf-8")) tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP] tokens.insert(0, "[CLS]") @@ -96,7 +105,18 @@ class BertFastTokenizer(BertTokenizer): self._sep_id = vocab.index("[SEP]") self._pad_id = vocab.index("[PAD]") - def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]: + def process_fn(self, input_tensor: tf.Tensor) -> Mapping[str, tf.Tensor]: + """Tensor implementation of the process function. + + This implementation can be used within a model graph directly since it + takes in tensors and outputs tensors. + + Args: + input_tensor: Input string tensor + + Returns: + Dictionary of tf.Tensors. + """ input_ids = self._tokenizer.tokenize(input_tensor).flat_values input_ids = input_ids[: (self._seq_len - 2)] input_ids = tf.concat( @@ -112,7 +132,20 @@ class BertFastTokenizer(BertTokenizer): input_type_ids = tf.zeros(self._seq_len, dtype=tf.int32) input_mask = tf.cast(input_ids != self._pad_id, dtype=tf.int32) return { - "input_word_ids": input_ids.numpy().tolist(), - "input_type_ids": input_type_ids.numpy().tolist(), - "input_mask": input_mask.numpy().tolist(), + "input_word_ids": input_ids, + "input_type_ids": input_type_ids, + "input_mask": input_mask, } + + def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]: + """Processes one input_tensor example. + + Args: + input_tensor: A tensor with shape (1, None) of a utf-8 encoded string. + + Returns: + A dictionary of lists all with shape (1, self._seq_len) containing the + keys "input_word_ids", "input_type_ids", and "input_mask". + """ + result = self.process_fn(input_tensor) + return {k: v.numpy().tolist() for k, v in result.items()} diff --git a/mediapipe/model_maker/python/text/text_classifier/model_with_tokenizer.py b/mediapipe/model_maker/python/text/text_classifier/model_with_tokenizer.py new file mode 100644 index 000000000..95328fb43 --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/model_with_tokenizer.py @@ -0,0 +1,35 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Text classifier export module library.""" +import tensorflow as tf + + +class ModelWithTokenizer(tf.keras.Model): + """A model with the tokenizer included in graph for exporting to TFLite.""" + + def __init__(self, tokenizer, model): + super().__init__() + self._tokenizer = tokenizer + self._model = model + + @tf.function( + input_signature=[ + tf.TensorSpec(shape=[None], dtype=tf.string, name="input") + ] + ) + def call(self, input_tensor): + x = self._tokenizer.process_fn(input_tensor) + x = {k: tf.expand_dims(v, axis=0) for k, v in x.items()} + x = self._model(x) + return x diff --git a/mediapipe/model_maker/python/text/text_classifier/model_with_tokenizer_test.py b/mediapipe/model_maker/python/text/text_classifier/model_with_tokenizer_test.py new file mode 100644 index 000000000..f6c5d2477 --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/model_with_tokenizer_test.py @@ -0,0 +1,105 @@ +# Copyright 2022 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +from unittest import mock as unittest_mock + +import tensorflow as tf +import tensorflow_hub + +from mediapipe.model_maker.python.core.utils import hub_loader +from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer +from mediapipe.model_maker.python.text.text_classifier import model_spec +from mediapipe.model_maker.python.text.text_classifier import model_with_tokenizer + + +class BertTokenizerTest(tf.test.TestCase): + _SEQ_LEN = 128 + + def setUp(self): + super().setUp() + # Mock tempfile.gettempdir() to be unique for each test to avoid race + # condition when downloading model since these tests may run in parallel. + mock_gettempdir = unittest_mock.patch.object( + tempfile, + "gettempdir", + return_value=self.create_tempdir(), + autospec=True, + ) + self.mock_gettempdir = mock_gettempdir.start() + self.addCleanup(mock_gettempdir.stop) + self._ms = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() + self._tokenizer = self._create_tokenizer() + self._model = self._create_model() + + def _create_tokenizer(self): + vocab_file = os.path.join( + tensorflow_hub.resolve(self._ms.get_path()), "assets", "vocab.txt" + ) + return bert_tokenizer.BertFastTokenizer(vocab_file, True, self._SEQ_LEN) + + def _create_model(self): + encoder_inputs = dict( + input_word_ids=tf.keras.layers.Input( + shape=(self._SEQ_LEN,), + dtype=tf.int32, + name="input_word_ids", + ), + input_mask=tf.keras.layers.Input( + shape=(self._SEQ_LEN,), + dtype=tf.int32, + name="input_mask", + ), + input_type_ids=tf.keras.layers.Input( + shape=(self._SEQ_LEN,), + dtype=tf.int32, + name="input_type_ids", + ), + ) + renamed_inputs = dict( + input_ids=encoder_inputs["input_word_ids"], + input_mask=encoder_inputs["input_mask"], + segment_ids=encoder_inputs["input_type_ids"], + ) + encoder = hub_loader.HubKerasLayerV1V2( + self._ms.get_path(), + signature="tokens", + output_key="pooled_output", + trainable=True, + ) + pooled_output = encoder(renamed_inputs) + + output = tf.keras.layers.Dropout(rate=0.1)(pooled_output) + initializer = tf.keras.initializers.TruncatedNormal(stddev=0.02) + output = tf.keras.layers.Dense( + 2, + kernel_initializer=initializer, + name="output", + activation="softmax", + dtype=tf.float32, + )(output) + return tf.keras.Model(inputs=encoder_inputs, outputs=output) + + def test_model_with_tokenizer(self): + model = model_with_tokenizer.ModelWithTokenizer( + self._tokenizer, self._model + ) + output = model(tf.constant(["Example input".encode("utf-8")])) + self.assertAllEqual(output.shape, (1, 2)) + self.assertEqual(tf.reduce_sum(output), 1) + + +if __name__ == "__main__": + tf.test.main() diff --git a/mediapipe/model_maker/python/text/text_classifier/preprocessor.py b/mediapipe/model_maker/python/text/text_classifier/preprocessor.py index 5954f4ca3..24130f6f8 100644 --- a/mediapipe/model_maker/python/text/text_classifier/preprocessor.py +++ b/mediapipe/model_maker/python/text/text_classifier/preprocessor.py @@ -368,6 +368,10 @@ class BertClassifierPreprocessor: tfrecord_cache_files=tfrecord_cache_files, ) + @property + def tokenizer(self) -> bert_tokenizer.BertTokenizer: + return self._tokenizer + TextClassifierPreprocessor = Union[ BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index c067a4ed6..623edbc38 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -29,10 +29,12 @@ from mediapipe.model_maker.python.core.utils import loss_functions from mediapipe.model_maker.python.core.utils import metrics from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp from mediapipe.model_maker.python.text.text_classifier import model_options as mo from mediapipe.model_maker.python.text.text_classifier import model_spec as ms +from mediapipe.model_maker.python.text.text_classifier import model_with_tokenizer from mediapipe.model_maker.python.text.text_classifier import preprocessor from mediapipe.model_maker.python.text.text_classifier import text_classifier_options from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer @@ -620,3 +622,56 @@ class _BertClassifier(TextClassifier): ids_name=self._model_spec.tflite_input_name["ids"], mask_name=self._model_spec.tflite_input_name["mask"], segment_name=self._model_spec.tflite_input_name["segment_ids"]) + + def export_model_with_tokenizer( + self, + model_name: str = "model_with_tokenizer.tflite", + quantization_config: Optional[quantization.QuantizationConfig] = None, + ): + """Converts and saves the model to a TFLite file with the tokenizer. + + Note that unlike the export_model method, this export method will include + a FastBertTokenizer in the TFLite graph. The resulting TFLite will not have + metadata information to use with MediaPipe Tasks, but can be run directly + using TFLite Inference: https://www.tensorflow.org/lite/guide/inference + + For more information on the tokenizer, see: + https://www.tensorflow.org/text/api_docs/python/text/FastBertTokenizer + + Args: + model_name: File name to save TFLite model with tokenizer. The full export + path is {self._hparams.export_dir}/{model_name}. + quantization_config: The configuration for model quantization. + """ + tf.io.gfile.makedirs(self._hparams.export_dir) + tflite_file = os.path.join(self._hparams.export_dir, model_name) + if ( + self._hparams.tokenizer + != bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER + ): + print( + f"WARNING: This model was trained with {self._hparams.tokenizer} " + "tokenizer, but the exported model with tokenizer will have a " + f"{bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER} " + "tokenizer." + ) + tokenizer = bert_tokenizer.BertFastTokenizer( + vocab_file=self._text_preprocessor.get_vocab_file(), + do_lower_case=self._model_spec.do_lower_case, + seq_len=self._model_options.seq_len, + ) + else: + tokenizer = self._text_preprocessor.tokenizer + + model = model_with_tokenizer.ModelWithTokenizer(tokenizer, self._model) + model(tf.constant(["Example input data".encode("utf-8")])) # build model + saved_model_file = os.path.join( + self._hparams.export_dir, "saved_model_with_tokenizer" + ) + model.save(saved_model_file) + tflite_model = model_util.convert_to_tflite_from_file( + saved_model_file, + quantization_config=quantization_config, + allow_custom_ops=True, + ) + model_util.save_tflite(tflite_model, tflite_file) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index 122182ddd..fdc2613a9 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -149,6 +149,12 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase): output_metadata_file, self._BERT_CLASSIFIER_JSON_FILE, shallow=False ) ) + bert_classifier.export_model_with_tokenizer() + output_tflite_with_tokenizer_file = os.path.join( + options.hparams.export_dir, 'model_with_tokenizer.tflite' + ) + self.assertTrue(os.path.exists(output_tflite_with_tokenizer_file)) + self.assertGreater(os.path.getsize(output_tflite_with_tokenizer_file), 0) def test_label_mismatch(self): options = text_classifier.TextClassifierOptions(