Add export_model_with_tokenizer to Text Classifier API.

PiperOrigin-RevId: 567744604
This commit is contained in:
MediaPipe Team 2023-09-22 16:28:36 -07:00 committed by Copybara-Service
parent 9d85141227
commit 573fdad173
7 changed files with 262 additions and 5 deletions

View File

@ -93,6 +93,23 @@ py_test(
deps = [":dataset"], deps = [":dataset"],
) )
py_library(
name = "model_with_tokenizer",
srcs = ["model_with_tokenizer.py"],
)
py_test(
name = "model_with_tokenizer_test",
srcs = ["model_with_tokenizer_test.py"],
tags = ["requires-net:external"],
deps = [
":bert_tokenizer",
":model_spec",
":model_with_tokenizer",
"//mediapipe/model_maker/python/core/utils:hub_loader",
],
)
py_library( py_library(
name = "bert_tokenizer", name = "bert_tokenizer",
srcs = ["bert_tokenizer.py"], srcs = ["bert_tokenizer.py"],
@ -145,10 +162,12 @@ py_library(
name = "text_classifier", name = "text_classifier",
srcs = ["text_classifier.py"], srcs = ["text_classifier.py"],
deps = [ deps = [
":bert_tokenizer",
":dataset", ":dataset",
":hyperparameters", ":hyperparameters",
":model_options", ":model_options",
":model_spec", ":model_spec",
":model_with_tokenizer",
":preprocessor", ":preprocessor",
":text_classifier_options", ":text_classifier_options",
"//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/data:dataset",
@ -165,7 +184,7 @@ py_library(
py_test( py_test(
name = "text_classifier_test", name = "text_classifier_test",
size = "large", size = "enormous",
srcs = ["text_classifier_test.py"], srcs = ["text_classifier_test.py"],
data = [ data = [
"//mediapipe/model_maker/python/text/text_classifier/testdata", "//mediapipe/model_maker/python/text/text_classifier/testdata",

View File

@ -56,6 +56,15 @@ class BertFullTokenizer(BertTokenizer):
self._seq_len = seq_len self._seq_len = seq_len
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]: def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
"""Processes one input_tensor example.
Args:
input_tensor: A tensor with shape (1, None) of a utf-8 encoded string.
Returns:
A dictionary of lists all with shape (1, self._seq_len) containing the
keys "input_word_ids", "input_type_ids", and "input_mask".
"""
tokens = self._tokenizer.tokenize(input_tensor.numpy()[0].decode("utf-8")) tokens = self._tokenizer.tokenize(input_tensor.numpy()[0].decode("utf-8"))
tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP] tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP]
tokens.insert(0, "[CLS]") tokens.insert(0, "[CLS]")
@ -96,7 +105,18 @@ class BertFastTokenizer(BertTokenizer):
self._sep_id = vocab.index("[SEP]") self._sep_id = vocab.index("[SEP]")
self._pad_id = vocab.index("[PAD]") self._pad_id = vocab.index("[PAD]")
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]: def process_fn(self, input_tensor: tf.Tensor) -> Mapping[str, tf.Tensor]:
"""Tensor implementation of the process function.
This implementation can be used within a model graph directly since it
takes in tensors and outputs tensors.
Args:
input_tensor: Input string tensor
Returns:
Dictionary of tf.Tensors.
"""
input_ids = self._tokenizer.tokenize(input_tensor).flat_values input_ids = self._tokenizer.tokenize(input_tensor).flat_values
input_ids = input_ids[: (self._seq_len - 2)] input_ids = input_ids[: (self._seq_len - 2)]
input_ids = tf.concat( input_ids = tf.concat(
@ -112,7 +132,20 @@ class BertFastTokenizer(BertTokenizer):
input_type_ids = tf.zeros(self._seq_len, dtype=tf.int32) input_type_ids = tf.zeros(self._seq_len, dtype=tf.int32)
input_mask = tf.cast(input_ids != self._pad_id, dtype=tf.int32) input_mask = tf.cast(input_ids != self._pad_id, dtype=tf.int32)
return { return {
"input_word_ids": input_ids.numpy().tolist(), "input_word_ids": input_ids,
"input_type_ids": input_type_ids.numpy().tolist(), "input_type_ids": input_type_ids,
"input_mask": input_mask.numpy().tolist(), "input_mask": input_mask,
} }
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
"""Processes one input_tensor example.
Args:
input_tensor: A tensor with shape (1, None) of a utf-8 encoded string.
Returns:
A dictionary of lists all with shape (1, self._seq_len) containing the
keys "input_word_ids", "input_type_ids", and "input_mask".
"""
result = self.process_fn(input_tensor)
return {k: v.numpy().tolist() for k, v in result.items()}

View File

@ -0,0 +1,35 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text classifier export module library."""
import tensorflow as tf
class ModelWithTokenizer(tf.keras.Model):
"""A model with the tokenizer included in graph for exporting to TFLite."""
def __init__(self, tokenizer, model):
super().__init__()
self._tokenizer = tokenizer
self._model = model
@tf.function(
input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.string, name="input")
]
)
def call(self, input_tensor):
x = self._tokenizer.process_fn(input_tensor)
x = {k: tf.expand_dims(v, axis=0) for k, v in x.items()}
x = self._model(x)
return x

View File

@ -0,0 +1,105 @@
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
from unittest import mock as unittest_mock
import tensorflow as tf
import tensorflow_hub
from mediapipe.model_maker.python.core.utils import hub_loader
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
from mediapipe.model_maker.python.text.text_classifier import model_spec
from mediapipe.model_maker.python.text.text_classifier import model_with_tokenizer
class BertTokenizerTest(tf.test.TestCase):
_SEQ_LEN = 128
def setUp(self):
super().setUp()
# Mock tempfile.gettempdir() to be unique for each test to avoid race
# condition when downloading model since these tests may run in parallel.
mock_gettempdir = unittest_mock.patch.object(
tempfile,
"gettempdir",
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
self._ms = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
self._tokenizer = self._create_tokenizer()
self._model = self._create_model()
def _create_tokenizer(self):
vocab_file = os.path.join(
tensorflow_hub.resolve(self._ms.get_path()), "assets", "vocab.txt"
)
return bert_tokenizer.BertFastTokenizer(vocab_file, True, self._SEQ_LEN)
def _create_model(self):
encoder_inputs = dict(
input_word_ids=tf.keras.layers.Input(
shape=(self._SEQ_LEN,),
dtype=tf.int32,
name="input_word_ids",
),
input_mask=tf.keras.layers.Input(
shape=(self._SEQ_LEN,),
dtype=tf.int32,
name="input_mask",
),
input_type_ids=tf.keras.layers.Input(
shape=(self._SEQ_LEN,),
dtype=tf.int32,
name="input_type_ids",
),
)
renamed_inputs = dict(
input_ids=encoder_inputs["input_word_ids"],
input_mask=encoder_inputs["input_mask"],
segment_ids=encoder_inputs["input_type_ids"],
)
encoder = hub_loader.HubKerasLayerV1V2(
self._ms.get_path(),
signature="tokens",
output_key="pooled_output",
trainable=True,
)
pooled_output = encoder(renamed_inputs)
output = tf.keras.layers.Dropout(rate=0.1)(pooled_output)
initializer = tf.keras.initializers.TruncatedNormal(stddev=0.02)
output = tf.keras.layers.Dense(
2,
kernel_initializer=initializer,
name="output",
activation="softmax",
dtype=tf.float32,
)(output)
return tf.keras.Model(inputs=encoder_inputs, outputs=output)
def test_model_with_tokenizer(self):
model = model_with_tokenizer.ModelWithTokenizer(
self._tokenizer, self._model
)
output = model(tf.constant(["Example input".encode("utf-8")]))
self.assertAllEqual(output.shape, (1, 2))
self.assertEqual(tf.reduce_sum(output), 1)
if __name__ == "__main__":
tf.test.main()

View File

@ -368,6 +368,10 @@ class BertClassifierPreprocessor:
tfrecord_cache_files=tfrecord_cache_files, tfrecord_cache_files=tfrecord_cache_files,
) )
@property
def tokenizer(self) -> bert_tokenizer.BertTokenizer:
return self._tokenizer
TextClassifierPreprocessor = Union[ TextClassifierPreprocessor = Union[
BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor

View File

@ -29,10 +29,12 @@ from mediapipe.model_maker.python.core.utils import loss_functions
from mediapipe.model_maker.python.core.utils import metrics from mediapipe.model_maker.python.core.utils import metrics
from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as mo from mediapipe.model_maker.python.text.text_classifier import model_options as mo
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
from mediapipe.model_maker.python.text.text_classifier import model_with_tokenizer
from mediapipe.model_maker.python.text.text_classifier import preprocessor from mediapipe.model_maker.python.text.text_classifier import preprocessor
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
@ -620,3 +622,56 @@ class _BertClassifier(TextClassifier):
ids_name=self._model_spec.tflite_input_name["ids"], ids_name=self._model_spec.tflite_input_name["ids"],
mask_name=self._model_spec.tflite_input_name["mask"], mask_name=self._model_spec.tflite_input_name["mask"],
segment_name=self._model_spec.tflite_input_name["segment_ids"]) segment_name=self._model_spec.tflite_input_name["segment_ids"])
def export_model_with_tokenizer(
self,
model_name: str = "model_with_tokenizer.tflite",
quantization_config: Optional[quantization.QuantizationConfig] = None,
):
"""Converts and saves the model to a TFLite file with the tokenizer.
Note that unlike the export_model method, this export method will include
a FastBertTokenizer in the TFLite graph. The resulting TFLite will not have
metadata information to use with MediaPipe Tasks, but can be run directly
using TFLite Inference: https://www.tensorflow.org/lite/guide/inference
For more information on the tokenizer, see:
https://www.tensorflow.org/text/api_docs/python/text/FastBertTokenizer
Args:
model_name: File name to save TFLite model with tokenizer. The full export
path is {self._hparams.export_dir}/{model_name}.
quantization_config: The configuration for model quantization.
"""
tf.io.gfile.makedirs(self._hparams.export_dir)
tflite_file = os.path.join(self._hparams.export_dir, model_name)
if (
self._hparams.tokenizer
!= bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER
):
print(
f"WARNING: This model was trained with {self._hparams.tokenizer} "
"tokenizer, but the exported model with tokenizer will have a "
f"{bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER} "
"tokenizer."
)
tokenizer = bert_tokenizer.BertFastTokenizer(
vocab_file=self._text_preprocessor.get_vocab_file(),
do_lower_case=self._model_spec.do_lower_case,
seq_len=self._model_options.seq_len,
)
else:
tokenizer = self._text_preprocessor.tokenizer
model = model_with_tokenizer.ModelWithTokenizer(tokenizer, self._model)
model(tf.constant(["Example input data".encode("utf-8")])) # build model
saved_model_file = os.path.join(
self._hparams.export_dir, "saved_model_with_tokenizer"
)
model.save(saved_model_file)
tflite_model = model_util.convert_to_tflite_from_file(
saved_model_file,
quantization_config=quantization_config,
allow_custom_ops=True,
)
model_util.save_tflite(tflite_model, tflite_file)

View File

@ -149,6 +149,12 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
output_metadata_file, self._BERT_CLASSIFIER_JSON_FILE, shallow=False output_metadata_file, self._BERT_CLASSIFIER_JSON_FILE, shallow=False
) )
) )
bert_classifier.export_model_with_tokenizer()
output_tflite_with_tokenizer_file = os.path.join(
options.hparams.export_dir, 'model_with_tokenizer.tflite'
)
self.assertTrue(os.path.exists(output_tflite_with_tokenizer_file))
self.assertGreater(os.path.getsize(output_tflite_with_tokenizer_file), 0)
def test_label_mismatch(self): def test_label_mismatch(self):
options = text_classifier.TextClassifierOptions( options = text_classifier.TextClassifierOptions(