No public description
PiperOrigin-RevId: 566435327
This commit is contained in:
parent
36f78f6e4a
commit
94cda40a83
|
@ -30,10 +30,12 @@ py_library(
|
||||||
srcs = ["__init__.py"],
|
srcs = ["__init__.py"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":bert_tokenizer",
|
||||||
":dataset",
|
":dataset",
|
||||||
":hyperparameters",
|
":hyperparameters",
|
||||||
":model_options",
|
":model_options",
|
||||||
":model_spec",
|
":model_spec",
|
||||||
|
":preprocessor",
|
||||||
":text_classifier",
|
":text_classifier",
|
||||||
":text_classifier_options",
|
":text_classifier_options",
|
||||||
],
|
],
|
||||||
|
@ -48,7 +50,10 @@ py_library(
|
||||||
py_library(
|
py_library(
|
||||||
name = "hyperparameters",
|
name = "hyperparameters",
|
||||||
srcs = ["hyperparameters.py"],
|
srcs = ["hyperparameters.py"],
|
||||||
deps = ["//mediapipe/model_maker/python/core:hyperparameters"],
|
deps = [
|
||||||
|
":bert_tokenizer",
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
@ -88,10 +93,26 @@ py_test(
|
||||||
deps = [":dataset"],
|
deps = [":dataset"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "bert_tokenizer",
|
||||||
|
srcs = ["bert_tokenizer.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "bert_tokenizer_test",
|
||||||
|
srcs = ["bert_tokenizer_test.py"],
|
||||||
|
tags = ["requires-net:external"],
|
||||||
|
deps = [
|
||||||
|
":bert_tokenizer",
|
||||||
|
":model_spec",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "preprocessor",
|
name = "preprocessor",
|
||||||
srcs = ["preprocessor.py"],
|
srcs = ["preprocessor.py"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":bert_tokenizer",
|
||||||
":dataset",
|
":dataset",
|
||||||
"//mediapipe/model_maker/python/core/data:cache_files",
|
"//mediapipe/model_maker/python/core/data:cache_files",
|
||||||
],
|
],
|
||||||
|
@ -102,6 +123,7 @@ py_test(
|
||||||
srcs = ["preprocessor_test.py"],
|
srcs = ["preprocessor_test.py"],
|
||||||
tags = ["requires-net:external"],
|
tags = ["requires-net:external"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":bert_tokenizer",
|
||||||
":dataset",
|
":dataset",
|
||||||
":model_spec",
|
":model_spec",
|
||||||
":preprocessor",
|
":preprocessor",
|
||||||
|
|
|
@ -13,10 +13,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""MediaPipe Public Python API for Text Classifier."""
|
"""MediaPipe Public Python API for Text Classifier."""
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
|
||||||
from mediapipe.model_maker.python.text.text_classifier import dataset
|
from mediapipe.model_maker.python.text.text_classifier import dataset
|
||||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters
|
from mediapipe.model_maker.python.text.text_classifier import hyperparameters
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_options
|
from mediapipe.model_maker.python.text.text_classifier import model_options
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
||||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||||
|
|
||||||
|
@ -33,12 +35,14 @@ Dataset = dataset.Dataset
|
||||||
SupportedModels = model_spec.SupportedModels
|
SupportedModels = model_spec.SupportedModels
|
||||||
TextClassifier = text_classifier.TextClassifier
|
TextClassifier = text_classifier.TextClassifier
|
||||||
TextClassifierOptions = text_classifier_options.TextClassifierOptions
|
TextClassifierOptions = text_classifier_options.TextClassifierOptions
|
||||||
|
SupportedBertTokenizers = bert_tokenizer.SupportedBertTokenizers
|
||||||
|
|
||||||
# Remove duplicated and non-public API
|
# Remove duplicated and non-public API
|
||||||
|
del bert_tokenizer
|
||||||
del hyperparameters
|
del hyperparameters
|
||||||
del dataset
|
del dataset
|
||||||
del model_options
|
del model_options
|
||||||
del model_spec
|
del model_spec
|
||||||
del preprocessor # pylint: disable=undefined-variable
|
del preprocessor
|
||||||
del text_classifier
|
del text_classifier
|
||||||
del text_classifier_options
|
del text_classifier_options
|
||||||
|
|
|
@ -0,0 +1,118 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Text classifier BERT tokenizer library."""
|
||||||
|
import abc
|
||||||
|
import enum
|
||||||
|
from typing import Mapping, Sequence
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_text as tf_text
|
||||||
|
|
||||||
|
from official.nlp.tools import tokenization
|
||||||
|
|
||||||
|
|
||||||
|
@enum.unique
|
||||||
|
class SupportedBertTokenizers(enum.Enum):
|
||||||
|
"""Supported preprocessors."""
|
||||||
|
|
||||||
|
FULL_TOKENIZER = "fulltokenizer"
|
||||||
|
FAST_BERT_TOKENIZER = "fastberttokenizer"
|
||||||
|
|
||||||
|
|
||||||
|
class BertTokenizer(abc.ABC):
|
||||||
|
"""Abstract BertTokenizer class."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __init__(self, vocab_file: str, do_lower_case: bool, seq_len: int):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BertFullTokenizer(BertTokenizer):
|
||||||
|
"""Tokenizer using the FullTokenizer from tensorflow_models."""
|
||||||
|
|
||||||
|
name = "fulltokenizer"
|
||||||
|
|
||||||
|
def __init__(self, vocab_file: str, do_lower_case: bool, seq_len: int):
|
||||||
|
self._tokenizer = tokenization.FullTokenizer(
|
||||||
|
vocab_file=vocab_file, do_lower_case=do_lower_case
|
||||||
|
)
|
||||||
|
self._seq_len = seq_len
|
||||||
|
|
||||||
|
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
|
||||||
|
tokens = self._tokenizer.tokenize(input_tensor.numpy()[0].decode("utf-8"))
|
||||||
|
tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP]
|
||||||
|
tokens.insert(0, "[CLS]")
|
||||||
|
tokens.append("[SEP]")
|
||||||
|
input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
input_mask = [1] * len(input_ids)
|
||||||
|
while len(input_ids) < self._seq_len:
|
||||||
|
input_ids.append(0)
|
||||||
|
input_mask.append(0)
|
||||||
|
segment_ids = [0] * self._seq_len
|
||||||
|
return {
|
||||||
|
"input_word_ids": input_ids,
|
||||||
|
"input_type_ids": segment_ids,
|
||||||
|
"input_mask": input_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BertFastTokenizer(BertTokenizer):
|
||||||
|
"""Tokenizer using the FastBertTokenizer from tensorflow_text.
|
||||||
|
|
||||||
|
For more information, see:
|
||||||
|
https://www.tensorflow.org/text/api_docs/python/text/FastBertTokenizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "fastberttokenizer"
|
||||||
|
|
||||||
|
def __init__(self, vocab_file: str, do_lower_case: bool, seq_len: int):
|
||||||
|
with tf.io.gfile.GFile(vocab_file, "r") as f:
|
||||||
|
vocab = f.read().splitlines()
|
||||||
|
self._tokenizer = tf_text.FastBertTokenizer(
|
||||||
|
vocab=vocab,
|
||||||
|
token_out_type=tf.int32,
|
||||||
|
support_detokenization=False,
|
||||||
|
lower_case_nfd_strip_accents=do_lower_case,
|
||||||
|
)
|
||||||
|
self._seq_len = seq_len
|
||||||
|
self._cls_id = vocab.index("[CLS]")
|
||||||
|
self._sep_id = vocab.index("[SEP]")
|
||||||
|
self._pad_id = vocab.index("[PAD]")
|
||||||
|
|
||||||
|
def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
|
||||||
|
input_ids = self._tokenizer.tokenize(input_tensor).flat_values
|
||||||
|
input_ids = input_ids[: (self._seq_len - 2)]
|
||||||
|
input_ids = tf.concat(
|
||||||
|
[
|
||||||
|
tf.constant([self._cls_id]),
|
||||||
|
input_ids,
|
||||||
|
tf.constant([self._sep_id]),
|
||||||
|
tf.fill((self._seq_len,), self._pad_id),
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
input_ids = input_ids[: self._seq_len]
|
||||||
|
input_type_ids = tf.zeros(self._seq_len, dtype=tf.int32)
|
||||||
|
input_mask = tf.cast(input_ids != self._pad_id, dtype=tf.int32)
|
||||||
|
return {
|
||||||
|
"input_word_ids": input_ids.numpy().tolist(),
|
||||||
|
"input_type_ids": input_type_ids.numpy().tolist(),
|
||||||
|
"input_mask": input_mask.numpy().tolist(),
|
||||||
|
}
|
|
@ -0,0 +1,91 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_hub
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||||
|
|
||||||
|
|
||||||
|
class BertTokenizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||||
|
# condition when downloading model since these tests may run in parallel.
|
||||||
|
mock_gettempdir = unittest_mock.patch.object(
|
||||||
|
tempfile,
|
||||||
|
'gettempdir',
|
||||||
|
return_value=self.create_tempdir(),
|
||||||
|
autospec=True,
|
||||||
|
)
|
||||||
|
self.mock_gettempdir = mock_gettempdir.start()
|
||||||
|
self.addCleanup(mock_gettempdir.stop)
|
||||||
|
ms = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||||
|
self._vocab_file = os.path.join(
|
||||||
|
tensorflow_hub.resolve(ms.get_path()), 'assets', 'vocab.txt'
|
||||||
|
)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(
|
||||||
|
testcase_name='fulltokenizer',
|
||||||
|
tokenizer_class=bert_tokenizer.BertFullTokenizer,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='fasttokenizer',
|
||||||
|
tokenizer_class=bert_tokenizer.BertFastTokenizer,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_bert_full_tokenizer(self, tokenizer_class):
|
||||||
|
tokenizer = tokenizer_class(self._vocab_file, True, 16)
|
||||||
|
text_input = tf.constant(['this is an éxamplé input ¿foo'.encode('utf-8')])
|
||||||
|
result = tokenizer.process(text_input)
|
||||||
|
self.assertAllEqual(
|
||||||
|
result['input_word_ids'],
|
||||||
|
[
|
||||||
|
101,
|
||||||
|
2023,
|
||||||
|
2003,
|
||||||
|
2019,
|
||||||
|
2742,
|
||||||
|
7953,
|
||||||
|
1094,
|
||||||
|
29379,
|
||||||
|
102,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.assertAllEqual(
|
||||||
|
result['input_mask'], [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
|
||||||
|
)
|
||||||
|
self.assertAllEqual(
|
||||||
|
result['input_type_ids'],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -18,6 +18,7 @@ import enum
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
@ -53,6 +54,8 @@ class BertHParams(hp.BaseHParams):
|
||||||
on recall. Only supported for binary classification.
|
on recall. Only supported for binary classification.
|
||||||
gamma: Gamma parameter for focal loss. To use cross entropy loss, set this
|
gamma: Gamma parameter for focal loss. To use cross entropy loss, set this
|
||||||
value to 0. Defaults to 2.0.
|
value to 0. Defaults to 2.0.
|
||||||
|
tokenizer: Tokenizer to use for preprocessing. Must be one of the enum
|
||||||
|
options of SupportedBertTokenizers. Defaults to FULL_TOKENIZER.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
learning_rate: float = 3e-5
|
learning_rate: float = 3e-5
|
||||||
|
@ -68,5 +71,9 @@ class BertHParams(hp.BaseHParams):
|
||||||
|
|
||||||
gamma: float = 2.0
|
gamma: float = 2.0
|
||||||
|
|
||||||
|
tokenizer: bert_tokenizer.SupportedBertTokenizers = (
|
||||||
|
bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
|
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
|
||||||
|
|
|
@ -24,8 +24,8 @@ import tensorflow as tf
|
||||||
import tensorflow_hub
|
import tensorflow_hub
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
|
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
|
||||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||||
from official.nlp.tools import tokenization
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_text_and_label(text: tf.Tensor, label: tf.Tensor) -> None:
|
def _validate_text_and_label(text: tf.Tensor, label: tf.Tensor) -> None:
|
||||||
|
@ -68,9 +68,9 @@ def _decode_record(
|
||||||
example[name] = tf.cast(example[name], tf.int32)
|
example[name] = tf.cast(example[name], tf.int32)
|
||||||
|
|
||||||
bert_features = {
|
bert_features = {
|
||||||
"input_word_ids": example["input_ids"],
|
"input_word_ids": example["input_word_ids"],
|
||||||
"input_mask": example["input_mask"],
|
"input_mask": example["input_mask"],
|
||||||
"input_type_ids": example["segment_ids"]
|
"input_type_ids": example["input_type_ids"],
|
||||||
}
|
}
|
||||||
return bert_features, example["label_ids"]
|
return bert_features, example["label_ids"]
|
||||||
|
|
||||||
|
@ -224,10 +224,17 @@ class BertClassifierPreprocessor:
|
||||||
tokenizer: BERT tokenizer.
|
tokenizer: BERT tokenizer.
|
||||||
model_name: Name of the model provided by the model_spec. Used to associate
|
model_name: Name of the model provided by the model_spec. Used to associate
|
||||||
cached files with specific Bert model vocab.
|
cached files with specific Bert model vocab.
|
||||||
|
preprocessor: Which preprocessor to use. Must be one of the enum values of
|
||||||
|
SupportedBertPreprocessors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, seq_len: int, do_lower_case: bool, uri: str, model_name: str
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
do_lower_case: bool,
|
||||||
|
uri: str,
|
||||||
|
model_name: str,
|
||||||
|
tokenizer: bert_tokenizer.SupportedBertTokenizers,
|
||||||
):
|
):
|
||||||
self._seq_len = seq_len
|
self._seq_len = seq_len
|
||||||
# Vocab filepath is tied to the BERT module's URI.
|
# Vocab filepath is tied to the BERT module's URI.
|
||||||
|
@ -235,17 +242,27 @@ class BertClassifierPreprocessor:
|
||||||
tensorflow_hub.resolve(uri), "assets", "vocab.txt"
|
tensorflow_hub.resolve(uri), "assets", "vocab.txt"
|
||||||
)
|
)
|
||||||
self._do_lower_case = do_lower_case
|
self._do_lower_case = do_lower_case
|
||||||
self._tokenizer = tokenization.FullTokenizer(
|
self._tokenizer: bert_tokenizer.BertTokenizer = None
|
||||||
self._vocab_file, self._do_lower_case
|
if tokenizer == bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER:
|
||||||
|
self._tokenizer = bert_tokenizer.BertFullTokenizer(
|
||||||
|
self._vocab_file, self._do_lower_case, self._seq_len
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
tokenizer == bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER
|
||||||
|
):
|
||||||
|
self._tokenizer = bert_tokenizer.BertFastTokenizer(
|
||||||
|
self._vocab_file, self._do_lower_case, self._seq_len
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported tokenizer: {tokenizer}")
|
||||||
self._model_name = model_name
|
self._model_name = model_name
|
||||||
|
|
||||||
def _get_name_to_features(self):
|
def _get_name_to_features(self):
|
||||||
"""Gets the dictionary mapping record keys to feature types."""
|
"""Gets the dictionary mapping record keys to feature types."""
|
||||||
return {
|
return {
|
||||||
"input_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
"input_word_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
||||||
"input_mask": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
"input_mask": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
||||||
"segment_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
"input_type_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
||||||
"label_ids": tf.io.FixedLenFeature([], tf.int64),
|
"label_ids": tf.io.FixedLenFeature([], tf.int64),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -269,6 +286,7 @@ class BertClassifierPreprocessor:
|
||||||
2. model_name
|
2. model_name
|
||||||
3. seq_len
|
3. seq_len
|
||||||
4. do_lower_case
|
4. do_lower_case
|
||||||
|
5. tokenizer name
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ds_cache_files: TFRecordCacheFiles from the original raw dataset object
|
ds_cache_files: TFRecordCacheFiles from the original raw dataset object
|
||||||
|
@ -282,6 +300,7 @@ class BertClassifierPreprocessor:
|
||||||
hasher.update(self._model_name.encode("utf-8"))
|
hasher.update(self._model_name.encode("utf-8"))
|
||||||
hasher.update(str(self._seq_len).encode("utf-8"))
|
hasher.update(str(self._seq_len).encode("utf-8"))
|
||||||
hasher.update(str(self._do_lower_case).encode("utf-8"))
|
hasher.update(str(self._do_lower_case).encode("utf-8"))
|
||||||
|
hasher.update(self._tokenizer.name.encode("utf-8"))
|
||||||
cache_prefix_filename = hasher.hexdigest()
|
cache_prefix_filename = hasher.hexdigest()
|
||||||
return cache_files_lib.TFRecordCacheFiles(
|
return cache_files_lib.TFRecordCacheFiles(
|
||||||
cache_prefix_filename,
|
cache_prefix_filename,
|
||||||
|
@ -289,23 +308,6 @@ class BertClassifierPreprocessor:
|
||||||
ds_cache_files.num_shards,
|
ds_cache_files.num_shards,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _process_bert_features(self, text: str) -> Mapping[str, Sequence[int]]:
|
|
||||||
tokens = self._tokenizer.tokenize(text)
|
|
||||||
tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP]
|
|
||||||
tokens.insert(0, "[CLS]")
|
|
||||||
tokens.append("[SEP]")
|
|
||||||
input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
|
|
||||||
input_mask = [1] * len(input_ids)
|
|
||||||
while len(input_ids) < self._seq_len:
|
|
||||||
input_ids.append(0)
|
|
||||||
input_mask.append(0)
|
|
||||||
segment_ids = [0] * self._seq_len
|
|
||||||
return {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"input_mask": input_mask,
|
|
||||||
"segment_ids": segment_ids,
|
|
||||||
}
|
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self, dataset: text_classifier_ds.Dataset
|
self, dataset: text_classifier_ds.Dataset
|
||||||
) -> text_classifier_ds.Dataset:
|
) -> text_classifier_ds.Dataset:
|
||||||
|
@ -326,18 +328,20 @@ class BertClassifierPreprocessor:
|
||||||
size = 0
|
size = 0
|
||||||
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
|
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
|
||||||
_validate_text_and_label(text, label)
|
_validate_text_and_label(text, label)
|
||||||
feature = self._process_bert_features(text.numpy()[0].decode("utf-8"))
|
feature = self._tokenizer.process(text)
|
||||||
def create_int_feature(values):
|
def create_int_feature(values):
|
||||||
f = tf.train.Feature(
|
f = tf.train.Feature(int64_list=tf.train.Int64List(value=values))
|
||||||
int64_list=tf.train.Int64List(value=list(values))
|
|
||||||
)
|
|
||||||
return f
|
return f
|
||||||
|
|
||||||
features = collections.OrderedDict()
|
features = collections.OrderedDict()
|
||||||
features["input_ids"] = create_int_feature(feature["input_ids"])
|
features["input_word_ids"] = create_int_feature(
|
||||||
|
feature["input_word_ids"]
|
||||||
|
)
|
||||||
features["input_mask"] = create_int_feature(feature["input_mask"])
|
features["input_mask"] = create_int_feature(feature["input_mask"])
|
||||||
features["segment_ids"] = create_int_feature(feature["segment_ids"])
|
features["input_type_ids"] = create_int_feature(
|
||||||
features["label_ids"] = create_int_feature([label.numpy()[0]])
|
feature["input_type_ids"]
|
||||||
|
)
|
||||||
|
features["label_ids"] = create_int_feature(label.numpy().tolist())
|
||||||
tf_example = tf.train.Example(
|
tf_example = tf.train.Example(
|
||||||
features=tf.train.Features(feature=features)
|
features=tf.train.Features(feature=features)
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,18 +18,20 @@ import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest import mock as unittest_mock
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import mock
|
import mock
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.testing as npt
|
import numpy.testing as npt
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.data import cache_files
|
from mediapipe.model_maker.python.core.data import cache_files
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
|
||||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||||
|
|
||||||
|
|
||||||
class PreprocessorTest(tf.test.TestCase):
|
class PreprocessorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
CSV_PARAMS_ = text_classifier_ds.CSVParameters(
|
CSV_PARAMS_ = text_classifier_ds.CSVParameters(
|
||||||
text_column='text', label_column='label')
|
text_column='text', label_column='label')
|
||||||
|
|
||||||
|
@ -83,7 +85,19 @@ class PreprocessorTest(tf.test.TestCase):
|
||||||
npt.assert_array_equal(
|
npt.assert_array_equal(
|
||||||
np.stack(features_list), np.array([[1, 3, 3, 3, 3], [1, 5, 6, 0, 0]]))
|
np.stack(features_list), np.array([[1, 3, 3, 3, 3], [1, 5, 6, 0, 0]]))
|
||||||
|
|
||||||
def test_bert_preprocessor(self):
|
@parameterized.named_parameters(
|
||||||
|
dict(
|
||||||
|
testcase_name='fulltokenizer',
|
||||||
|
tokenizer=bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
testcase_name='fastberttokenizer',
|
||||||
|
tokenizer=bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_bert_preprocessor(
|
||||||
|
self, tokenizer: bert_tokenizer.SupportedBertTokenizers
|
||||||
|
):
|
||||||
csv_file = self._get_csv_file()
|
csv_file = self._get_csv_file()
|
||||||
dataset = text_classifier_ds.Dataset.from_csv(
|
dataset = text_classifier_ds.Dataset.from_csv(
|
||||||
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
||||||
|
@ -93,6 +107,7 @@ class PreprocessorTest(tf.test.TestCase):
|
||||||
do_lower_case=bert_spec.do_lower_case,
|
do_lower_case=bert_spec.do_lower_case,
|
||||||
uri=bert_spec.get_path(),
|
uri=bert_spec.get_path(),
|
||||||
model_name=bert_spec.name,
|
model_name=bert_spec.name,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
||||||
labels = []
|
labels = []
|
||||||
|
@ -122,11 +137,13 @@ class PreprocessorTest(tf.test.TestCase):
|
||||||
cache_dir=self.get_temp_dir(),
|
cache_dir=self.get_temp_dir(),
|
||||||
)
|
)
|
||||||
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||||
|
tokenizer = bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
|
||||||
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||||
seq_len=5,
|
seq_len=5,
|
||||||
do_lower_case=bert_spec.do_lower_case,
|
do_lower_case=bert_spec.do_lower_case,
|
||||||
uri=bert_spec.get_path(),
|
uri=bert_spec.get_path(),
|
||||||
model_name=bert_spec.name,
|
model_name=bert_spec.name,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
ds_cache_files = dataset.tfrecord_cache_files
|
ds_cache_files = dataset.tfrecord_cache_files
|
||||||
preprocessed_cache_files = bert_preprocessor._get_tfrecord_cache_files(
|
preprocessed_cache_files = bert_preprocessor._get_tfrecord_cache_files(
|
||||||
|
@ -149,12 +166,13 @@ class PreprocessorTest(tf.test.TestCase):
|
||||||
f' {preprocessed_cache_files.cache_prefix}\n',
|
f' {preprocessed_cache_files.cache_prefix}\n',
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case):
|
def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case, tokenizer):
|
||||||
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||||
seq_len=seq_len,
|
seq_len=seq_len,
|
||||||
do_lower_case=do_lower_case,
|
do_lower_case=do_lower_case,
|
||||||
uri=bert_spec.get_path(),
|
uri=bert_spec.get_path(),
|
||||||
model_name=bert_spec.name,
|
model_name=bert_spec.name,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
new_cf = bert_preprocessor._get_tfrecord_cache_files(cf)
|
new_cf = bert_preprocessor._get_tfrecord_cache_files(cf)
|
||||||
return new_cf.cache_prefix_filename
|
return new_cf.cache_prefix_filename
|
||||||
|
@ -167,19 +185,31 @@ class PreprocessorTest(tf.test.TestCase):
|
||||||
cache_dir=self.get_temp_dir(),
|
cache_dir=self.get_temp_dir(),
|
||||||
num_shards=1,
|
num_shards=1,
|
||||||
)
|
)
|
||||||
|
tokenizer = bert_tokenizer.SupportedBertTokenizers.FULL_TOKENIZER
|
||||||
mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||||
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True))
|
all_cf_prefixes.add(
|
||||||
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True))
|
self._get_new_prefix(cf, mobilebert_spec, 5, True, tokenizer)
|
||||||
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, False))
|
)
|
||||||
|
all_cf_prefixes.add(
|
||||||
|
self._get_new_prefix(cf, mobilebert_spec, 10, True, tokenizer)
|
||||||
|
)
|
||||||
|
all_cf_prefixes.add(
|
||||||
|
self._get_new_prefix(cf, mobilebert_spec, 5, False, tokenizer)
|
||||||
|
)
|
||||||
new_cf = cache_files.TFRecordCacheFiles(
|
new_cf = cache_files.TFRecordCacheFiles(
|
||||||
cache_prefix_filename='new_cache_prefix',
|
cache_prefix_filename='new_cache_prefix',
|
||||||
cache_dir=self.get_temp_dir(),
|
cache_dir=self.get_temp_dir(),
|
||||||
num_shards=1,
|
num_shards=1,
|
||||||
)
|
)
|
||||||
all_cf_prefixes.add(self._get_new_prefix(new_cf, mobilebert_spec, 5, True))
|
all_cf_prefixes.add(
|
||||||
|
self._get_new_prefix(new_cf, mobilebert_spec, 5, True, tokenizer)
|
||||||
|
)
|
||||||
|
new_tokenizer = bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER
|
||||||
|
all_cf_prefixes.add(
|
||||||
|
self._get_new_prefix(cf, mobilebert_spec, 5, True, new_tokenizer)
|
||||||
|
)
|
||||||
# Each item of all_cf_prefixes should be unique.
|
# Each item of all_cf_prefixes should be unique.
|
||||||
self.assertLen(all_cf_prefixes, 4)
|
self.assertLen(all_cf_prefixes, 5)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -437,6 +437,7 @@ class _BertClassifier(TextClassifier):
|
||||||
do_lower_case=self._model_spec.do_lower_case,
|
do_lower_case=self._model_spec.do_lower_case,
|
||||||
uri=self._model_spec.get_path(),
|
uri=self._model_spec.get_path(),
|
||||||
model_name=self._model_spec.name,
|
model_name=self._model_spec.name,
|
||||||
|
tokenizer=self._hparams.tokenizer,
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
self._text_preprocessor.preprocess(train_data),
|
self._text_preprocessor.preprocess(train_data),
|
||||||
|
|
|
@ -6,4 +6,5 @@ tensorflow>=2.10
|
||||||
tensorflow-addons
|
tensorflow-addons
|
||||||
tensorflow-datasets
|
tensorflow-datasets
|
||||||
tensorflow-hub
|
tensorflow-hub
|
||||||
|
tensorflow-text
|
||||||
tf-models-official>=2.13.1
|
tf-models-official>=2.13.1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user