No public description

PiperOrigin-RevId: 566435327
This commit is contained in:
MediaPipe Team 2023-09-18 15:47:29 -07:00 committed by Copybara-Service
parent 36f78f6e4a
commit 94cda40a83
9 changed files with 322 additions and 44 deletions

View File

@ -30,10 +30,12 @@ py_library(
srcs = ["__init__.py"], srcs = ["__init__.py"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":bert_tokenizer",
":dataset", ":dataset",
":hyperparameters", ":hyperparameters",
":model_options", ":model_options",
":model_spec", ":model_spec",
":preprocessor",
":text_classifier", ":text_classifier",
":text_classifier_options", ":text_classifier_options",
], ],
@ -48,7 +50,10 @@ py_library(
py_library( py_library(
name = "hyperparameters", name = "hyperparameters",
srcs = ["hyperparameters.py"], srcs = ["hyperparameters.py"],
deps = ["//mediapipe/model_maker/python/core:hyperparameters"], deps = [
":bert_tokenizer",
"//mediapipe/model_maker/python/core:hyperparameters",
],
) )
py_library( py_library(
@ -88,10 +93,26 @@ py_test(
deps = [":dataset"], deps = [":dataset"],
) )
py_library(
name = "bert_tokenizer",
srcs = ["bert_tokenizer.py"],
)
py_test(
name = "bert_tokenizer_test",
srcs = ["bert_tokenizer_test.py"],
tags = ["requires-net:external"],
deps = [
":bert_tokenizer",
":model_spec",
],
)
py_library( py_library(
name = "preprocessor", name = "preprocessor",
srcs = ["preprocessor.py"], srcs = ["preprocessor.py"],
deps = [ deps = [
":bert_tokenizer",
":dataset", ":dataset",
"//mediapipe/model_maker/python/core/data:cache_files", "//mediapipe/model_maker/python/core/data:cache_files",
], ],
@ -102,6 +123,7 @@ py_test(
srcs = ["preprocessor_test.py"], srcs = ["preprocessor_test.py"],
tags = ["requires-net:external"], tags = ["requires-net:external"],
deps = [ deps = [
":bert_tokenizer",
":dataset", ":dataset",
":model_spec", ":model_spec",
":preprocessor", ":preprocessor",

View File

@ -13,10 +13,12 @@
# limitations under the License. # limitations under the License.
"""MediaPipe Public Python API for Text Classifier.""" """MediaPipe Public Python API for Text Classifier."""
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
from mediapipe.model_maker.python.text.text_classifier import dataset from mediapipe.model_maker.python.text.text_classifier import dataset
from mediapipe.model_maker.python.text.text_classifier import hyperparameters from mediapipe.model_maker.python.text.text_classifier import hyperparameters
from mediapipe.model_maker.python.text.text_classifier import model_options from mediapipe.model_maker.python.text.text_classifier import model_options
from mediapipe.model_maker.python.text.text_classifier import model_spec from mediapipe.model_maker.python.text.text_classifier import model_spec
from mediapipe.model_maker.python.text.text_classifier import preprocessor
from mediapipe.model_maker.python.text.text_classifier import text_classifier from mediapipe.model_maker.python.text.text_classifier import text_classifier
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
@ -33,12 +35,14 @@ Dataset = dataset.Dataset
SupportedModels = model_spec.SupportedModels SupportedModels = model_spec.SupportedModels
TextClassifier = text_classifier.TextClassifier TextClassifier = text_classifier.TextClassifier
TextClassifierOptions = text_classifier_options.TextClassifierOptions TextClassifierOptions = text_classifier_options.TextClassifierOptions
SupportedBertTokenizers = bert_tokenizer.SupportedBertTokenizers
# Remove duplicated and non-public API # Remove duplicated and non-public API
del bert_tokenizer
del hyperparameters del hyperparameters
del dataset del dataset
del model_options del model_options
del model_spec del model_spec
del preprocessor # pylint: disable=undefined-variable del preprocessor
del text_classifier del text_classifier
del text_classifier_options del text_classifier_options

View File

@ -0,0 +1,118 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text classifier BERT tokenizer library."""
import abc
import enum
from typing import Mapping, Sequence
import tensorflow as tf
import tensorflow_text as tf_text
from official.nlp.tools import tokenization
@enum.unique
class SupportedBertTokenizers(enum.Enum):
"""Supported preprocessors."""
FULL_TOKENIZER = "fulltokenizer"
FAST_BERT_TOKENIZER = "fastberttokenizer"
class BertTokenizer(abc.ABC):
"""Abstract BertTokenizer class."""
name: str
@abc.abstractmethod
def __init__(self, vocab_file: str, do_lower_case: bool, seq_len: int):
pass
@abc.abstractmethod
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
pass
class BertFullTokenizer(BertTokenizer):
"""Tokenizer using the FullTokenizer from tensorflow_models."""
name = "fulltokenizer"
def __init__(self, vocab_file: str, do_lower_case: bool, seq_len: int):
self._tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_file, do_lower_case=do_lower_case
)
self._seq_len = seq_len
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
tokens = self._tokenizer.tokenize(input_tensor.numpy()[0].decode("utf-8"))
tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP]
tokens.insert(0, "[CLS]")
tokens.append("[SEP]")
input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < self._seq_len:
input_ids.append(0)
input_mask.append(0)
segment_ids = [0] * self._seq_len
return {
"input_word_ids": input_ids,
"input_type_ids": segment_ids,
"input_mask": input_mask,
}
class BertFastTokenizer(BertTokenizer):
"""Tokenizer using the FastBertTokenizer from tensorflow_text.
For more information, see:
https://www.tensorflow.org/text/api_docs/python/text/FastBertTokenizer
"""
name = "fastberttokenizer"
def __init__(self, vocab_file: str, do_lower_case: bool, seq_len: int):
with tf.io.gfile.GFile(vocab_file, "r") as f:
vocab = f.read().splitlines()
self._tokenizer = tf_text.FastBertTokenizer(
vocab=vocab,
token_out_type=tf.int32,
support_detokenization=False,
lower_case_nfd_strip_accents=do_lower_case,
)
self._seq_len = seq_len
self._cls_id = vocab.index("[CLS]")
self._sep_id = vocab.index("[SEP]")
self._pad_id = vocab.index("[PAD]")
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
input_ids = self._tokenizer.tokenize(input_tensor).flat_values
input_ids = input_ids[: (self._seq_len - 2)]
input_ids = tf.concat(
[
tf.constant([self._cls_id]),
input_ids,
tf.constant([self._sep_id]),
tf.fill((self._seq_len,), self._pad_id),
],
axis=0,
)
input_ids = input_ids[: self._seq_len]
input_type_ids = tf.zeros(self._seq_len, dtype=tf.int32)
input_mask = tf.cast(input_ids != self._pad_id, dtype=tf.int32)
return {
"input_word_ids": input_ids.numpy().tolist(),
"input_type_ids": input_type_ids.numpy().tolist(),
"input_mask": input_mask.numpy().tolist(),
}

View File

@ -0,0 +1,91 @@
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
from unittest import mock as unittest_mock
from absl.testing import parameterized
import tensorflow as tf
import tensorflow_hub
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
from mediapipe.model_maker.python.text.text_classifier import model_spec
class BertTokenizerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
# Mock tempfile.gettempdir() to be unique for each test to avoid race
# condition when downloading model since these tests may run in parallel.
mock_gettempdir = unittest_mock.patch.object(
tempfile,
'gettempdir',
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
ms = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
self._vocab_file = os.path.join(
tensorflow_hub.resolve(ms.get_path()), 'assets', 'vocab.txt'
)
@parameterized.named_parameters(
dict(
testcase_name='fulltokenizer',
tokenizer_class=bert_tokenizer.BertFullTokenizer,
),
dict(
testcase_name='fasttokenizer',
tokenizer_class=bert_tokenizer.BertFastTokenizer,
),
)
def test_bert_full_tokenizer(self, tokenizer_class):
tokenizer = tokenizer_class(self._vocab_file, True, 16)
text_input = tf.constant(['this is an éxamplé input ¿foo'.encode('utf-8')])
result = tokenizer.process(text_input)
self.assertAllEqual(
result['input_word_ids'],
[
101,
2023,
2003,
2019,
2742,
7953,
1094,
29379,
102,
0,
0,
0,
0,
0,
0,
0,
],
)
self.assertAllEqual(
result['input_mask'], [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
)
self.assertAllEqual(
result['input_type_ids'],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
)
if __name__ == '__main__':
tf.test.main()

View File

@ -18,6 +18,7 @@ import enum
from typing import Sequence, Union from typing import Sequence, Union
from mediapipe.model_maker.python.core import hyperparameters as hp from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
@dataclasses.dataclass @dataclasses.dataclass
@ -53,6 +54,8 @@ class BertHParams(hp.BaseHParams):
on recall. Only supported for binary classification. on recall. Only supported for binary classification.
gamma: Gamma parameter for focal loss. To use cross entropy loss, set this gamma: Gamma parameter for focal loss. To use cross entropy loss, set this
value to 0. Defaults to 2.0. value to 0. Defaults to 2.0.
tokenizer: Tokenizer to use for preprocessing. Must be one of the enum
options of SupportedBertTokenizers. Defaults to FULL_TOKENIZER.
""" """
learning_rate: float = 3e-5 learning_rate: float = 3e-5
@ -68,5 +71,9 @@ class BertHParams(hp.BaseHParams):
gamma: float = 2.0 gamma: float = 2.0
tokenizer: bert_tokenizer.SupportedBertTokenizers = (
bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
)
HParams = Union[BertHParams, AverageWordEmbeddingHParams] HParams = Union[BertHParams, AverageWordEmbeddingHParams]

View File

@ -24,8 +24,8 @@ import tensorflow as tf
import tensorflow_hub import tensorflow_hub
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
from official.nlp.tools import tokenization
def _validate_text_and_label(text: tf.Tensor, label: tf.Tensor) -> None: def _validate_text_and_label(text: tf.Tensor, label: tf.Tensor) -> None:
@ -68,9 +68,9 @@ def _decode_record(
example[name] = tf.cast(example[name], tf.int32) example[name] = tf.cast(example[name], tf.int32)
bert_features = { bert_features = {
"input_word_ids": example["input_ids"], "input_word_ids": example["input_word_ids"],
"input_mask": example["input_mask"], "input_mask": example["input_mask"],
"input_type_ids": example["segment_ids"] "input_type_ids": example["input_type_ids"],
} }
return bert_features, example["label_ids"] return bert_features, example["label_ids"]
@ -224,10 +224,17 @@ class BertClassifierPreprocessor:
tokenizer: BERT tokenizer. tokenizer: BERT tokenizer.
model_name: Name of the model provided by the model_spec. Used to associate model_name: Name of the model provided by the model_spec. Used to associate
cached files with specific Bert model vocab. cached files with specific Bert model vocab.
preprocessor: Which preprocessor to use. Must be one of the enum values of
SupportedBertPreprocessors.
""" """
def __init__( def __init__(
self, seq_len: int, do_lower_case: bool, uri: str, model_name: str self,
seq_len: int,
do_lower_case: bool,
uri: str,
model_name: str,
tokenizer: bert_tokenizer.SupportedBertTokenizers,
): ):
self._seq_len = seq_len self._seq_len = seq_len
# Vocab filepath is tied to the BERT module's URI. # Vocab filepath is tied to the BERT module's URI.
@ -235,17 +242,27 @@ class BertClassifierPreprocessor:
tensorflow_hub.resolve(uri), "assets", "vocab.txt" tensorflow_hub.resolve(uri), "assets", "vocab.txt"
) )
self._do_lower_case = do_lower_case self._do_lower_case = do_lower_case
self._tokenizer = tokenization.FullTokenizer( self._tokenizer: bert_tokenizer.BertTokenizer = None
self._vocab_file, self._do_lower_case if tokenizer == bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER:
self._tokenizer = bert_tokenizer.BertFullTokenizer(
self._vocab_file, self._do_lower_case, self._seq_len
) )
elif (
tokenizer == bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER
):
self._tokenizer = bert_tokenizer.BertFastTokenizer(
self._vocab_file, self._do_lower_case, self._seq_len
)
else:
raise ValueError(f"Unsupported tokenizer: {tokenizer}")
self._model_name = model_name self._model_name = model_name
def _get_name_to_features(self): def _get_name_to_features(self):
"""Gets the dictionary mapping record keys to feature types.""" """Gets the dictionary mapping record keys to feature types."""
return { return {
"input_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64), "input_word_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
"input_mask": tf.io.FixedLenFeature([self._seq_len], tf.int64), "input_mask": tf.io.FixedLenFeature([self._seq_len], tf.int64),
"segment_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64), "input_type_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
"label_ids": tf.io.FixedLenFeature([], tf.int64), "label_ids": tf.io.FixedLenFeature([], tf.int64),
} }
@ -269,6 +286,7 @@ class BertClassifierPreprocessor:
2. model_name 2. model_name
3. seq_len 3. seq_len
4. do_lower_case 4. do_lower_case
5. tokenizer name
Args: Args:
ds_cache_files: TFRecordCacheFiles from the original raw dataset object ds_cache_files: TFRecordCacheFiles from the original raw dataset object
@ -282,6 +300,7 @@ class BertClassifierPreprocessor:
hasher.update(self._model_name.encode("utf-8")) hasher.update(self._model_name.encode("utf-8"))
hasher.update(str(self._seq_len).encode("utf-8")) hasher.update(str(self._seq_len).encode("utf-8"))
hasher.update(str(self._do_lower_case).encode("utf-8")) hasher.update(str(self._do_lower_case).encode("utf-8"))
hasher.update(self._tokenizer.name.encode("utf-8"))
cache_prefix_filename = hasher.hexdigest() cache_prefix_filename = hasher.hexdigest()
return cache_files_lib.TFRecordCacheFiles( return cache_files_lib.TFRecordCacheFiles(
cache_prefix_filename, cache_prefix_filename,
@ -289,23 +308,6 @@ class BertClassifierPreprocessor:
ds_cache_files.num_shards, ds_cache_files.num_shards,
) )
def _process_bert_features(self, text: str) -> Mapping[str, Sequence[int]]:
tokens = self._tokenizer.tokenize(text)
tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP]
tokens.insert(0, "[CLS]")
tokens.append("[SEP]")
input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < self._seq_len:
input_ids.append(0)
input_mask.append(0)
segment_ids = [0] * self._seq_len
return {
"input_ids": input_ids,
"input_mask": input_mask,
"segment_ids": segment_ids,
}
def preprocess( def preprocess(
self, dataset: text_classifier_ds.Dataset self, dataset: text_classifier_ds.Dataset
) -> text_classifier_ds.Dataset: ) -> text_classifier_ds.Dataset:
@ -326,18 +328,20 @@ class BertClassifierPreprocessor:
size = 0 size = 0
for index, (text, label) in enumerate(dataset.gen_tf_dataset()): for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
_validate_text_and_label(text, label) _validate_text_and_label(text, label)
feature = self._process_bert_features(text.numpy()[0].decode("utf-8")) feature = self._tokenizer.process(text)
def create_int_feature(values): def create_int_feature(values):
f = tf.train.Feature( f = tf.train.Feature(int64_list=tf.train.Int64List(value=values))
int64_list=tf.train.Int64List(value=list(values))
)
return f return f
features = collections.OrderedDict() features = collections.OrderedDict()
features["input_ids"] = create_int_feature(feature["input_ids"]) features["input_word_ids"] = create_int_feature(
feature["input_word_ids"]
)
features["input_mask"] = create_int_feature(feature["input_mask"]) features["input_mask"] = create_int_feature(feature["input_mask"])
features["segment_ids"] = create_int_feature(feature["segment_ids"]) features["input_type_ids"] = create_int_feature(
features["label_ids"] = create_int_feature([label.numpy()[0]]) feature["input_type_ids"]
)
features["label_ids"] = create_int_feature(label.numpy().tolist())
tf_example = tf.train.Example( tf_example = tf.train.Example(
features=tf.train.Features(feature=features) features=tf.train.Features(feature=features)
) )

View File

@ -18,18 +18,20 @@ import os
import tempfile import tempfile
from unittest import mock as unittest_mock from unittest import mock as unittest_mock
from absl.testing import parameterized
import mock import mock
import numpy as np import numpy as np
import numpy.testing as npt import numpy.testing as npt
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.data import cache_files from mediapipe.model_maker.python.core.data import cache_files
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
from mediapipe.model_maker.python.text.text_classifier import model_spec from mediapipe.model_maker.python.text.text_classifier import model_spec
from mediapipe.model_maker.python.text.text_classifier import preprocessor from mediapipe.model_maker.python.text.text_classifier import preprocessor
class PreprocessorTest(tf.test.TestCase): class PreprocessorTest(tf.test.TestCase, parameterized.TestCase):
CSV_PARAMS_ = text_classifier_ds.CSVParameters( CSV_PARAMS_ = text_classifier_ds.CSVParameters(
text_column='text', label_column='label') text_column='text', label_column='label')
@ -83,7 +85,19 @@ class PreprocessorTest(tf.test.TestCase):
npt.assert_array_equal( npt.assert_array_equal(
np.stack(features_list), np.array([[1, 3, 3, 3, 3], [1, 5, 6, 0, 0]])) np.stack(features_list), np.array([[1, 3, 3, 3, 3], [1, 5, 6, 0, 0]]))
def test_bert_preprocessor(self): @parameterized.named_parameters(
dict(
testcase_name='fulltokenizer',
tokenizer=bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER,
),
dict(
testcase_name='fastberttokenizer',
tokenizer=bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER,
),
)
def test_bert_preprocessor(
self, tokenizer: bert_tokenizer.SupportedBertTokenizers
):
csv_file = self._get_csv_file() csv_file = self._get_csv_file()
dataset = text_classifier_ds.Dataset.from_csv( dataset = text_classifier_ds.Dataset.from_csv(
filename=csv_file, csv_params=self.CSV_PARAMS_) filename=csv_file, csv_params=self.CSV_PARAMS_)
@ -93,6 +107,7 @@ class PreprocessorTest(tf.test.TestCase):
do_lower_case=bert_spec.do_lower_case, do_lower_case=bert_spec.do_lower_case,
uri=bert_spec.get_path(), uri=bert_spec.get_path(),
model_name=bert_spec.name, model_name=bert_spec.name,
tokenizer=tokenizer,
) )
preprocessed_dataset = bert_preprocessor.preprocess(dataset) preprocessed_dataset = bert_preprocessor.preprocess(dataset)
labels = [] labels = []
@ -122,11 +137,13 @@ class PreprocessorTest(tf.test.TestCase):
cache_dir=self.get_temp_dir(), cache_dir=self.get_temp_dir(),
) )
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
tokenizer = bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
bert_preprocessor = preprocessor.BertClassifierPreprocessor( bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=5, seq_len=5,
do_lower_case=bert_spec.do_lower_case, do_lower_case=bert_spec.do_lower_case,
uri=bert_spec.get_path(), uri=bert_spec.get_path(),
model_name=bert_spec.name, model_name=bert_spec.name,
tokenizer=tokenizer,
) )
ds_cache_files = dataset.tfrecord_cache_files ds_cache_files = dataset.tfrecord_cache_files
preprocessed_cache_files = bert_preprocessor._get_tfrecord_cache_files( preprocessed_cache_files = bert_preprocessor._get_tfrecord_cache_files(
@ -149,12 +166,13 @@ class PreprocessorTest(tf.test.TestCase):
f' {preprocessed_cache_files.cache_prefix}\n', f' {preprocessed_cache_files.cache_prefix}\n',
) )
def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case): def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case, tokenizer):
bert_preprocessor = preprocessor.BertClassifierPreprocessor( bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=seq_len, seq_len=seq_len,
do_lower_case=do_lower_case, do_lower_case=do_lower_case,
uri=bert_spec.get_path(), uri=bert_spec.get_path(),
model_name=bert_spec.name, model_name=bert_spec.name,
tokenizer=tokenizer,
) )
new_cf = bert_preprocessor._get_tfrecord_cache_files(cf) new_cf = bert_preprocessor._get_tfrecord_cache_files(cf)
return new_cf.cache_prefix_filename return new_cf.cache_prefix_filename
@ -167,19 +185,31 @@ class PreprocessorTest(tf.test.TestCase):
cache_dir=self.get_temp_dir(), cache_dir=self.get_temp_dir(),
num_shards=1, num_shards=1,
) )
tokenizer = bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True)) all_cf_prefixes.add(
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True)) self._get_new_prefix(cf, mobilebert_spec, 5, True, tokenizer)
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, False)) )
all_cf_prefixes.add(
self._get_new_prefix(cf, mobilebert_spec, 10, True, tokenizer)
)
all_cf_prefixes.add(
self._get_new_prefix(cf, mobilebert_spec, 5, False, tokenizer)
)
new_cf = cache_files.TFRecordCacheFiles( new_cf = cache_files.TFRecordCacheFiles(
cache_prefix_filename='new_cache_prefix', cache_prefix_filename='new_cache_prefix',
cache_dir=self.get_temp_dir(), cache_dir=self.get_temp_dir(),
num_shards=1, num_shards=1,
) )
all_cf_prefixes.add(self._get_new_prefix(new_cf, mobilebert_spec, 5, True)) all_cf_prefixes.add(
self._get_new_prefix(new_cf, mobilebert_spec, 5, True, tokenizer)
)
new_tokenizer = bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER
all_cf_prefixes.add(
self._get_new_prefix(cf, mobilebert_spec, 5, True, new_tokenizer)
)
# Each item of all_cf_prefixes should be unique. # Each item of all_cf_prefixes should be unique.
self.assertLen(all_cf_prefixes, 4) self.assertLen(all_cf_prefixes, 5)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -437,6 +437,7 @@ class _BertClassifier(TextClassifier):
do_lower_case=self._model_spec.do_lower_case, do_lower_case=self._model_spec.do_lower_case,
uri=self._model_spec.get_path(), uri=self._model_spec.get_path(),
model_name=self._model_spec.name, model_name=self._model_spec.name,
tokenizer=self._hparams.tokenizer,
) )
return ( return (
self._text_preprocessor.preprocess(train_data), self._text_preprocessor.preprocess(train_data),

View File

@ -6,4 +6,5 @@ tensorflow>=2.10
tensorflow-addons tensorflow-addons
tensorflow-datasets tensorflow-datasets
tensorflow-hub tensorflow-hub
tensorflow-text
tf-models-official>=2.13.1 tf-models-official>=2.13.1