diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index e32733e31..2c239e4b0 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -30,10 +30,12 @@ py_library( srcs = ["__init__.py"], visibility = ["//visibility:public"], deps = [ + ":bert_tokenizer", ":dataset", ":hyperparameters", ":model_options", ":model_spec", + ":preprocessor", ":text_classifier", ":text_classifier_options", ], @@ -48,7 +50,10 @@ py_library( py_library( name = "hyperparameters", srcs = ["hyperparameters.py"], - deps = ["//mediapipe/model_maker/python/core:hyperparameters"], + deps = [ + ":bert_tokenizer", + "//mediapipe/model_maker/python/core:hyperparameters", + ], ) py_library( @@ -88,10 +93,26 @@ py_test( deps = [":dataset"], ) +py_library( + name = "bert_tokenizer", + srcs = ["bert_tokenizer.py"], +) + +py_test( + name = "bert_tokenizer_test", + srcs = ["bert_tokenizer_test.py"], + tags = ["requires-net:external"], + deps = [ + ":bert_tokenizer", + ":model_spec", + ], +) + py_library( name = "preprocessor", srcs = ["preprocessor.py"], deps = [ + ":bert_tokenizer", ":dataset", "//mediapipe/model_maker/python/core/data:cache_files", ], @@ -102,6 +123,7 @@ py_test( srcs = ["preprocessor_test.py"], tags = ["requires-net:external"], deps = [ + ":bert_tokenizer", ":dataset", ":model_spec", ":preprocessor", diff --git a/mediapipe/model_maker/python/text/text_classifier/__init__.py b/mediapipe/model_maker/python/text/text_classifier/__init__.py index 7eb0f9259..4096bf734 100644 --- a/mediapipe/model_maker/python/text/text_classifier/__init__.py +++ b/mediapipe/model_maker/python/text/text_classifier/__init__.py @@ -13,10 +13,12 @@ # limitations under the License. """MediaPipe Public Python API for Text Classifier.""" +from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer from mediapipe.model_maker.python.text.text_classifier import dataset from mediapipe.model_maker.python.text.text_classifier import hyperparameters from mediapipe.model_maker.python.text.text_classifier import model_options from mediapipe.model_maker.python.text.text_classifier import model_spec +from mediapipe.model_maker.python.text.text_classifier import preprocessor from mediapipe.model_maker.python.text.text_classifier import text_classifier from mediapipe.model_maker.python.text.text_classifier import text_classifier_options @@ -33,12 +35,14 @@ Dataset = dataset.Dataset SupportedModels = model_spec.SupportedModels TextClassifier = text_classifier.TextClassifier TextClassifierOptions = text_classifier_options.TextClassifierOptions +SupportedBertTokenizers = bert_tokenizer.SupportedBertTokenizers # Remove duplicated and non-public API +del bert_tokenizer del hyperparameters del dataset del model_options del model_spec -del preprocessor # pylint: disable=undefined-variable +del preprocessor del text_classifier del text_classifier_options diff --git a/mediapipe/model_maker/python/text/text_classifier/bert_tokenizer.py b/mediapipe/model_maker/python/text/text_classifier/bert_tokenizer.py new file mode 100644 index 000000000..8e92bc29c --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/bert_tokenizer.py @@ -0,0 +1,118 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Text classifier BERT tokenizer library.""" +import abc +import enum +from typing import Mapping, Sequence + +import tensorflow as tf +import tensorflow_text as tf_text + +from official.nlp.tools import tokenization + + +@enum.unique +class SupportedBertTokenizers(enum.Enum): + """Supported preprocessors.""" + + FULL_TOKENIZER = "fulltokenizer" + FAST_BERT_TOKENIZER = "fastberttokenizer" + + +class BertTokenizer(abc.ABC): + """Abstract BertTokenizer class.""" + + name: str + + @abc.abstractmethod + def __init__(self, vocab_file: str, do_lower_case: bool, seq_len: int): + pass + + @abc.abstractmethod + def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]: + pass + + +class BertFullTokenizer(BertTokenizer): + """Tokenizer using the FullTokenizer from tensorflow_models.""" + + name = "fulltokenizer" + + def __init__(self, vocab_file: str, do_lower_case: bool, seq_len: int): + self._tokenizer = tokenization.FullTokenizer( + vocab_file=vocab_file, do_lower_case=do_lower_case + ) + self._seq_len = seq_len + + def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]: + tokens = self._tokenizer.tokenize(input_tensor.numpy()[0].decode("utf-8")) + tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP] + tokens.insert(0, "[CLS]") + tokens.append("[SEP]") + input_ids = self._tokenizer.convert_tokens_to_ids(tokens) + input_mask = [1] * len(input_ids) + while len(input_ids) < self._seq_len: + input_ids.append(0) + input_mask.append(0) + segment_ids = [0] * self._seq_len + return { + "input_word_ids": input_ids, + "input_type_ids": segment_ids, + "input_mask": input_mask, + } + + +class BertFastTokenizer(BertTokenizer): + """Tokenizer using the FastBertTokenizer from tensorflow_text. + + For more information, see: + https://www.tensorflow.org/text/api_docs/python/text/FastBertTokenizer + """ + + name = "fastberttokenizer" + + def __init__(self, vocab_file: str, do_lower_case: bool, seq_len: int): + with tf.io.gfile.GFile(vocab_file, "r") as f: + vocab = f.read().splitlines() + self._tokenizer = tf_text.FastBertTokenizer( + vocab=vocab, + token_out_type=tf.int32, + support_detokenization=False, + lower_case_nfd_strip_accents=do_lower_case, + ) + self._seq_len = seq_len + self._cls_id = vocab.index("[CLS]") + self._sep_id = vocab.index("[SEP]") + self._pad_id = vocab.index("[PAD]") + + def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]: + input_ids = self._tokenizer.tokenize(input_tensor).flat_values + input_ids = input_ids[: (self._seq_len - 2)] + input_ids = tf.concat( + [ + tf.constant([self._cls_id]), + input_ids, + tf.constant([self._sep_id]), + tf.fill((self._seq_len,), self._pad_id), + ], + axis=0, + ) + input_ids = input_ids[: self._seq_len] + input_type_ids = tf.zeros(self._seq_len, dtype=tf.int32) + input_mask = tf.cast(input_ids != self._pad_id, dtype=tf.int32) + return { + "input_word_ids": input_ids.numpy().tolist(), + "input_type_ids": input_type_ids.numpy().tolist(), + "input_mask": input_mask.numpy().tolist(), + } diff --git a/mediapipe/model_maker/python/text/text_classifier/bert_tokenizer_test.py b/mediapipe/model_maker/python/text/text_classifier/bert_tokenizer_test.py new file mode 100644 index 000000000..84139d071 --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/bert_tokenizer_test.py @@ -0,0 +1,91 @@ +# Copyright 2022 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +from unittest import mock as unittest_mock + +from absl.testing import parameterized +import tensorflow as tf +import tensorflow_hub + +from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer +from mediapipe.model_maker.python.text.text_classifier import model_spec + + +class BertTokenizerTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + # Mock tempfile.gettempdir() to be unique for each test to avoid race + # condition when downloading model since these tests may run in parallel. + mock_gettempdir = unittest_mock.patch.object( + tempfile, + 'gettempdir', + return_value=self.create_tempdir(), + autospec=True, + ) + self.mock_gettempdir = mock_gettempdir.start() + self.addCleanup(mock_gettempdir.stop) + ms = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() + self._vocab_file = os.path.join( + tensorflow_hub.resolve(ms.get_path()), 'assets', 'vocab.txt' + ) + + @parameterized.named_parameters( + dict( + testcase_name='fulltokenizer', + tokenizer_class=bert_tokenizer.BertFullTokenizer, + ), + dict( + testcase_name='fasttokenizer', + tokenizer_class=bert_tokenizer.BertFastTokenizer, + ), + ) + def test_bert_full_tokenizer(self, tokenizer_class): + tokenizer = tokenizer_class(self._vocab_file, True, 16) + text_input = tf.constant(['this is an éxamplé input ¿foo'.encode('utf-8')]) + result = tokenizer.process(text_input) + self.assertAllEqual( + result['input_word_ids'], + [ + 101, + 2023, + 2003, + 2019, + 2742, + 7953, + 1094, + 29379, + 102, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + ) + self.assertAllEqual( + result['input_mask'], [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0] + ) + self.assertAllEqual( + result['input_type_ids'], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/text/text_classifier/hyperparameters.py b/mediapipe/model_maker/python/text/text_classifier/hyperparameters.py index 71470edb3..5d16564f5 100644 --- a/mediapipe/model_maker/python/text/text_classifier/hyperparameters.py +++ b/mediapipe/model_maker/python/text/text_classifier/hyperparameters.py @@ -18,6 +18,7 @@ import enum from typing import Sequence, Union from mediapipe.model_maker.python.core import hyperparameters as hp +from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer @dataclasses.dataclass @@ -53,6 +54,8 @@ class BertHParams(hp.BaseHParams): on recall. Only supported for binary classification. gamma: Gamma parameter for focal loss. To use cross entropy loss, set this value to 0. Defaults to 2.0. + tokenizer: Tokenizer to use for preprocessing. Must be one of the enum + options of SupportedBertTokenizers. Defaults to FULL_TOKENIZER. """ learning_rate: float = 3e-5 @@ -68,5 +71,9 @@ class BertHParams(hp.BaseHParams): gamma: float = 2.0 + tokenizer: bert_tokenizer.SupportedBertTokenizers = ( + bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER + ) + HParams = Union[BertHParams, AverageWordEmbeddingHParams] diff --git a/mediapipe/model_maker/python/text/text_classifier/preprocessor.py b/mediapipe/model_maker/python/text/text_classifier/preprocessor.py index 68a5df2fd..5954f4ca3 100644 --- a/mediapipe/model_maker/python/text/text_classifier/preprocessor.py +++ b/mediapipe/model_maker/python/text/text_classifier/preprocessor.py @@ -24,8 +24,8 @@ import tensorflow as tf import tensorflow_hub from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib +from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds -from official.nlp.tools import tokenization def _validate_text_and_label(text: tf.Tensor, label: tf.Tensor) -> None: @@ -68,9 +68,9 @@ def _decode_record( example[name] = tf.cast(example[name], tf.int32) bert_features = { - "input_word_ids": example["input_ids"], + "input_word_ids": example["input_word_ids"], "input_mask": example["input_mask"], - "input_type_ids": example["segment_ids"] + "input_type_ids": example["input_type_ids"], } return bert_features, example["label_ids"] @@ -224,10 +224,17 @@ class BertClassifierPreprocessor: tokenizer: BERT tokenizer. model_name: Name of the model provided by the model_spec. Used to associate cached files with specific Bert model vocab. + preprocessor: Which preprocessor to use. Must be one of the enum values of + SupportedBertPreprocessors. """ def __init__( - self, seq_len: int, do_lower_case: bool, uri: str, model_name: str + self, + seq_len: int, + do_lower_case: bool, + uri: str, + model_name: str, + tokenizer: bert_tokenizer.SupportedBertTokenizers, ): self._seq_len = seq_len # Vocab filepath is tied to the BERT module's URI. @@ -235,17 +242,27 @@ class BertClassifierPreprocessor: tensorflow_hub.resolve(uri), "assets", "vocab.txt" ) self._do_lower_case = do_lower_case - self._tokenizer = tokenization.FullTokenizer( - self._vocab_file, self._do_lower_case - ) + self._tokenizer: bert_tokenizer.BertTokenizer = None + if tokenizer == bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER: + self._tokenizer = bert_tokenizer.BertFullTokenizer( + self._vocab_file, self._do_lower_case, self._seq_len + ) + elif ( + tokenizer == bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER + ): + self._tokenizer = bert_tokenizer.BertFastTokenizer( + self._vocab_file, self._do_lower_case, self._seq_len + ) + else: + raise ValueError(f"Unsupported tokenizer: {tokenizer}") self._model_name = model_name def _get_name_to_features(self): """Gets the dictionary mapping record keys to feature types.""" return { - "input_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64), + "input_word_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64), "input_mask": tf.io.FixedLenFeature([self._seq_len], tf.int64), - "segment_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64), + "input_type_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64), "label_ids": tf.io.FixedLenFeature([], tf.int64), } @@ -269,6 +286,7 @@ class BertClassifierPreprocessor: 2. model_name 3. seq_len 4. do_lower_case + 5. tokenizer name Args: ds_cache_files: TFRecordCacheFiles from the original raw dataset object @@ -282,6 +300,7 @@ class BertClassifierPreprocessor: hasher.update(self._model_name.encode("utf-8")) hasher.update(str(self._seq_len).encode("utf-8")) hasher.update(str(self._do_lower_case).encode("utf-8")) + hasher.update(self._tokenizer.name.encode("utf-8")) cache_prefix_filename = hasher.hexdigest() return cache_files_lib.TFRecordCacheFiles( cache_prefix_filename, @@ -289,23 +308,6 @@ class BertClassifierPreprocessor: ds_cache_files.num_shards, ) - def _process_bert_features(self, text: str) -> Mapping[str, Sequence[int]]: - tokens = self._tokenizer.tokenize(text) - tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP] - tokens.insert(0, "[CLS]") - tokens.append("[SEP]") - input_ids = self._tokenizer.convert_tokens_to_ids(tokens) - input_mask = [1] * len(input_ids) - while len(input_ids) < self._seq_len: - input_ids.append(0) - input_mask.append(0) - segment_ids = [0] * self._seq_len - return { - "input_ids": input_ids, - "input_mask": input_mask, - "segment_ids": segment_ids, - } - def preprocess( self, dataset: text_classifier_ds.Dataset ) -> text_classifier_ds.Dataset: @@ -326,18 +328,20 @@ class BertClassifierPreprocessor: size = 0 for index, (text, label) in enumerate(dataset.gen_tf_dataset()): _validate_text_and_label(text, label) - feature = self._process_bert_features(text.numpy()[0].decode("utf-8")) + feature = self._tokenizer.process(text) def create_int_feature(values): - f = tf.train.Feature( - int64_list=tf.train.Int64List(value=list(values)) - ) + f = tf.train.Feature(int64_list=tf.train.Int64List(value=values)) return f features = collections.OrderedDict() - features["input_ids"] = create_int_feature(feature["input_ids"]) + features["input_word_ids"] = create_int_feature( + feature["input_word_ids"] + ) features["input_mask"] = create_int_feature(feature["input_mask"]) - features["segment_ids"] = create_int_feature(feature["segment_ids"]) - features["label_ids"] = create_int_feature([label.numpy()[0]]) + features["input_type_ids"] = create_int_feature( + feature["input_type_ids"] + ) + features["label_ids"] = create_int_feature(label.numpy().tolist()) tf_example = tf.train.Example( features=tf.train.Features(feature=features) ) diff --git a/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py b/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py index ff9015498..2ed1e6101 100644 --- a/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py @@ -18,18 +18,20 @@ import os import tempfile from unittest import mock as unittest_mock +from absl.testing import parameterized import mock import numpy as np import numpy.testing as npt import tensorflow as tf from mediapipe.model_maker.python.core.data import cache_files +from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds from mediapipe.model_maker.python.text.text_classifier import model_spec from mediapipe.model_maker.python.text.text_classifier import preprocessor -class PreprocessorTest(tf.test.TestCase): +class PreprocessorTest(tf.test.TestCase, parameterized.TestCase): CSV_PARAMS_ = text_classifier_ds.CSVParameters( text_column='text', label_column='label') @@ -83,7 +85,19 @@ class PreprocessorTest(tf.test.TestCase): npt.assert_array_equal( np.stack(features_list), np.array([[1, 3, 3, 3, 3], [1, 5, 6, 0, 0]])) - def test_bert_preprocessor(self): + @parameterized.named_parameters( + dict( + testcase_name='fulltokenizer', + tokenizer=bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER, + ), + dict( + testcase_name='fastberttokenizer', + tokenizer=bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER, + ), + ) + def test_bert_preprocessor( + self, tokenizer: bert_tokenizer.SupportedBertTokenizers + ): csv_file = self._get_csv_file() dataset = text_classifier_ds.Dataset.from_csv( filename=csv_file, csv_params=self.CSV_PARAMS_) @@ -93,6 +107,7 @@ class PreprocessorTest(tf.test.TestCase): do_lower_case=bert_spec.do_lower_case, uri=bert_spec.get_path(), model_name=bert_spec.name, + tokenizer=tokenizer, ) preprocessed_dataset = bert_preprocessor.preprocess(dataset) labels = [] @@ -122,11 +137,13 @@ class PreprocessorTest(tf.test.TestCase): cache_dir=self.get_temp_dir(), ) bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() + tokenizer = bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER bert_preprocessor = preprocessor.BertClassifierPreprocessor( seq_len=5, do_lower_case=bert_spec.do_lower_case, uri=bert_spec.get_path(), model_name=bert_spec.name, + tokenizer=tokenizer, ) ds_cache_files = dataset.tfrecord_cache_files preprocessed_cache_files = bert_preprocessor._get_tfrecord_cache_files( @@ -149,12 +166,13 @@ class PreprocessorTest(tf.test.TestCase): f' {preprocessed_cache_files.cache_prefix}\n', ) - def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case): + def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case, tokenizer): bert_preprocessor = preprocessor.BertClassifierPreprocessor( seq_len=seq_len, do_lower_case=do_lower_case, uri=bert_spec.get_path(), model_name=bert_spec.name, + tokenizer=tokenizer, ) new_cf = bert_preprocessor._get_tfrecord_cache_files(cf) return new_cf.cache_prefix_filename @@ -167,19 +185,31 @@ class PreprocessorTest(tf.test.TestCase): cache_dir=self.get_temp_dir(), num_shards=1, ) + tokenizer = bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() - all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True)) - all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True)) - all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, False)) + all_cf_prefixes.add( + self._get_new_prefix(cf, mobilebert_spec, 5, True, tokenizer) + ) + all_cf_prefixes.add( + self._get_new_prefix(cf, mobilebert_spec, 10, True, tokenizer) + ) + all_cf_prefixes.add( + self._get_new_prefix(cf, mobilebert_spec, 5, False, tokenizer) + ) new_cf = cache_files.TFRecordCacheFiles( cache_prefix_filename='new_cache_prefix', cache_dir=self.get_temp_dir(), num_shards=1, ) - all_cf_prefixes.add(self._get_new_prefix(new_cf, mobilebert_spec, 5, True)) - + all_cf_prefixes.add( + self._get_new_prefix(new_cf, mobilebert_spec, 5, True, tokenizer) + ) + new_tokenizer = bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER + all_cf_prefixes.add( + self._get_new_prefix(cf, mobilebert_spec, 5, True, new_tokenizer) + ) # Each item of all_cf_prefixes should be unique. - self.assertLen(all_cf_prefixes, 4) + self.assertLen(all_cf_prefixes, 5) if __name__ == '__main__': diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index 752752230..c067a4ed6 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -437,6 +437,7 @@ class _BertClassifier(TextClassifier): do_lower_case=self._model_spec.do_lower_case, uri=self._model_spec.get_path(), model_name=self._model_spec.name, + tokenizer=self._hparams.tokenizer, ) return ( self._text_preprocessor.preprocess(train_data), diff --git a/mediapipe/model_maker/requirements.txt b/mediapipe/model_maker/requirements.txt index a1c975c1e..ff43fa3f0 100644 --- a/mediapipe/model_maker/requirements.txt +++ b/mediapipe/model_maker/requirements.txt @@ -6,4 +6,5 @@ tensorflow>=2.10 tensorflow-addons tensorflow-datasets tensorflow-hub +tensorflow-text tf-models-official>=2.13.1