Refactor text_classifier preprocessor to move away from using classifier_data_lib
PiperOrigin-RevId: 556859900
This commit is contained in:
parent
3ac3b03ed5
commit
c8ad606e7c
|
@ -25,7 +25,6 @@ import tensorflow_hub
|
|||
|
||||
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||
from official.nlp.data import classifier_data_lib
|
||||
from official.nlp.tools import tokenization
|
||||
|
||||
|
||||
|
@ -290,6 +289,23 @@ class BertClassifierPreprocessor:
|
|||
ds_cache_files.num_shards,
|
||||
)
|
||||
|
||||
def _process_bert_features(self, text: str) -> Mapping[str, Sequence[int]]:
|
||||
tokens = self._tokenizer.tokenize(text)
|
||||
tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP]
|
||||
tokens.insert(0, "[CLS]")
|
||||
tokens.append("[SEP]")
|
||||
input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
while len(input_ids) < self._seq_len:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids = [0] * self._seq_len
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"input_mask": input_mask,
|
||||
"segment_ids": segment_ids,
|
||||
}
|
||||
|
||||
def preprocess(
|
||||
self, dataset: text_classifier_ds.Dataset
|
||||
) -> text_classifier_ds.Dataset:
|
||||
|
@ -310,18 +326,7 @@ class BertClassifierPreprocessor:
|
|||
size = 0
|
||||
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
|
||||
_validate_text_and_label(text, label)
|
||||
example = classifier_data_lib.InputExample(
|
||||
guid=str(index),
|
||||
text_a=text.numpy()[0].decode("utf-8"),
|
||||
text_b=None,
|
||||
# InputExample expects the label name rather than the int ID
|
||||
# label=dataset.label_names[label.numpy()[0]])
|
||||
label=label.numpy()[0],
|
||||
)
|
||||
feature = classifier_data_lib.convert_single_example(
|
||||
index, example, None, self._seq_len, self._tokenizer
|
||||
)
|
||||
|
||||
feature = self._process_bert_features(text.numpy()[0].decode("utf-8"))
|
||||
def create_int_feature(values):
|
||||
f = tf.train.Feature(
|
||||
int64_list=tf.train.Int64List(value=list(values))
|
||||
|
@ -329,10 +334,10 @@ class BertClassifierPreprocessor:
|
|||
return f
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = create_int_feature(feature.input_ids)
|
||||
features["input_mask"] = create_int_feature(feature.input_mask)
|
||||
features["segment_ids"] = create_int_feature(feature.segment_ids)
|
||||
features["label_ids"] = create_int_feature([feature.label_id])
|
||||
features["input_ids"] = create_int_feature(feature["input_ids"])
|
||||
features["input_mask"] = create_int_feature(feature["input_mask"])
|
||||
features["segment_ids"] = create_int_feature(feature["segment_ids"])
|
||||
features["label_ids"] = create_int_feature([label.numpy()[0]])
|
||||
tf_example = tf.train.Example(
|
||||
features=tf.train.Features(feature=features)
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user