Add export_model_with_tokenizer to Text Classifier API.
PiperOrigin-RevId: 567744604
This commit is contained in:
parent
9d85141227
commit
573fdad173
|
@ -93,6 +93,23 @@ py_test(
|
||||||
deps = [":dataset"],
|
deps = [":dataset"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "model_with_tokenizer",
|
||||||
|
srcs = ["model_with_tokenizer.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "model_with_tokenizer_test",
|
||||||
|
srcs = ["model_with_tokenizer_test.py"],
|
||||||
|
tags = ["requires-net:external"],
|
||||||
|
deps = [
|
||||||
|
":bert_tokenizer",
|
||||||
|
":model_spec",
|
||||||
|
":model_with_tokenizer",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:hub_loader",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "bert_tokenizer",
|
name = "bert_tokenizer",
|
||||||
srcs = ["bert_tokenizer.py"],
|
srcs = ["bert_tokenizer.py"],
|
||||||
|
@ -145,10 +162,12 @@ py_library(
|
||||||
name = "text_classifier",
|
name = "text_classifier",
|
||||||
srcs = ["text_classifier.py"],
|
srcs = ["text_classifier.py"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":bert_tokenizer",
|
||||||
":dataset",
|
":dataset",
|
||||||
":hyperparameters",
|
":hyperparameters",
|
||||||
":model_options",
|
":model_options",
|
||||||
":model_spec",
|
":model_spec",
|
||||||
|
":model_with_tokenizer",
|
||||||
":preprocessor",
|
":preprocessor",
|
||||||
":text_classifier_options",
|
":text_classifier_options",
|
||||||
"//mediapipe/model_maker/python/core/data:dataset",
|
"//mediapipe/model_maker/python/core/data:dataset",
|
||||||
|
@ -165,7 +184,7 @@ py_library(
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "text_classifier_test",
|
name = "text_classifier_test",
|
||||||
size = "large",
|
size = "enormous",
|
||||||
srcs = ["text_classifier_test.py"],
|
srcs = ["text_classifier_test.py"],
|
||||||
data = [
|
data = [
|
||||||
"//mediapipe/model_maker/python/text/text_classifier/testdata",
|
"//mediapipe/model_maker/python/text/text_classifier/testdata",
|
||||||
|
|
|
@ -56,6 +56,15 @@ class BertFullTokenizer(BertTokenizer):
|
||||||
self._seq_len = seq_len
|
self._seq_len = seq_len
|
||||||
|
|
||||||
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
|
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
|
||||||
|
"""Processes one input_tensor example.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensor: A tensor with shape (1, None) of a utf-8 encoded string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of lists all with shape (1, self._seq_len) containing the
|
||||||
|
keys "input_word_ids", "input_type_ids", and "input_mask".
|
||||||
|
"""
|
||||||
tokens = self._tokenizer.tokenize(input_tensor.numpy()[0].decode("utf-8"))
|
tokens = self._tokenizer.tokenize(input_tensor.numpy()[0].decode("utf-8"))
|
||||||
tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP]
|
tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP]
|
||||||
tokens.insert(0, "[CLS]")
|
tokens.insert(0, "[CLS]")
|
||||||
|
@ -96,7 +105,18 @@ class BertFastTokenizer(BertTokenizer):
|
||||||
self._sep_id = vocab.index("[SEP]")
|
self._sep_id = vocab.index("[SEP]")
|
||||||
self._pad_id = vocab.index("[PAD]")
|
self._pad_id = vocab.index("[PAD]")
|
||||||
|
|
||||||
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
|
def process_fn(self, input_tensor: tf.Tensor) -> Mapping[str, tf.Tensor]:
|
||||||
|
"""Tensor implementation of the process function.
|
||||||
|
|
||||||
|
This implementation can be used within a model graph directly since it
|
||||||
|
takes in tensors and outputs tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensor: Input string tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of tf.Tensors.
|
||||||
|
"""
|
||||||
input_ids = self._tokenizer.tokenize(input_tensor).flat_values
|
input_ids = self._tokenizer.tokenize(input_tensor).flat_values
|
||||||
input_ids = input_ids[: (self._seq_len - 2)]
|
input_ids = input_ids[: (self._seq_len - 2)]
|
||||||
input_ids = tf.concat(
|
input_ids = tf.concat(
|
||||||
|
@ -112,7 +132,20 @@ class BertFastTokenizer(BertTokenizer):
|
||||||
input_type_ids = tf.zeros(self._seq_len, dtype=tf.int32)
|
input_type_ids = tf.zeros(self._seq_len, dtype=tf.int32)
|
||||||
input_mask = tf.cast(input_ids != self._pad_id, dtype=tf.int32)
|
input_mask = tf.cast(input_ids != self._pad_id, dtype=tf.int32)
|
||||||
return {
|
return {
|
||||||
"input_word_ids": input_ids.numpy().tolist(),
|
"input_word_ids": input_ids,
|
||||||
"input_type_ids": input_type_ids.numpy().tolist(),
|
"input_type_ids": input_type_ids,
|
||||||
"input_mask": input_mask.numpy().tolist(),
|
"input_mask": input_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
|
||||||
|
"""Processes one input_tensor example.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensor: A tensor with shape (1, None) of a utf-8 encoded string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of lists all with shape (1, self._seq_len) containing the
|
||||||
|
keys "input_word_ids", "input_type_ids", and "input_mask".
|
||||||
|
"""
|
||||||
|
result = self.process_fn(input_tensor)
|
||||||
|
return {k: v.numpy().tolist() for k, v in result.items()}
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Text classifier export module library."""
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class ModelWithTokenizer(tf.keras.Model):
|
||||||
|
"""A model with the tokenizer included in graph for exporting to TFLite."""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer, model):
|
||||||
|
super().__init__()
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
@tf.function(
|
||||||
|
input_signature=[
|
||||||
|
tf.TensorSpec(shape=[None], dtype=tf.string, name="input")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def call(self, input_tensor):
|
||||||
|
x = self._tokenizer.process_fn(input_tensor)
|
||||||
|
x = {k: tf.expand_dims(v, axis=0) for k, v in x.items()}
|
||||||
|
x = self._model(x)
|
||||||
|
return x
|
|
@ -0,0 +1,105 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_hub
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import hub_loader
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_with_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class BertTokenizerTest(tf.test.TestCase):
|
||||||
|
_SEQ_LEN = 128
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||||
|
# condition when downloading model since these tests may run in parallel.
|
||||||
|
mock_gettempdir = unittest_mock.patch.object(
|
||||||
|
tempfile,
|
||||||
|
"gettempdir",
|
||||||
|
return_value=self.create_tempdir(),
|
||||||
|
autospec=True,
|
||||||
|
)
|
||||||
|
self.mock_gettempdir = mock_gettempdir.start()
|
||||||
|
self.addCleanup(mock_gettempdir.stop)
|
||||||
|
self._ms = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||||
|
self._tokenizer = self._create_tokenizer()
|
||||||
|
self._model = self._create_model()
|
||||||
|
|
||||||
|
def _create_tokenizer(self):
|
||||||
|
vocab_file = os.path.join(
|
||||||
|
tensorflow_hub.resolve(self._ms.get_path()), "assets", "vocab.txt"
|
||||||
|
)
|
||||||
|
return bert_tokenizer.BertFastTokenizer(vocab_file, True, self._SEQ_LEN)
|
||||||
|
|
||||||
|
def _create_model(self):
|
||||||
|
encoder_inputs = dict(
|
||||||
|
input_word_ids=tf.keras.layers.Input(
|
||||||
|
shape=(self._SEQ_LEN,),
|
||||||
|
dtype=tf.int32,
|
||||||
|
name="input_word_ids",
|
||||||
|
),
|
||||||
|
input_mask=tf.keras.layers.Input(
|
||||||
|
shape=(self._SEQ_LEN,),
|
||||||
|
dtype=tf.int32,
|
||||||
|
name="input_mask",
|
||||||
|
),
|
||||||
|
input_type_ids=tf.keras.layers.Input(
|
||||||
|
shape=(self._SEQ_LEN,),
|
||||||
|
dtype=tf.int32,
|
||||||
|
name="input_type_ids",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
renamed_inputs = dict(
|
||||||
|
input_ids=encoder_inputs["input_word_ids"],
|
||||||
|
input_mask=encoder_inputs["input_mask"],
|
||||||
|
segment_ids=encoder_inputs["input_type_ids"],
|
||||||
|
)
|
||||||
|
encoder = hub_loader.HubKerasLayerV1V2(
|
||||||
|
self._ms.get_path(),
|
||||||
|
signature="tokens",
|
||||||
|
output_key="pooled_output",
|
||||||
|
trainable=True,
|
||||||
|
)
|
||||||
|
pooled_output = encoder(renamed_inputs)
|
||||||
|
|
||||||
|
output = tf.keras.layers.Dropout(rate=0.1)(pooled_output)
|
||||||
|
initializer = tf.keras.initializers.TruncatedNormal(stddev=0.02)
|
||||||
|
output = tf.keras.layers.Dense(
|
||||||
|
2,
|
||||||
|
kernel_initializer=initializer,
|
||||||
|
name="output",
|
||||||
|
activation="softmax",
|
||||||
|
dtype=tf.float32,
|
||||||
|
)(output)
|
||||||
|
return tf.keras.Model(inputs=encoder_inputs, outputs=output)
|
||||||
|
|
||||||
|
def test_model_with_tokenizer(self):
|
||||||
|
model = model_with_tokenizer.ModelWithTokenizer(
|
||||||
|
self._tokenizer, self._model
|
||||||
|
)
|
||||||
|
output = model(tf.constant(["Example input".encode("utf-8")]))
|
||||||
|
self.assertAllEqual(output.shape, (1, 2))
|
||||||
|
self.assertEqual(tf.reduce_sum(output), 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.test.main()
|
|
@ -368,6 +368,10 @@ class BertClassifierPreprocessor:
|
||||||
tfrecord_cache_files=tfrecord_cache_files,
|
tfrecord_cache_files=tfrecord_cache_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokenizer(self) -> bert_tokenizer.BertTokenizer:
|
||||||
|
return self._tokenizer
|
||||||
|
|
||||||
|
|
||||||
TextClassifierPreprocessor = Union[
|
TextClassifierPreprocessor = Union[
|
||||||
BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor
|
BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor
|
||||||
|
|
|
@ -29,10 +29,12 @@ from mediapipe.model_maker.python.core.utils import loss_functions
|
||||||
from mediapipe.model_maker.python.core.utils import metrics
|
from mediapipe.model_maker.python.core.utils import metrics
|
||||||
from mediapipe.model_maker.python.core.utils import model_util
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
from mediapipe.model_maker.python.core.utils import quantization
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
|
||||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
||||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
|
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_with_tokenizer
|
||||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
|
@ -620,3 +622,56 @@ class _BertClassifier(TextClassifier):
|
||||||
ids_name=self._model_spec.tflite_input_name["ids"],
|
ids_name=self._model_spec.tflite_input_name["ids"],
|
||||||
mask_name=self._model_spec.tflite_input_name["mask"],
|
mask_name=self._model_spec.tflite_input_name["mask"],
|
||||||
segment_name=self._model_spec.tflite_input_name["segment_ids"])
|
segment_name=self._model_spec.tflite_input_name["segment_ids"])
|
||||||
|
|
||||||
|
def export_model_with_tokenizer(
|
||||||
|
self,
|
||||||
|
model_name: str = "model_with_tokenizer.tflite",
|
||||||
|
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
"""Converts and saves the model to a TFLite file with the tokenizer.
|
||||||
|
|
||||||
|
Note that unlike the export_model method, this export method will include
|
||||||
|
a FastBertTokenizer in the TFLite graph. The resulting TFLite will not have
|
||||||
|
metadata information to use with MediaPipe Tasks, but can be run directly
|
||||||
|
using TFLite Inference: https://www.tensorflow.org/lite/guide/inference
|
||||||
|
|
||||||
|
For more information on the tokenizer, see:
|
||||||
|
https://www.tensorflow.org/text/api_docs/python/text/FastBertTokenizer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: File name to save TFLite model with tokenizer. The full export
|
||||||
|
path is {self._hparams.export_dir}/{model_name}.
|
||||||
|
quantization_config: The configuration for model quantization.
|
||||||
|
"""
|
||||||
|
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||||
|
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
||||||
|
if (
|
||||||
|
self._hparams.tokenizer
|
||||||
|
!= bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"WARNING: This model was trained with {self._hparams.tokenizer} "
|
||||||
|
"tokenizer, but the exported model with tokenizer will have a "
|
||||||
|
f"{bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER} "
|
||||||
|
"tokenizer."
|
||||||
|
)
|
||||||
|
tokenizer = bert_tokenizer.BertFastTokenizer(
|
||||||
|
vocab_file=self._text_preprocessor.get_vocab_file(),
|
||||||
|
do_lower_case=self._model_spec.do_lower_case,
|
||||||
|
seq_len=self._model_options.seq_len,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tokenizer = self._text_preprocessor.tokenizer
|
||||||
|
|
||||||
|
model = model_with_tokenizer.ModelWithTokenizer(tokenizer, self._model)
|
||||||
|
model(tf.constant(["Example input data".encode("utf-8")])) # build model
|
||||||
|
saved_model_file = os.path.join(
|
||||||
|
self._hparams.export_dir, "saved_model_with_tokenizer"
|
||||||
|
)
|
||||||
|
model.save(saved_model_file)
|
||||||
|
tflite_model = model_util.convert_to_tflite_from_file(
|
||||||
|
saved_model_file,
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
allow_custom_ops=True,
|
||||||
|
)
|
||||||
|
model_util.save_tflite(tflite_model, tflite_file)
|
||||||
|
|
|
@ -149,6 +149,12 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
output_metadata_file, self._BERT_CLASSIFIER_JSON_FILE, shallow=False
|
output_metadata_file, self._BERT_CLASSIFIER_JSON_FILE, shallow=False
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
bert_classifier.export_model_with_tokenizer()
|
||||||
|
output_tflite_with_tokenizer_file = os.path.join(
|
||||||
|
options.hparams.export_dir, 'model_with_tokenizer.tflite'
|
||||||
|
)
|
||||||
|
self.assertTrue(os.path.exists(output_tflite_with_tokenizer_file))
|
||||||
|
self.assertGreater(os.path.getsize(output_tflite_with_tokenizer_file), 0)
|
||||||
|
|
||||||
def test_label_mismatch(self):
|
def test_label_mismatch(self):
|
||||||
options = text_classifier.TextClassifierOptions(
|
options = text_classifier.TextClassifierOptions(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user