Add export_model_with_tokenizer to Text Classifier API.
PiperOrigin-RevId: 567744604
This commit is contained in:
		
							parent
							
								
									9d85141227
								
							
						
					
					
						commit
						573fdad173
					
				| 
						 | 
					@ -93,6 +93,23 @@ py_test(
 | 
				
			||||||
    deps = [":dataset"],
 | 
					    deps = [":dataset"],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_library(
 | 
				
			||||||
 | 
					    name = "model_with_tokenizer",
 | 
				
			||||||
 | 
					    srcs = ["model_with_tokenizer.py"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_test(
 | 
				
			||||||
 | 
					    name = "model_with_tokenizer_test",
 | 
				
			||||||
 | 
					    srcs = ["model_with_tokenizer_test.py"],
 | 
				
			||||||
 | 
					    tags = ["requires-net:external"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":bert_tokenizer",
 | 
				
			||||||
 | 
					        ":model_spec",
 | 
				
			||||||
 | 
					        ":model_with_tokenizer",
 | 
				
			||||||
 | 
					        "//mediapipe/model_maker/python/core/utils:hub_loader",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
py_library(
 | 
					py_library(
 | 
				
			||||||
    name = "bert_tokenizer",
 | 
					    name = "bert_tokenizer",
 | 
				
			||||||
    srcs = ["bert_tokenizer.py"],
 | 
					    srcs = ["bert_tokenizer.py"],
 | 
				
			||||||
| 
						 | 
					@ -145,10 +162,12 @@ py_library(
 | 
				
			||||||
    name = "text_classifier",
 | 
					    name = "text_classifier",
 | 
				
			||||||
    srcs = ["text_classifier.py"],
 | 
					    srcs = ["text_classifier.py"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":bert_tokenizer",
 | 
				
			||||||
        ":dataset",
 | 
					        ":dataset",
 | 
				
			||||||
        ":hyperparameters",
 | 
					        ":hyperparameters",
 | 
				
			||||||
        ":model_options",
 | 
					        ":model_options",
 | 
				
			||||||
        ":model_spec",
 | 
					        ":model_spec",
 | 
				
			||||||
 | 
					        ":model_with_tokenizer",
 | 
				
			||||||
        ":preprocessor",
 | 
					        ":preprocessor",
 | 
				
			||||||
        ":text_classifier_options",
 | 
					        ":text_classifier_options",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core/data:dataset",
 | 
					        "//mediapipe/model_maker/python/core/data:dataset",
 | 
				
			||||||
| 
						 | 
					@ -165,7 +184,7 @@ py_library(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
py_test(
 | 
					py_test(
 | 
				
			||||||
    name = "text_classifier_test",
 | 
					    name = "text_classifier_test",
 | 
				
			||||||
    size = "large",
 | 
					    size = "enormous",
 | 
				
			||||||
    srcs = ["text_classifier_test.py"],
 | 
					    srcs = ["text_classifier_test.py"],
 | 
				
			||||||
    data = [
 | 
					    data = [
 | 
				
			||||||
        "//mediapipe/model_maker/python/text/text_classifier/testdata",
 | 
					        "//mediapipe/model_maker/python/text/text_classifier/testdata",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -56,6 +56,15 @@ class BertFullTokenizer(BertTokenizer):
 | 
				
			||||||
    self._seq_len = seq_len
 | 
					    self._seq_len = seq_len
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
 | 
					  def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
 | 
				
			||||||
 | 
					    """Processes one input_tensor example.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      input_tensor: A tensor with shape (1, None) of a utf-8 encoded string.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A dictionary of lists all with shape (1, self._seq_len) containing the
 | 
				
			||||||
 | 
					        keys "input_word_ids", "input_type_ids", and "input_mask".
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
    tokens = self._tokenizer.tokenize(input_tensor.numpy()[0].decode("utf-8"))
 | 
					    tokens = self._tokenizer.tokenize(input_tensor.numpy()[0].decode("utf-8"))
 | 
				
			||||||
    tokens = tokens[0 : (self._seq_len - 2)]  # account for [CLS] and [SEP]
 | 
					    tokens = tokens[0 : (self._seq_len - 2)]  # account for [CLS] and [SEP]
 | 
				
			||||||
    tokens.insert(0, "[CLS]")
 | 
					    tokens.insert(0, "[CLS]")
 | 
				
			||||||
| 
						 | 
					@ -96,7 +105,18 @@ class BertFastTokenizer(BertTokenizer):
 | 
				
			||||||
    self._sep_id = vocab.index("[SEP]")
 | 
					    self._sep_id = vocab.index("[SEP]")
 | 
				
			||||||
    self._pad_id = vocab.index("[PAD]")
 | 
					    self._pad_id = vocab.index("[PAD]")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
 | 
					  def process_fn(self, input_tensor: tf.Tensor) -> Mapping[str, tf.Tensor]:
 | 
				
			||||||
 | 
					    """Tensor implementation of the process function.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    This implementation can be used within a model graph directly since it
 | 
				
			||||||
 | 
					    takes in tensors and outputs tensors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      input_tensor: Input string tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      Dictionary of tf.Tensors.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
    input_ids = self._tokenizer.tokenize(input_tensor).flat_values
 | 
					    input_ids = self._tokenizer.tokenize(input_tensor).flat_values
 | 
				
			||||||
    input_ids = input_ids[: (self._seq_len - 2)]
 | 
					    input_ids = input_ids[: (self._seq_len - 2)]
 | 
				
			||||||
    input_ids = tf.concat(
 | 
					    input_ids = tf.concat(
 | 
				
			||||||
| 
						 | 
					@ -112,7 +132,20 @@ class BertFastTokenizer(BertTokenizer):
 | 
				
			||||||
    input_type_ids = tf.zeros(self._seq_len, dtype=tf.int32)
 | 
					    input_type_ids = tf.zeros(self._seq_len, dtype=tf.int32)
 | 
				
			||||||
    input_mask = tf.cast(input_ids != self._pad_id, dtype=tf.int32)
 | 
					    input_mask = tf.cast(input_ids != self._pad_id, dtype=tf.int32)
 | 
				
			||||||
    return {
 | 
					    return {
 | 
				
			||||||
        "input_word_ids": input_ids.numpy().tolist(),
 | 
					        "input_word_ids": input_ids,
 | 
				
			||||||
        "input_type_ids": input_type_ids.numpy().tolist(),
 | 
					        "input_type_ids": input_type_ids,
 | 
				
			||||||
        "input_mask": input_mask.numpy().tolist(),
 | 
					        "input_mask": input_mask,
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
 | 
				
			||||||
 | 
					    """Processes one input_tensor example.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      input_tensor: A tensor with shape (1, None) of a utf-8 encoded string.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A dictionary of lists all with shape (1, self._seq_len) containing the
 | 
				
			||||||
 | 
					        keys "input_word_ids", "input_type_ids", and "input_mask".
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    result = self.process_fn(input_tensor)
 | 
				
			||||||
 | 
					    return {k: v.numpy().tolist() for k, v in result.items()}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,35 @@
 | 
				
			||||||
 | 
					# Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""Text classifier export module library."""
 | 
				
			||||||
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ModelWithTokenizer(tf.keras.Model):
 | 
				
			||||||
 | 
					  """A model with the tokenizer included in graph for exporting to TFLite."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __init__(self, tokenizer, model):
 | 
				
			||||||
 | 
					    super().__init__()
 | 
				
			||||||
 | 
					    self._tokenizer = tokenizer
 | 
				
			||||||
 | 
					    self._model = model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @tf.function(
 | 
				
			||||||
 | 
					      input_signature=[
 | 
				
			||||||
 | 
					          tf.TensorSpec(shape=[None], dtype=tf.string, name="input")
 | 
				
			||||||
 | 
					      ]
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
 | 
					  def call(self, input_tensor):
 | 
				
			||||||
 | 
					    x = self._tokenizer.process_fn(input_tensor)
 | 
				
			||||||
 | 
					    x = {k: tf.expand_dims(v, axis=0) for k, v in x.items()}
 | 
				
			||||||
 | 
					    x = self._model(x)
 | 
				
			||||||
 | 
					    return x
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,105 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import tempfile
 | 
				
			||||||
 | 
					from unittest import mock as unittest_mock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					import tensorflow_hub
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.core.utils import hub_loader
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_spec
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_with_tokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BertTokenizerTest(tf.test.TestCase):
 | 
				
			||||||
 | 
					  _SEQ_LEN = 128
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def setUp(self):
 | 
				
			||||||
 | 
					    super().setUp()
 | 
				
			||||||
 | 
					    # Mock tempfile.gettempdir() to be unique for each test to avoid race
 | 
				
			||||||
 | 
					    # condition when downloading model since these tests may run in parallel.
 | 
				
			||||||
 | 
					    mock_gettempdir = unittest_mock.patch.object(
 | 
				
			||||||
 | 
					        tempfile,
 | 
				
			||||||
 | 
					        "gettempdir",
 | 
				
			||||||
 | 
					        return_value=self.create_tempdir(),
 | 
				
			||||||
 | 
					        autospec=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    self.mock_gettempdir = mock_gettempdir.start()
 | 
				
			||||||
 | 
					    self.addCleanup(mock_gettempdir.stop)
 | 
				
			||||||
 | 
					    self._ms = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
 | 
				
			||||||
 | 
					    self._tokenizer = self._create_tokenizer()
 | 
				
			||||||
 | 
					    self._model = self._create_model()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _create_tokenizer(self):
 | 
				
			||||||
 | 
					    vocab_file = os.path.join(
 | 
				
			||||||
 | 
					        tensorflow_hub.resolve(self._ms.get_path()), "assets", "vocab.txt"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    return bert_tokenizer.BertFastTokenizer(vocab_file, True, self._SEQ_LEN)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _create_model(self):
 | 
				
			||||||
 | 
					    encoder_inputs = dict(
 | 
				
			||||||
 | 
					        input_word_ids=tf.keras.layers.Input(
 | 
				
			||||||
 | 
					            shape=(self._SEQ_LEN,),
 | 
				
			||||||
 | 
					            dtype=tf.int32,
 | 
				
			||||||
 | 
					            name="input_word_ids",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        input_mask=tf.keras.layers.Input(
 | 
				
			||||||
 | 
					            shape=(self._SEQ_LEN,),
 | 
				
			||||||
 | 
					            dtype=tf.int32,
 | 
				
			||||||
 | 
					            name="input_mask",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        input_type_ids=tf.keras.layers.Input(
 | 
				
			||||||
 | 
					            shape=(self._SEQ_LEN,),
 | 
				
			||||||
 | 
					            dtype=tf.int32,
 | 
				
			||||||
 | 
					            name="input_type_ids",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    renamed_inputs = dict(
 | 
				
			||||||
 | 
					        input_ids=encoder_inputs["input_word_ids"],
 | 
				
			||||||
 | 
					        input_mask=encoder_inputs["input_mask"],
 | 
				
			||||||
 | 
					        segment_ids=encoder_inputs["input_type_ids"],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    encoder = hub_loader.HubKerasLayerV1V2(
 | 
				
			||||||
 | 
					        self._ms.get_path(),
 | 
				
			||||||
 | 
					        signature="tokens",
 | 
				
			||||||
 | 
					        output_key="pooled_output",
 | 
				
			||||||
 | 
					        trainable=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    pooled_output = encoder(renamed_inputs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    output = tf.keras.layers.Dropout(rate=0.1)(pooled_output)
 | 
				
			||||||
 | 
					    initializer = tf.keras.initializers.TruncatedNormal(stddev=0.02)
 | 
				
			||||||
 | 
					    output = tf.keras.layers.Dense(
 | 
				
			||||||
 | 
					        2,
 | 
				
			||||||
 | 
					        kernel_initializer=initializer,
 | 
				
			||||||
 | 
					        name="output",
 | 
				
			||||||
 | 
					        activation="softmax",
 | 
				
			||||||
 | 
					        dtype=tf.float32,
 | 
				
			||||||
 | 
					    )(output)
 | 
				
			||||||
 | 
					    return tf.keras.Model(inputs=encoder_inputs, outputs=output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_model_with_tokenizer(self):
 | 
				
			||||||
 | 
					    model = model_with_tokenizer.ModelWithTokenizer(
 | 
				
			||||||
 | 
					        self._tokenizer, self._model
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    output = model(tf.constant(["Example input".encode("utf-8")]))
 | 
				
			||||||
 | 
					    self.assertAllEqual(output.shape, (1, 2))
 | 
				
			||||||
 | 
					    self.assertEqual(tf.reduce_sum(output), 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					  tf.test.main()
 | 
				
			||||||
| 
						 | 
					@ -368,6 +368,10 @@ class BertClassifierPreprocessor:
 | 
				
			||||||
        tfrecord_cache_files=tfrecord_cache_files,
 | 
					        tfrecord_cache_files=tfrecord_cache_files,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @property
 | 
				
			||||||
 | 
					  def tokenizer(self) -> bert_tokenizer.BertTokenizer:
 | 
				
			||||||
 | 
					    return self._tokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TextClassifierPreprocessor = Union[
 | 
					TextClassifierPreprocessor = Union[
 | 
				
			||||||
    BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor
 | 
					    BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -29,10 +29,12 @@ from mediapipe.model_maker.python.core.utils import loss_functions
 | 
				
			||||||
from mediapipe.model_maker.python.core.utils import metrics
 | 
					from mediapipe.model_maker.python.core.utils import metrics
 | 
				
			||||||
from mediapipe.model_maker.python.core.utils import model_util
 | 
					from mediapipe.model_maker.python.core.utils import model_util
 | 
				
			||||||
from mediapipe.model_maker.python.core.utils import quantization
 | 
					from mediapipe.model_maker.python.core.utils import quantization
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
 | 
					from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
					from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_with_tokenizer
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
 | 
					from mediapipe.model_maker.python.text.text_classifier import preprocessor
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
 | 
					from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
 | 
				
			||||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
 | 
					from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
 | 
				
			||||||
| 
						 | 
					@ -620,3 +622,56 @@ class _BertClassifier(TextClassifier):
 | 
				
			||||||
        ids_name=self._model_spec.tflite_input_name["ids"],
 | 
					        ids_name=self._model_spec.tflite_input_name["ids"],
 | 
				
			||||||
        mask_name=self._model_spec.tflite_input_name["mask"],
 | 
					        mask_name=self._model_spec.tflite_input_name["mask"],
 | 
				
			||||||
        segment_name=self._model_spec.tflite_input_name["segment_ids"])
 | 
					        segment_name=self._model_spec.tflite_input_name["segment_ids"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def export_model_with_tokenizer(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      model_name: str = "model_with_tokenizer.tflite",
 | 
				
			||||||
 | 
					      quantization_config: Optional[quantization.QuantizationConfig] = None,
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
 | 
					    """Converts and saves the model to a TFLite file with the tokenizer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Note that unlike the export_model method, this export method will include
 | 
				
			||||||
 | 
					    a FastBertTokenizer in the TFLite graph. The resulting TFLite will not have
 | 
				
			||||||
 | 
					    metadata information to use with MediaPipe Tasks, but can be run directly
 | 
				
			||||||
 | 
					    using TFLite Inference: https://www.tensorflow.org/lite/guide/inference
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    For more information on the tokenizer, see:
 | 
				
			||||||
 | 
					      https://www.tensorflow.org/text/api_docs/python/text/FastBertTokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      model_name: File name to save TFLite model with tokenizer. The full export
 | 
				
			||||||
 | 
					        path is {self._hparams.export_dir}/{model_name}.
 | 
				
			||||||
 | 
					      quantization_config: The configuration for model quantization.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    tf.io.gfile.makedirs(self._hparams.export_dir)
 | 
				
			||||||
 | 
					    tflite_file = os.path.join(self._hparams.export_dir, model_name)
 | 
				
			||||||
 | 
					    if (
 | 
				
			||||||
 | 
					        self._hparams.tokenizer
 | 
				
			||||||
 | 
					        != bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					      print(
 | 
				
			||||||
 | 
					          f"WARNING: This model was trained with {self._hparams.tokenizer} "
 | 
				
			||||||
 | 
					          "tokenizer, but the exported model with tokenizer will have a "
 | 
				
			||||||
 | 
					          f"{bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER} "
 | 
				
			||||||
 | 
					          "tokenizer."
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					      tokenizer = bert_tokenizer.BertFastTokenizer(
 | 
				
			||||||
 | 
					          vocab_file=self._text_preprocessor.get_vocab_file(),
 | 
				
			||||||
 | 
					          do_lower_case=self._model_spec.do_lower_case,
 | 
				
			||||||
 | 
					          seq_len=self._model_options.seq_len,
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					      tokenizer = self._text_preprocessor.tokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model = model_with_tokenizer.ModelWithTokenizer(tokenizer, self._model)
 | 
				
			||||||
 | 
					    model(tf.constant(["Example input data".encode("utf-8")]))  # build model
 | 
				
			||||||
 | 
					    saved_model_file = os.path.join(
 | 
				
			||||||
 | 
					        self._hparams.export_dir, "saved_model_with_tokenizer"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    model.save(saved_model_file)
 | 
				
			||||||
 | 
					    tflite_model = model_util.convert_to_tflite_from_file(
 | 
				
			||||||
 | 
					        saved_model_file,
 | 
				
			||||||
 | 
					        quantization_config=quantization_config,
 | 
				
			||||||
 | 
					        allow_custom_ops=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    model_util.save_tflite(tflite_model, tflite_file)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -149,6 +149,12 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
 | 
				
			||||||
            output_metadata_file, self._BERT_CLASSIFIER_JSON_FILE, shallow=False
 | 
					            output_metadata_file, self._BERT_CLASSIFIER_JSON_FILE, shallow=False
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					    bert_classifier.export_model_with_tokenizer()
 | 
				
			||||||
 | 
					    output_tflite_with_tokenizer_file = os.path.join(
 | 
				
			||||||
 | 
					        options.hparams.export_dir, 'model_with_tokenizer.tflite'
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    self.assertTrue(os.path.exists(output_tflite_with_tokenizer_file))
 | 
				
			||||||
 | 
					    self.assertGreater(os.path.getsize(output_tflite_with_tokenizer_file), 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_label_mismatch(self):
 | 
					  def test_label_mismatch(self):
 | 
				
			||||||
    options = text_classifier.TextClassifierOptions(
 | 
					    options = text_classifier.TextClassifierOptions(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user