No public description

PiperOrigin-RevId: 566435327
This commit is contained in:
MediaPipe Team 2023-09-18 15:47:29 -07:00 committed by Copybara-Service
parent 36f78f6e4a
commit 94cda40a83
9 changed files with 322 additions and 44 deletions

View File

@ -30,10 +30,12 @@ py_library(
srcs = ["__init__.py"],
visibility = ["//visibility:public"],
deps = [
":bert_tokenizer",
":dataset",
":hyperparameters",
":model_options",
":model_spec",
":preprocessor",
":text_classifier",
":text_classifier_options",
],
@ -48,7 +50,10 @@ py_library(
py_library(
name = "hyperparameters",
srcs = ["hyperparameters.py"],
deps = ["//mediapipe/model_maker/python/core:hyperparameters"],
deps = [
":bert_tokenizer",
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
py_library(
@ -88,10 +93,26 @@ py_test(
deps = [":dataset"],
)
py_library(
name = "bert_tokenizer",
srcs = ["bert_tokenizer.py"],
)
py_test(
name = "bert_tokenizer_test",
srcs = ["bert_tokenizer_test.py"],
tags = ["requires-net:external"],
deps = [
":bert_tokenizer",
":model_spec",
],
)
py_library(
name = "preprocessor",
srcs = ["preprocessor.py"],
deps = [
":bert_tokenizer",
":dataset",
"//mediapipe/model_maker/python/core/data:cache_files",
],
@ -102,6 +123,7 @@ py_test(
srcs = ["preprocessor_test.py"],
tags = ["requires-net:external"],
deps = [
":bert_tokenizer",
":dataset",
":model_spec",
":preprocessor",

View File

@ -13,10 +13,12 @@
# limitations under the License.
"""MediaPipe Public Python API for Text Classifier."""
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
from mediapipe.model_maker.python.text.text_classifier import dataset
from mediapipe.model_maker.python.text.text_classifier import hyperparameters
from mediapipe.model_maker.python.text.text_classifier import model_options
from mediapipe.model_maker.python.text.text_classifier import model_spec
from mediapipe.model_maker.python.text.text_classifier import preprocessor
from mediapipe.model_maker.python.text.text_classifier import text_classifier
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
@ -33,12 +35,14 @@ Dataset = dataset.Dataset
SupportedModels = model_spec.SupportedModels
TextClassifier = text_classifier.TextClassifier
TextClassifierOptions = text_classifier_options.TextClassifierOptions
SupportedBertTokenizers = bert_tokenizer.SupportedBertTokenizers
# Remove duplicated and non-public API
del bert_tokenizer
del hyperparameters
del dataset
del model_options
del model_spec
del preprocessor # pylint: disable=undefined-variable
del preprocessor
del text_classifier
del text_classifier_options

View File

@ -0,0 +1,118 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text classifier BERT tokenizer library."""
import abc
import enum
from typing import Mapping, Sequence
import tensorflow as tf
import tensorflow_text as tf_text
from official.nlp.tools import tokenization
@enum.unique
class SupportedBertTokenizers(enum.Enum):
"""Supported preprocessors."""
FULL_TOKENIZER = "fulltokenizer"
FAST_BERT_TOKENIZER = "fastberttokenizer"
class BertTokenizer(abc.ABC):
"""Abstract BertTokenizer class."""
name: str
@abc.abstractmethod
def __init__(self, vocab_file: str, do_lower_case: bool, seq_len: int):
pass
@abc.abstractmethod
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
pass
class BertFullTokenizer(BertTokenizer):
"""Tokenizer using the FullTokenizer from tensorflow_models."""
name = "fulltokenizer"
def __init__(self, vocab_file: str, do_lower_case: bool, seq_len: int):
self._tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_file, do_lower_case=do_lower_case
)
self._seq_len = seq_len
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
tokens = self._tokenizer.tokenize(input_tensor.numpy()[0].decode("utf-8"))
tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP]
tokens.insert(0, "[CLS]")
tokens.append("[SEP]")
input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < self._seq_len:
input_ids.append(0)
input_mask.append(0)
segment_ids = [0] * self._seq_len
return {
"input_word_ids": input_ids,
"input_type_ids": segment_ids,
"input_mask": input_mask,
}
class BertFastTokenizer(BertTokenizer):
"""Tokenizer using the FastBertTokenizer from tensorflow_text.
For more information, see:
https://www.tensorflow.org/text/api_docs/python/text/FastBertTokenizer
"""
name = "fastberttokenizer"
def __init__(self, vocab_file: str, do_lower_case: bool, seq_len: int):
with tf.io.gfile.GFile(vocab_file, "r") as f:
vocab = f.read().splitlines()
self._tokenizer = tf_text.FastBertTokenizer(
vocab=vocab,
token_out_type=tf.int32,
support_detokenization=False,
lower_case_nfd_strip_accents=do_lower_case,
)
self._seq_len = seq_len
self._cls_id = vocab.index("[CLS]")
self._sep_id = vocab.index("[SEP]")
self._pad_id = vocab.index("[PAD]")
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
input_ids = self._tokenizer.tokenize(input_tensor).flat_values
input_ids = input_ids[: (self._seq_len - 2)]
input_ids = tf.concat(
[
tf.constant([self._cls_id]),
input_ids,
tf.constant([self._sep_id]),
tf.fill((self._seq_len,), self._pad_id),
],
axis=0,
)
input_ids = input_ids[: self._seq_len]
input_type_ids = tf.zeros(self._seq_len, dtype=tf.int32)
input_mask = tf.cast(input_ids != self._pad_id, dtype=tf.int32)
return {
"input_word_ids": input_ids.numpy().tolist(),
"input_type_ids": input_type_ids.numpy().tolist(),
"input_mask": input_mask.numpy().tolist(),
}

View File

@ -0,0 +1,91 @@
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
from unittest import mock as unittest_mock
from absl.testing import parameterized
import tensorflow as tf
import tensorflow_hub
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
from mediapipe.model_maker.python.text.text_classifier import model_spec
class BertTokenizerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
# Mock tempfile.gettempdir() to be unique for each test to avoid race
# condition when downloading model since these tests may run in parallel.
mock_gettempdir = unittest_mock.patch.object(
tempfile,
'gettempdir',
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
ms = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
self._vocab_file = os.path.join(
tensorflow_hub.resolve(ms.get_path()), 'assets', 'vocab.txt'
)
@parameterized.named_parameters(
dict(
testcase_name='fulltokenizer',
tokenizer_class=bert_tokenizer.BertFullTokenizer,
),
dict(
testcase_name='fasttokenizer',
tokenizer_class=bert_tokenizer.BertFastTokenizer,
),
)
def test_bert_full_tokenizer(self, tokenizer_class):
tokenizer = tokenizer_class(self._vocab_file, True, 16)
text_input = tf.constant(['this is an éxamplé input ¿foo'.encode('utf-8')])
result = tokenizer.process(text_input)
self.assertAllEqual(
result['input_word_ids'],
[
101,
2023,
2003,
2019,
2742,
7953,
1094,
29379,
102,
0,
0,
0,
0,
0,
0,
0,
],
)
self.assertAllEqual(
result['input_mask'], [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
)
self.assertAllEqual(
result['input_type_ids'],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
)
if __name__ == '__main__':
tf.test.main()

View File

@ -18,6 +18,7 @@ import enum
from typing import Sequence, Union
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
@dataclasses.dataclass
@ -53,6 +54,8 @@ class BertHParams(hp.BaseHParams):
on recall. Only supported for binary classification.
gamma: Gamma parameter for focal loss. To use cross entropy loss, set this
value to 0. Defaults to 2.0.
tokenizer: Tokenizer to use for preprocessing. Must be one of the enum
options of SupportedBertTokenizers. Defaults to FULL_TOKENIZER.
"""
learning_rate: float = 3e-5
@ -68,5 +71,9 @@ class BertHParams(hp.BaseHParams):
gamma: float = 2.0
tokenizer: bert_tokenizer.SupportedBertTokenizers = (
bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
)
HParams = Union[BertHParams, AverageWordEmbeddingHParams]

View File

@ -24,8 +24,8 @@ import tensorflow as tf
import tensorflow_hub
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
from official.nlp.tools import tokenization
def _validate_text_and_label(text: tf.Tensor, label: tf.Tensor) -> None:
@ -68,9 +68,9 @@ def _decode_record(
example[name] = tf.cast(example[name], tf.int32)
bert_features = {
"input_word_ids": example["input_ids"],
"input_word_ids": example["input_word_ids"],
"input_mask": example["input_mask"],
"input_type_ids": example["segment_ids"]
"input_type_ids": example["input_type_ids"],
}
return bert_features, example["label_ids"]
@ -224,10 +224,17 @@ class BertClassifierPreprocessor:
tokenizer: BERT tokenizer.
model_name: Name of the model provided by the model_spec. Used to associate
cached files with specific Bert model vocab.
preprocessor: Which preprocessor to use. Must be one of the enum values of
SupportedBertPreprocessors.
"""
def __init__(
self, seq_len: int, do_lower_case: bool, uri: str, model_name: str
self,
seq_len: int,
do_lower_case: bool,
uri: str,
model_name: str,
tokenizer: bert_tokenizer.SupportedBertTokenizers,
):
self._seq_len = seq_len
# Vocab filepath is tied to the BERT module's URI.
@ -235,17 +242,27 @@ class BertClassifierPreprocessor:
tensorflow_hub.resolve(uri), "assets", "vocab.txt"
)
self._do_lower_case = do_lower_case
self._tokenizer = tokenization.FullTokenizer(
self._vocab_file, self._do_lower_case
self._tokenizer: bert_tokenizer.BertTokenizer = None
if tokenizer == bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER:
self._tokenizer = bert_tokenizer.BertFullTokenizer(
self._vocab_file, self._do_lower_case, self._seq_len
)
elif (
tokenizer == bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER
):
self._tokenizer = bert_tokenizer.BertFastTokenizer(
self._vocab_file, self._do_lower_case, self._seq_len
)
else:
raise ValueError(f"Unsupported tokenizer: {tokenizer}")
self._model_name = model_name
def _get_name_to_features(self):
"""Gets the dictionary mapping record keys to feature types."""
return {
"input_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
"input_word_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
"input_mask": tf.io.FixedLenFeature([self._seq_len], tf.int64),
"segment_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
"input_type_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
"label_ids": tf.io.FixedLenFeature([], tf.int64),
}
@ -269,6 +286,7 @@ class BertClassifierPreprocessor:
2. model_name
3. seq_len
4. do_lower_case
5. tokenizer name
Args:
ds_cache_files: TFRecordCacheFiles from the original raw dataset object
@ -282,6 +300,7 @@ class BertClassifierPreprocessor:
hasher.update(self._model_name.encode("utf-8"))
hasher.update(str(self._seq_len).encode("utf-8"))
hasher.update(str(self._do_lower_case).encode("utf-8"))
hasher.update(self._tokenizer.name.encode("utf-8"))
cache_prefix_filename = hasher.hexdigest()
return cache_files_lib.TFRecordCacheFiles(
cache_prefix_filename,
@ -289,23 +308,6 @@ class BertClassifierPreprocessor:
ds_cache_files.num_shards,
)
def _process_bert_features(self, text: str) -> Mapping[str, Sequence[int]]:
tokens = self._tokenizer.tokenize(text)
tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP]
tokens.insert(0, "[CLS]")
tokens.append("[SEP]")
input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < self._seq_len:
input_ids.append(0)
input_mask.append(0)
segment_ids = [0] * self._seq_len
return {
"input_ids": input_ids,
"input_mask": input_mask,
"segment_ids": segment_ids,
}
def preprocess(
self, dataset: text_classifier_ds.Dataset
) -> text_classifier_ds.Dataset:
@ -326,18 +328,20 @@ class BertClassifierPreprocessor:
size = 0
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
_validate_text_and_label(text, label)
feature = self._process_bert_features(text.numpy()[0].decode("utf-8"))
feature = self._tokenizer.process(text)
def create_int_feature(values):
f = tf.train.Feature(
int64_list=tf.train.Int64List(value=list(values))
)
f = tf.train.Feature(int64_list=tf.train.Int64List(value=values))
return f
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(feature["input_ids"])
features["input_word_ids"] = create_int_feature(
feature["input_word_ids"]
)
features["input_mask"] = create_int_feature(feature["input_mask"])
features["segment_ids"] = create_int_feature(feature["segment_ids"])
features["label_ids"] = create_int_feature([label.numpy()[0]])
features["input_type_ids"] = create_int_feature(
feature["input_type_ids"]
)
features["label_ids"] = create_int_feature(label.numpy().tolist())
tf_example = tf.train.Example(
features=tf.train.Features(feature=features)
)

View File

@ -18,18 +18,20 @@ import os
import tempfile
from unittest import mock as unittest_mock
from absl.testing import parameterized
import mock
import numpy as np
import numpy.testing as npt
import tensorflow as tf
from mediapipe.model_maker.python.core.data import cache_files
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
from mediapipe.model_maker.python.text.text_classifier import model_spec
from mediapipe.model_maker.python.text.text_classifier import preprocessor
class PreprocessorTest(tf.test.TestCase):
class PreprocessorTest(tf.test.TestCase, parameterized.TestCase):
CSV_PARAMS_ = text_classifier_ds.CSVParameters(
text_column='text', label_column='label')
@ -83,7 +85,19 @@ class PreprocessorTest(tf.test.TestCase):
npt.assert_array_equal(
np.stack(features_list), np.array([[1, 3, 3, 3, 3], [1, 5, 6, 0, 0]]))
def test_bert_preprocessor(self):
@parameterized.named_parameters(
dict(
testcase_name='fulltokenizer',
tokenizer=bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER,
),
dict(
testcase_name='fastberttokenizer',
tokenizer=bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER,
),
)
def test_bert_preprocessor(
self, tokenizer: bert_tokenizer.SupportedBertTokenizers
):
csv_file = self._get_csv_file()
dataset = text_classifier_ds.Dataset.from_csv(
filename=csv_file, csv_params=self.CSV_PARAMS_)
@ -93,6 +107,7 @@ class PreprocessorTest(tf.test.TestCase):
do_lower_case=bert_spec.do_lower_case,
uri=bert_spec.get_path(),
model_name=bert_spec.name,
tokenizer=tokenizer,
)
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
labels = []
@ -122,11 +137,13 @@ class PreprocessorTest(tf.test.TestCase):
cache_dir=self.get_temp_dir(),
)
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
tokenizer = bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=5,
do_lower_case=bert_spec.do_lower_case,
uri=bert_spec.get_path(),
model_name=bert_spec.name,
tokenizer=tokenizer,
)
ds_cache_files = dataset.tfrecord_cache_files
preprocessed_cache_files = bert_preprocessor._get_tfrecord_cache_files(
@ -149,12 +166,13 @@ class PreprocessorTest(tf.test.TestCase):
f' {preprocessed_cache_files.cache_prefix}\n',
)
def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case):
def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case, tokenizer):
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=seq_len,
do_lower_case=do_lower_case,
uri=bert_spec.get_path(),
model_name=bert_spec.name,
tokenizer=tokenizer,
)
new_cf = bert_preprocessor._get_tfrecord_cache_files(cf)
return new_cf.cache_prefix_filename
@ -167,19 +185,31 @@ class PreprocessorTest(tf.test.TestCase):
cache_dir=self.get_temp_dir(),
num_shards=1,
)
tokenizer = bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True))
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True))
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, False))
all_cf_prefixes.add(
self._get_new_prefix(cf, mobilebert_spec, 5, True, tokenizer)
)
all_cf_prefixes.add(
self._get_new_prefix(cf, mobilebert_spec, 10, True, tokenizer)
)
all_cf_prefixes.add(
self._get_new_prefix(cf, mobilebert_spec, 5, False, tokenizer)
)
new_cf = cache_files.TFRecordCacheFiles(
cache_prefix_filename='new_cache_prefix',
cache_dir=self.get_temp_dir(),
num_shards=1,
)
all_cf_prefixes.add(self._get_new_prefix(new_cf, mobilebert_spec, 5, True))
all_cf_prefixes.add(
self._get_new_prefix(new_cf, mobilebert_spec, 5, True, tokenizer)
)
new_tokenizer = bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER
all_cf_prefixes.add(
self._get_new_prefix(cf, mobilebert_spec, 5, True, new_tokenizer)
)
# Each item of all_cf_prefixes should be unique.
self.assertLen(all_cf_prefixes, 4)
self.assertLen(all_cf_prefixes, 5)
if __name__ == '__main__':

View File

@ -437,6 +437,7 @@ class _BertClassifier(TextClassifier):
do_lower_case=self._model_spec.do_lower_case,
uri=self._model_spec.get_path(),
model_name=self._model_spec.name,
tokenizer=self._hparams.tokenizer,
)
return (
self._text_preprocessor.preprocess(train_data),

View File

@ -6,4 +6,5 @@ tensorflow>=2.10
tensorflow-addons
tensorflow-datasets
tensorflow-hub
tensorflow-text
tf-models-official>=2.13.1