Refactor text_classifier preprocessor to move away from using classifier_data_lib
PiperOrigin-RevId: 556859900
This commit is contained in:
parent
3ac3b03ed5
commit
c8ad606e7c
|
@ -25,7 +25,6 @@ import tensorflow_hub
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
|
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
|
||||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||||
from official.nlp.data import classifier_data_lib
|
|
||||||
from official.nlp.tools import tokenization
|
from official.nlp.tools import tokenization
|
||||||
|
|
||||||
|
|
||||||
|
@ -290,6 +289,23 @@ class BertClassifierPreprocessor:
|
||||||
ds_cache_files.num_shards,
|
ds_cache_files.num_shards,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _process_bert_features(self, text: str) -> Mapping[str, Sequence[int]]:
|
||||||
|
tokens = self._tokenizer.tokenize(text)
|
||||||
|
tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP]
|
||||||
|
tokens.insert(0, "[CLS]")
|
||||||
|
tokens.append("[SEP]")
|
||||||
|
input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
input_mask = [1] * len(input_ids)
|
||||||
|
while len(input_ids) < self._seq_len:
|
||||||
|
input_ids.append(0)
|
||||||
|
input_mask.append(0)
|
||||||
|
segment_ids = [0] * self._seq_len
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"input_mask": input_mask,
|
||||||
|
"segment_ids": segment_ids,
|
||||||
|
}
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self, dataset: text_classifier_ds.Dataset
|
self, dataset: text_classifier_ds.Dataset
|
||||||
) -> text_classifier_ds.Dataset:
|
) -> text_classifier_ds.Dataset:
|
||||||
|
@ -310,18 +326,7 @@ class BertClassifierPreprocessor:
|
||||||
size = 0
|
size = 0
|
||||||
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
|
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
|
||||||
_validate_text_and_label(text, label)
|
_validate_text_and_label(text, label)
|
||||||
example = classifier_data_lib.InputExample(
|
feature = self._process_bert_features(text.numpy()[0].decode("utf-8"))
|
||||||
guid=str(index),
|
|
||||||
text_a=text.numpy()[0].decode("utf-8"),
|
|
||||||
text_b=None,
|
|
||||||
# InputExample expects the label name rather than the int ID
|
|
||||||
# label=dataset.label_names[label.numpy()[0]])
|
|
||||||
label=label.numpy()[0],
|
|
||||||
)
|
|
||||||
feature = classifier_data_lib.convert_single_example(
|
|
||||||
index, example, None, self._seq_len, self._tokenizer
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_int_feature(values):
|
def create_int_feature(values):
|
||||||
f = tf.train.Feature(
|
f = tf.train.Feature(
|
||||||
int64_list=tf.train.Int64List(value=list(values))
|
int64_list=tf.train.Int64List(value=list(values))
|
||||||
|
@ -329,10 +334,10 @@ class BertClassifierPreprocessor:
|
||||||
return f
|
return f
|
||||||
|
|
||||||
features = collections.OrderedDict()
|
features = collections.OrderedDict()
|
||||||
features["input_ids"] = create_int_feature(feature.input_ids)
|
features["input_ids"] = create_int_feature(feature["input_ids"])
|
||||||
features["input_mask"] = create_int_feature(feature.input_mask)
|
features["input_mask"] = create_int_feature(feature["input_mask"])
|
||||||
features["segment_ids"] = create_int_feature(feature.segment_ids)
|
features["segment_ids"] = create_int_feature(feature["segment_ids"])
|
||||||
features["label_ids"] = create_int_feature([feature.label_id])
|
features["label_ids"] = create_int_feature([label.numpy()[0]])
|
||||||
tf_example = tf.train.Example(
|
tf_example = tf.train.Example(
|
||||||
features=tf.train.Features(feature=features)
|
features=tf.train.Features(feature=features)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user