No public description
PiperOrigin-RevId: 566435327
This commit is contained in:
parent
36f78f6e4a
commit
94cda40a83
|
@ -30,10 +30,12 @@ py_library(
|
|||
srcs = ["__init__.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":bert_tokenizer",
|
||||
":dataset",
|
||||
":hyperparameters",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
":preprocessor",
|
||||
":text_classifier",
|
||||
":text_classifier_options",
|
||||
],
|
||||
|
@ -48,7 +50,10 @@ py_library(
|
|||
py_library(
|
||||
name = "hyperparameters",
|
||||
srcs = ["hyperparameters.py"],
|
||||
deps = ["//mediapipe/model_maker/python/core:hyperparameters"],
|
||||
deps = [
|
||||
":bert_tokenizer",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
@ -88,10 +93,26 @@ py_test(
|
|||
deps = [":dataset"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "bert_tokenizer",
|
||||
srcs = ["bert_tokenizer.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "bert_tokenizer_test",
|
||||
srcs = ["bert_tokenizer_test.py"],
|
||||
tags = ["requires-net:external"],
|
||||
deps = [
|
||||
":bert_tokenizer",
|
||||
":model_spec",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "preprocessor",
|
||||
srcs = ["preprocessor.py"],
|
||||
deps = [
|
||||
":bert_tokenizer",
|
||||
":dataset",
|
||||
"//mediapipe/model_maker/python/core/data:cache_files",
|
||||
],
|
||||
|
@ -102,6 +123,7 @@ py_test(
|
|||
srcs = ["preprocessor_test.py"],
|
||||
tags = ["requires-net:external"],
|
||||
deps = [
|
||||
":bert_tokenizer",
|
||||
":dataset",
|
||||
":model_spec",
|
||||
":preprocessor",
|
||||
|
|
|
@ -13,10 +13,12 @@
|
|||
# limitations under the License.
|
||||
"""MediaPipe Public Python API for Text Classifier."""
|
||||
|
||||
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset
|
||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_options
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||
|
||||
|
@ -33,12 +35,14 @@ Dataset = dataset.Dataset
|
|||
SupportedModels = model_spec.SupportedModels
|
||||
TextClassifier = text_classifier.TextClassifier
|
||||
TextClassifierOptions = text_classifier_options.TextClassifierOptions
|
||||
SupportedBertTokenizers = bert_tokenizer.SupportedBertTokenizers
|
||||
|
||||
# Remove duplicated and non-public API
|
||||
del bert_tokenizer
|
||||
del hyperparameters
|
||||
del dataset
|
||||
del model_options
|
||||
del model_spec
|
||||
del preprocessor # pylint: disable=undefined-variable
|
||||
del preprocessor
|
||||
del text_classifier
|
||||
del text_classifier_options
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Text classifier BERT tokenizer library."""
|
||||
import abc
|
||||
import enum
|
||||
from typing import Mapping, Sequence
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_text as tf_text
|
||||
|
||||
from official.nlp.tools import tokenization
|
||||
|
||||
|
||||
@enum.unique
|
||||
class SupportedBertTokenizers(enum.Enum):
|
||||
"""Supported preprocessors."""
|
||||
|
||||
FULL_TOKENIZER = "fulltokenizer"
|
||||
FAST_BERT_TOKENIZER = "fastberttokenizer"
|
||||
|
||||
|
||||
class BertTokenizer(abc.ABC):
|
||||
"""Abstract BertTokenizer class."""
|
||||
|
||||
name: str
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, vocab_file: str, do_lower_case: bool, seq_len: int):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
|
||||
pass
|
||||
|
||||
|
||||
class BertFullTokenizer(BertTokenizer):
|
||||
"""Tokenizer using the FullTokenizer from tensorflow_models."""
|
||||
|
||||
name = "fulltokenizer"
|
||||
|
||||
def __init__(self, vocab_file: str, do_lower_case: bool, seq_len: int):
|
||||
self._tokenizer = tokenization.FullTokenizer(
|
||||
vocab_file=vocab_file, do_lower_case=do_lower_case
|
||||
)
|
||||
self._seq_len = seq_len
|
||||
|
||||
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
|
||||
tokens = self._tokenizer.tokenize(input_tensor.numpy()[0].decode("utf-8"))
|
||||
tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP]
|
||||
tokens.insert(0, "[CLS]")
|
||||
tokens.append("[SEP]")
|
||||
input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
while len(input_ids) < self._seq_len:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids = [0] * self._seq_len
|
||||
return {
|
||||
"input_word_ids": input_ids,
|
||||
"input_type_ids": segment_ids,
|
||||
"input_mask": input_mask,
|
||||
}
|
||||
|
||||
|
||||
class BertFastTokenizer(BertTokenizer):
|
||||
"""Tokenizer using the FastBertTokenizer from tensorflow_text.
|
||||
|
||||
For more information, see:
|
||||
https://www.tensorflow.org/text/api_docs/python/text/FastBertTokenizer
|
||||
"""
|
||||
|
||||
name = "fastberttokenizer"
|
||||
|
||||
def __init__(self, vocab_file: str, do_lower_case: bool, seq_len: int):
|
||||
with tf.io.gfile.GFile(vocab_file, "r") as f:
|
||||
vocab = f.read().splitlines()
|
||||
self._tokenizer = tf_text.FastBertTokenizer(
|
||||
vocab=vocab,
|
||||
token_out_type=tf.int32,
|
||||
support_detokenization=False,
|
||||
lower_case_nfd_strip_accents=do_lower_case,
|
||||
)
|
||||
self._seq_len = seq_len
|
||||
self._cls_id = vocab.index("[CLS]")
|
||||
self._sep_id = vocab.index("[SEP]")
|
||||
self._pad_id = vocab.index("[PAD]")
|
||||
|
||||
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
|
||||
input_ids = self._tokenizer.tokenize(input_tensor).flat_values
|
||||
input_ids = input_ids[: (self._seq_len - 2)]
|
||||
input_ids = tf.concat(
|
||||
[
|
||||
tf.constant([self._cls_id]),
|
||||
input_ids,
|
||||
tf.constant([self._sep_id]),
|
||||
tf.fill((self._seq_len,), self._pad_id),
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
input_ids = input_ids[: self._seq_len]
|
||||
input_type_ids = tf.zeros(self._seq_len, dtype=tf.int32)
|
||||
input_mask = tf.cast(input_ids != self._pad_id, dtype=tf.int32)
|
||||
return {
|
||||
"input_word_ids": input_ids.numpy().tolist(),
|
||||
"input_type_ids": input_type_ids.numpy().tolist(),
|
||||
"input_mask": input_mask.numpy().tolist(),
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright 2022 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub
|
||||
|
||||
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||
|
||||
|
||||
class BertTokenizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||
# condition when downloading model since these tests may run in parallel.
|
||||
mock_gettempdir = unittest_mock.patch.object(
|
||||
tempfile,
|
||||
'gettempdir',
|
||||
return_value=self.create_tempdir(),
|
||||
autospec=True,
|
||||
)
|
||||
self.mock_gettempdir = mock_gettempdir.start()
|
||||
self.addCleanup(mock_gettempdir.stop)
|
||||
ms = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||
self._vocab_file = os.path.join(
|
||||
tensorflow_hub.resolve(ms.get_path()), 'assets', 'vocab.txt'
|
||||
)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='fulltokenizer',
|
||||
tokenizer_class=bert_tokenizer.BertFullTokenizer,
|
||||
),
|
||||
dict(
|
||||
testcase_name='fasttokenizer',
|
||||
tokenizer_class=bert_tokenizer.BertFastTokenizer,
|
||||
),
|
||||
)
|
||||
def test_bert_full_tokenizer(self, tokenizer_class):
|
||||
tokenizer = tokenizer_class(self._vocab_file, True, 16)
|
||||
text_input = tf.constant(['this is an éxamplé input ¿foo'.encode('utf-8')])
|
||||
result = tokenizer.process(text_input)
|
||||
self.assertAllEqual(
|
||||
result['input_word_ids'],
|
||||
[
|
||||
101,
|
||||
2023,
|
||||
2003,
|
||||
2019,
|
||||
2742,
|
||||
7953,
|
||||
1094,
|
||||
29379,
|
||||
102,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
],
|
||||
)
|
||||
self.assertAllEqual(
|
||||
result['input_mask'], [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
|
||||
)
|
||||
self.assertAllEqual(
|
||||
result['input_type_ids'],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -18,6 +18,7 @@ import enum
|
|||
from typing import Sequence, Union
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -53,6 +54,8 @@ class BertHParams(hp.BaseHParams):
|
|||
on recall. Only supported for binary classification.
|
||||
gamma: Gamma parameter for focal loss. To use cross entropy loss, set this
|
||||
value to 0. Defaults to 2.0.
|
||||
tokenizer: Tokenizer to use for preprocessing. Must be one of the enum
|
||||
options of SupportedBertTokenizers. Defaults to FULL_TOKENIZER.
|
||||
"""
|
||||
|
||||
learning_rate: float = 3e-5
|
||||
|
@ -68,5 +71,9 @@ class BertHParams(hp.BaseHParams):
|
|||
|
||||
gamma: float = 2.0
|
||||
|
||||
tokenizer: bert_tokenizer.SupportedBertTokenizers = (
|
||||
bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
|
||||
)
|
||||
|
||||
|
||||
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
|
||||
|
|
|
@ -24,8 +24,8 @@ import tensorflow as tf
|
|||
import tensorflow_hub
|
||||
|
||||
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
|
||||
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||
from official.nlp.tools import tokenization
|
||||
|
||||
|
||||
def _validate_text_and_label(text: tf.Tensor, label: tf.Tensor) -> None:
|
||||
|
@ -68,9 +68,9 @@ def _decode_record(
|
|||
example[name] = tf.cast(example[name], tf.int32)
|
||||
|
||||
bert_features = {
|
||||
"input_word_ids": example["input_ids"],
|
||||
"input_word_ids": example["input_word_ids"],
|
||||
"input_mask": example["input_mask"],
|
||||
"input_type_ids": example["segment_ids"]
|
||||
"input_type_ids": example["input_type_ids"],
|
||||
}
|
||||
return bert_features, example["label_ids"]
|
||||
|
||||
|
@ -224,10 +224,17 @@ class BertClassifierPreprocessor:
|
|||
tokenizer: BERT tokenizer.
|
||||
model_name: Name of the model provided by the model_spec. Used to associate
|
||||
cached files with specific Bert model vocab.
|
||||
preprocessor: Which preprocessor to use. Must be one of the enum values of
|
||||
SupportedBertPreprocessors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, seq_len: int, do_lower_case: bool, uri: str, model_name: str
|
||||
self,
|
||||
seq_len: int,
|
||||
do_lower_case: bool,
|
||||
uri: str,
|
||||
model_name: str,
|
||||
tokenizer: bert_tokenizer.SupportedBertTokenizers,
|
||||
):
|
||||
self._seq_len = seq_len
|
||||
# Vocab filepath is tied to the BERT module's URI.
|
||||
|
@ -235,17 +242,27 @@ class BertClassifierPreprocessor:
|
|||
tensorflow_hub.resolve(uri), "assets", "vocab.txt"
|
||||
)
|
||||
self._do_lower_case = do_lower_case
|
||||
self._tokenizer = tokenization.FullTokenizer(
|
||||
self._vocab_file, self._do_lower_case
|
||||
self._tokenizer: bert_tokenizer.BertTokenizer = None
|
||||
if tokenizer == bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER:
|
||||
self._tokenizer = bert_tokenizer.BertFullTokenizer(
|
||||
self._vocab_file, self._do_lower_case, self._seq_len
|
||||
)
|
||||
elif (
|
||||
tokenizer == bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER
|
||||
):
|
||||
self._tokenizer = bert_tokenizer.BertFastTokenizer(
|
||||
self._vocab_file, self._do_lower_case, self._seq_len
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported tokenizer: {tokenizer}")
|
||||
self._model_name = model_name
|
||||
|
||||
def _get_name_to_features(self):
|
||||
"""Gets the dictionary mapping record keys to feature types."""
|
||||
return {
|
||||
"input_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
||||
"input_word_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
||||
"input_mask": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
||||
"segment_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
||||
"input_type_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
||||
"label_ids": tf.io.FixedLenFeature([], tf.int64),
|
||||
}
|
||||
|
||||
|
@ -269,6 +286,7 @@ class BertClassifierPreprocessor:
|
|||
2. model_name
|
||||
3. seq_len
|
||||
4. do_lower_case
|
||||
5. tokenizer name
|
||||
|
||||
Args:
|
||||
ds_cache_files: TFRecordCacheFiles from the original raw dataset object
|
||||
|
@ -282,6 +300,7 @@ class BertClassifierPreprocessor:
|
|||
hasher.update(self._model_name.encode("utf-8"))
|
||||
hasher.update(str(self._seq_len).encode("utf-8"))
|
||||
hasher.update(str(self._do_lower_case).encode("utf-8"))
|
||||
hasher.update(self._tokenizer.name.encode("utf-8"))
|
||||
cache_prefix_filename = hasher.hexdigest()
|
||||
return cache_files_lib.TFRecordCacheFiles(
|
||||
cache_prefix_filename,
|
||||
|
@ -289,23 +308,6 @@ class BertClassifierPreprocessor:
|
|||
ds_cache_files.num_shards,
|
||||
)
|
||||
|
||||
def _process_bert_features(self, text: str) -> Mapping[str, Sequence[int]]:
|
||||
tokens = self._tokenizer.tokenize(text)
|
||||
tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP]
|
||||
tokens.insert(0, "[CLS]")
|
||||
tokens.append("[SEP]")
|
||||
input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
while len(input_ids) < self._seq_len:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids = [0] * self._seq_len
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"input_mask": input_mask,
|
||||
"segment_ids": segment_ids,
|
||||
}
|
||||
|
||||
def preprocess(
|
||||
self, dataset: text_classifier_ds.Dataset
|
||||
) -> text_classifier_ds.Dataset:
|
||||
|
@ -326,18 +328,20 @@ class BertClassifierPreprocessor:
|
|||
size = 0
|
||||
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
|
||||
_validate_text_and_label(text, label)
|
||||
feature = self._process_bert_features(text.numpy()[0].decode("utf-8"))
|
||||
feature = self._tokenizer.process(text)
|
||||
def create_int_feature(values):
|
||||
f = tf.train.Feature(
|
||||
int64_list=tf.train.Int64List(value=list(values))
|
||||
)
|
||||
f = tf.train.Feature(int64_list=tf.train.Int64List(value=values))
|
||||
return f
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = create_int_feature(feature["input_ids"])
|
||||
features["input_word_ids"] = create_int_feature(
|
||||
feature["input_word_ids"]
|
||||
)
|
||||
features["input_mask"] = create_int_feature(feature["input_mask"])
|
||||
features["segment_ids"] = create_int_feature(feature["segment_ids"])
|
||||
features["label_ids"] = create_int_feature([label.numpy()[0]])
|
||||
features["input_type_ids"] = create_int_feature(
|
||||
feature["input_type_ids"]
|
||||
)
|
||||
features["label_ids"] = create_int_feature(label.numpy().tolist())
|
||||
tf_example = tf.train.Example(
|
||||
features=tf.train.Features(feature=features)
|
||||
)
|
||||
|
|
|
@ -18,18 +18,20 @@ import os
|
|||
import tempfile
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
from absl.testing import parameterized
|
||||
import mock
|
||||
import numpy as np
|
||||
import numpy.testing as npt
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import cache_files
|
||||
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||
|
||||
|
||||
class PreprocessorTest(tf.test.TestCase):
|
||||
class PreprocessorTest(tf.test.TestCase, parameterized.TestCase):
|
||||
CSV_PARAMS_ = text_classifier_ds.CSVParameters(
|
||||
text_column='text', label_column='label')
|
||||
|
||||
|
@ -83,7 +85,19 @@ class PreprocessorTest(tf.test.TestCase):
|
|||
npt.assert_array_equal(
|
||||
np.stack(features_list), np.array([[1, 3, 3, 3, 3], [1, 5, 6, 0, 0]]))
|
||||
|
||||
def test_bert_preprocessor(self):
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='fulltokenizer',
|
||||
tokenizer=bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER,
|
||||
),
|
||||
dict(
|
||||
testcase_name='fastberttokenizer',
|
||||
tokenizer=bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER,
|
||||
),
|
||||
)
|
||||
def test_bert_preprocessor(
|
||||
self, tokenizer: bert_tokenizer.SupportedBertTokenizers
|
||||
):
|
||||
csv_file = self._get_csv_file()
|
||||
dataset = text_classifier_ds.Dataset.from_csv(
|
||||
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
||||
|
@ -93,6 +107,7 @@ class PreprocessorTest(tf.test.TestCase):
|
|||
do_lower_case=bert_spec.do_lower_case,
|
||||
uri=bert_spec.get_path(),
|
||||
model_name=bert_spec.name,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
||||
labels = []
|
||||
|
@ -122,11 +137,13 @@ class PreprocessorTest(tf.test.TestCase):
|
|||
cache_dir=self.get_temp_dir(),
|
||||
)
|
||||
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||
tokenizer = bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
|
||||
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||
seq_len=5,
|
||||
do_lower_case=bert_spec.do_lower_case,
|
||||
uri=bert_spec.get_path(),
|
||||
model_name=bert_spec.name,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
ds_cache_files = dataset.tfrecord_cache_files
|
||||
preprocessed_cache_files = bert_preprocessor._get_tfrecord_cache_files(
|
||||
|
@ -149,12 +166,13 @@ class PreprocessorTest(tf.test.TestCase):
|
|||
f' {preprocessed_cache_files.cache_prefix}\n',
|
||||
)
|
||||
|
||||
def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case):
|
||||
def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case, tokenizer):
|
||||
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||
seq_len=seq_len,
|
||||
do_lower_case=do_lower_case,
|
||||
uri=bert_spec.get_path(),
|
||||
model_name=bert_spec.name,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
new_cf = bert_preprocessor._get_tfrecord_cache_files(cf)
|
||||
return new_cf.cache_prefix_filename
|
||||
|
@ -167,19 +185,31 @@ class PreprocessorTest(tf.test.TestCase):
|
|||
cache_dir=self.get_temp_dir(),
|
||||
num_shards=1,
|
||||
)
|
||||
tokenizer = bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
|
||||
mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True))
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True))
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, False))
|
||||
all_cf_prefixes.add(
|
||||
self._get_new_prefix(cf, mobilebert_spec, 5, True, tokenizer)
|
||||
)
|
||||
all_cf_prefixes.add(
|
||||
self._get_new_prefix(cf, mobilebert_spec, 10, True, tokenizer)
|
||||
)
|
||||
all_cf_prefixes.add(
|
||||
self._get_new_prefix(cf, mobilebert_spec, 5, False, tokenizer)
|
||||
)
|
||||
new_cf = cache_files.TFRecordCacheFiles(
|
||||
cache_prefix_filename='new_cache_prefix',
|
||||
cache_dir=self.get_temp_dir(),
|
||||
num_shards=1,
|
||||
)
|
||||
all_cf_prefixes.add(self._get_new_prefix(new_cf, mobilebert_spec, 5, True))
|
||||
|
||||
all_cf_prefixes.add(
|
||||
self._get_new_prefix(new_cf, mobilebert_spec, 5, True, tokenizer)
|
||||
)
|
||||
new_tokenizer = bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER
|
||||
all_cf_prefixes.add(
|
||||
self._get_new_prefix(cf, mobilebert_spec, 5, True, new_tokenizer)
|
||||
)
|
||||
# Each item of all_cf_prefixes should be unique.
|
||||
self.assertLen(all_cf_prefixes, 4)
|
||||
self.assertLen(all_cf_prefixes, 5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -437,6 +437,7 @@ class _BertClassifier(TextClassifier):
|
|||
do_lower_case=self._model_spec.do_lower_case,
|
||||
uri=self._model_spec.get_path(),
|
||||
model_name=self._model_spec.name,
|
||||
tokenizer=self._hparams.tokenizer,
|
||||
)
|
||||
return (
|
||||
self._text_preprocessor.preprocess(train_data),
|
||||
|
|
|
@ -6,4 +6,5 @@ tensorflow>=2.10
|
|||
tensorflow-addons
|
||||
tensorflow-datasets
|
||||
tensorflow-hub
|
||||
tensorflow-text
|
||||
tf-models-official>=2.13.1
|
||||
|
|
Loading…
Reference in New Issue
Block a user