Add export_model_with_tokenizer to Text Classifier API.
PiperOrigin-RevId: 567744604
This commit is contained in:
parent
9d85141227
commit
573fdad173
|
@ -93,6 +93,23 @@ py_test(
|
|||
deps = [":dataset"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_with_tokenizer",
|
||||
srcs = ["model_with_tokenizer.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "model_with_tokenizer_test",
|
||||
srcs = ["model_with_tokenizer_test.py"],
|
||||
tags = ["requires-net:external"],
|
||||
deps = [
|
||||
":bert_tokenizer",
|
||||
":model_spec",
|
||||
":model_with_tokenizer",
|
||||
"//mediapipe/model_maker/python/core/utils:hub_loader",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "bert_tokenizer",
|
||||
srcs = ["bert_tokenizer.py"],
|
||||
|
@ -145,10 +162,12 @@ py_library(
|
|||
name = "text_classifier",
|
||||
srcs = ["text_classifier.py"],
|
||||
deps = [
|
||||
":bert_tokenizer",
|
||||
":dataset",
|
||||
":hyperparameters",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
":model_with_tokenizer",
|
||||
":preprocessor",
|
||||
":text_classifier_options",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
|
@ -165,7 +184,7 @@ py_library(
|
|||
|
||||
py_test(
|
||||
name = "text_classifier_test",
|
||||
size = "large",
|
||||
size = "enormous",
|
||||
srcs = ["text_classifier_test.py"],
|
||||
data = [
|
||||
"//mediapipe/model_maker/python/text/text_classifier/testdata",
|
||||
|
|
|
@ -56,6 +56,15 @@ class BertFullTokenizer(BertTokenizer):
|
|||
self._seq_len = seq_len
|
||||
|
||||
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
|
||||
"""Processes one input_tensor example.
|
||||
|
||||
Args:
|
||||
input_tensor: A tensor with shape (1, None) of a utf-8 encoded string.
|
||||
|
||||
Returns:
|
||||
A dictionary of lists all with shape (1, self._seq_len) containing the
|
||||
keys "input_word_ids", "input_type_ids", and "input_mask".
|
||||
"""
|
||||
tokens = self._tokenizer.tokenize(input_tensor.numpy()[0].decode("utf-8"))
|
||||
tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP]
|
||||
tokens.insert(0, "[CLS]")
|
||||
|
@ -96,7 +105,18 @@ class BertFastTokenizer(BertTokenizer):
|
|||
self._sep_id = vocab.index("[SEP]")
|
||||
self._pad_id = vocab.index("[PAD]")
|
||||
|
||||
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
|
||||
def process_fn(self, input_tensor: tf.Tensor) -> Mapping[str, tf.Tensor]:
|
||||
"""Tensor implementation of the process function.
|
||||
|
||||
This implementation can be used within a model graph directly since it
|
||||
takes in tensors and outputs tensors.
|
||||
|
||||
Args:
|
||||
input_tensor: Input string tensor
|
||||
|
||||
Returns:
|
||||
Dictionary of tf.Tensors.
|
||||
"""
|
||||
input_ids = self._tokenizer.tokenize(input_tensor).flat_values
|
||||
input_ids = input_ids[: (self._seq_len - 2)]
|
||||
input_ids = tf.concat(
|
||||
|
@ -112,7 +132,20 @@ class BertFastTokenizer(BertTokenizer):
|
|||
input_type_ids = tf.zeros(self._seq_len, dtype=tf.int32)
|
||||
input_mask = tf.cast(input_ids != self._pad_id, dtype=tf.int32)
|
||||
return {
|
||||
"input_word_ids": input_ids.numpy().tolist(),
|
||||
"input_type_ids": input_type_ids.numpy().tolist(),
|
||||
"input_mask": input_mask.numpy().tolist(),
|
||||
"input_word_ids": input_ids,
|
||||
"input_type_ids": input_type_ids,
|
||||
"input_mask": input_mask,
|
||||
}
|
||||
|
||||
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
|
||||
"""Processes one input_tensor example.
|
||||
|
||||
Args:
|
||||
input_tensor: A tensor with shape (1, None) of a utf-8 encoded string.
|
||||
|
||||
Returns:
|
||||
A dictionary of lists all with shape (1, self._seq_len) containing the
|
||||
keys "input_word_ids", "input_type_ids", and "input_mask".
|
||||
"""
|
||||
result = self.process_fn(input_tensor)
|
||||
return {k: v.numpy().tolist() for k, v in result.items()}
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Text classifier export module library."""
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class ModelWithTokenizer(tf.keras.Model):
|
||||
"""A model with the tokenizer included in graph for exporting to TFLite."""
|
||||
|
||||
def __init__(self, tokenizer, model):
|
||||
super().__init__()
|
||||
self._tokenizer = tokenizer
|
||||
self._model = model
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
tf.TensorSpec(shape=[None], dtype=tf.string, name="input")
|
||||
]
|
||||
)
|
||||
def call(self, input_tensor):
|
||||
x = self._tokenizer.process_fn(input_tensor)
|
||||
x = {k: tf.expand_dims(v, axis=0) for k, v in x.items()}
|
||||
x = self._model(x)
|
||||
return x
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2022 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import hub_loader
|
||||
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_with_tokenizer
|
||||
|
||||
|
||||
class BertTokenizerTest(tf.test.TestCase):
|
||||
_SEQ_LEN = 128
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||
# condition when downloading model since these tests may run in parallel.
|
||||
mock_gettempdir = unittest_mock.patch.object(
|
||||
tempfile,
|
||||
"gettempdir",
|
||||
return_value=self.create_tempdir(),
|
||||
autospec=True,
|
||||
)
|
||||
self.mock_gettempdir = mock_gettempdir.start()
|
||||
self.addCleanup(mock_gettempdir.stop)
|
||||
self._ms = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||
self._tokenizer = self._create_tokenizer()
|
||||
self._model = self._create_model()
|
||||
|
||||
def _create_tokenizer(self):
|
||||
vocab_file = os.path.join(
|
||||
tensorflow_hub.resolve(self._ms.get_path()), "assets", "vocab.txt"
|
||||
)
|
||||
return bert_tokenizer.BertFastTokenizer(vocab_file, True, self._SEQ_LEN)
|
||||
|
||||
def _create_model(self):
|
||||
encoder_inputs = dict(
|
||||
input_word_ids=tf.keras.layers.Input(
|
||||
shape=(self._SEQ_LEN,),
|
||||
dtype=tf.int32,
|
||||
name="input_word_ids",
|
||||
),
|
||||
input_mask=tf.keras.layers.Input(
|
||||
shape=(self._SEQ_LEN,),
|
||||
dtype=tf.int32,
|
||||
name="input_mask",
|
||||
),
|
||||
input_type_ids=tf.keras.layers.Input(
|
||||
shape=(self._SEQ_LEN,),
|
||||
dtype=tf.int32,
|
||||
name="input_type_ids",
|
||||
),
|
||||
)
|
||||
renamed_inputs = dict(
|
||||
input_ids=encoder_inputs["input_word_ids"],
|
||||
input_mask=encoder_inputs["input_mask"],
|
||||
segment_ids=encoder_inputs["input_type_ids"],
|
||||
)
|
||||
encoder = hub_loader.HubKerasLayerV1V2(
|
||||
self._ms.get_path(),
|
||||
signature="tokens",
|
||||
output_key="pooled_output",
|
||||
trainable=True,
|
||||
)
|
||||
pooled_output = encoder(renamed_inputs)
|
||||
|
||||
output = tf.keras.layers.Dropout(rate=0.1)(pooled_output)
|
||||
initializer = tf.keras.initializers.TruncatedNormal(stddev=0.02)
|
||||
output = tf.keras.layers.Dense(
|
||||
2,
|
||||
kernel_initializer=initializer,
|
||||
name="output",
|
||||
activation="softmax",
|
||||
dtype=tf.float32,
|
||||
)(output)
|
||||
return tf.keras.Model(inputs=encoder_inputs, outputs=output)
|
||||
|
||||
def test_model_with_tokenizer(self):
|
||||
model = model_with_tokenizer.ModelWithTokenizer(
|
||||
self._tokenizer, self._model
|
||||
)
|
||||
output = model(tf.constant(["Example input".encode("utf-8")]))
|
||||
self.assertAllEqual(output.shape, (1, 2))
|
||||
self.assertEqual(tf.reduce_sum(output), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
|
@ -368,6 +368,10 @@ class BertClassifierPreprocessor:
|
|||
tfrecord_cache_files=tfrecord_cache_files,
|
||||
)
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> bert_tokenizer.BertTokenizer:
|
||||
return self._tokenizer
|
||||
|
||||
|
||||
TextClassifierPreprocessor = Union[
|
||||
BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor
|
||||
|
|
|
@ -29,10 +29,12 @@ from mediapipe.model_maker.python.core.utils import loss_functions
|
|||
from mediapipe.model_maker.python.core.utils import metrics
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_with_tokenizer
|
||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
|
@ -620,3 +622,56 @@ class _BertClassifier(TextClassifier):
|
|||
ids_name=self._model_spec.tflite_input_name["ids"],
|
||||
mask_name=self._model_spec.tflite_input_name["mask"],
|
||||
segment_name=self._model_spec.tflite_input_name["segment_ids"])
|
||||
|
||||
def export_model_with_tokenizer(
|
||||
self,
|
||||
model_name: str = "model_with_tokenizer.tflite",
|
||||
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||
):
|
||||
"""Converts and saves the model to a TFLite file with the tokenizer.
|
||||
|
||||
Note that unlike the export_model method, this export method will include
|
||||
a FastBertTokenizer in the TFLite graph. The resulting TFLite will not have
|
||||
metadata information to use with MediaPipe Tasks, but can be run directly
|
||||
using TFLite Inference: https://www.tensorflow.org/lite/guide/inference
|
||||
|
||||
For more information on the tokenizer, see:
|
||||
https://www.tensorflow.org/text/api_docs/python/text/FastBertTokenizer
|
||||
|
||||
Args:
|
||||
model_name: File name to save TFLite model with tokenizer. The full export
|
||||
path is {self._hparams.export_dir}/{model_name}.
|
||||
quantization_config: The configuration for model quantization.
|
||||
"""
|
||||
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
||||
if (
|
||||
self._hparams.tokenizer
|
||||
!= bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER
|
||||
):
|
||||
print(
|
||||
f"WARNING: This model was trained with {self._hparams.tokenizer} "
|
||||
"tokenizer, but the exported model with tokenizer will have a "
|
||||
f"{bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER} "
|
||||
"tokenizer."
|
||||
)
|
||||
tokenizer = bert_tokenizer.BertFastTokenizer(
|
||||
vocab_file=self._text_preprocessor.get_vocab_file(),
|
||||
do_lower_case=self._model_spec.do_lower_case,
|
||||
seq_len=self._model_options.seq_len,
|
||||
)
|
||||
else:
|
||||
tokenizer = self._text_preprocessor.tokenizer
|
||||
|
||||
model = model_with_tokenizer.ModelWithTokenizer(tokenizer, self._model)
|
||||
model(tf.constant(["Example input data".encode("utf-8")])) # build model
|
||||
saved_model_file = os.path.join(
|
||||
self._hparams.export_dir, "saved_model_with_tokenizer"
|
||||
)
|
||||
model.save(saved_model_file)
|
||||
tflite_model = model_util.convert_to_tflite_from_file(
|
||||
saved_model_file,
|
||||
quantization_config=quantization_config,
|
||||
allow_custom_ops=True,
|
||||
)
|
||||
model_util.save_tflite(tflite_model, tflite_file)
|
||||
|
|
|
@ -149,6 +149,12 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
output_metadata_file, self._BERT_CLASSIFIER_JSON_FILE, shallow=False
|
||||
)
|
||||
)
|
||||
bert_classifier.export_model_with_tokenizer()
|
||||
output_tflite_with_tokenizer_file = os.path.join(
|
||||
options.hparams.export_dir, 'model_with_tokenizer.tflite'
|
||||
)
|
||||
self.assertTrue(os.path.exists(output_tflite_with_tokenizer_file))
|
||||
self.assertGreater(os.path.getsize(output_tflite_with_tokenizer_file), 0)
|
||||
|
||||
def test_label_mismatch(self):
|
||||
options = text_classifier.TextClassifierOptions(
|
||||
|
|
Loading…
Reference in New Issue
Block a user