Refactor text_classifier preprocessor to move away from using classifier_data_lib

PiperOrigin-RevId: 556859900
This commit is contained in:
MediaPipe Team 2023-08-14 11:34:10 -07:00 committed by Copybara-Service
parent 3ac3b03ed5
commit c8ad606e7c

View File

@ -25,7 +25,6 @@ import tensorflow_hub
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
from official.nlp.data import classifier_data_lib
from official.nlp.tools import tokenization from official.nlp.tools import tokenization
@ -290,6 +289,23 @@ class BertClassifierPreprocessor:
ds_cache_files.num_shards, ds_cache_files.num_shards,
) )
def _process_bert_features(self, text: str) -> Mapping[str, Sequence[int]]:
tokens = self._tokenizer.tokenize(text)
tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP]
tokens.insert(0, "[CLS]")
tokens.append("[SEP]")
input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < self._seq_len:
input_ids.append(0)
input_mask.append(0)
segment_ids = [0] * self._seq_len
return {
"input_ids": input_ids,
"input_mask": input_mask,
"segment_ids": segment_ids,
}
def preprocess( def preprocess(
self, dataset: text_classifier_ds.Dataset self, dataset: text_classifier_ds.Dataset
) -> text_classifier_ds.Dataset: ) -> text_classifier_ds.Dataset:
@ -310,18 +326,7 @@ class BertClassifierPreprocessor:
size = 0 size = 0
for index, (text, label) in enumerate(dataset.gen_tf_dataset()): for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
_validate_text_and_label(text, label) _validate_text_and_label(text, label)
example = classifier_data_lib.InputExample( feature = self._process_bert_features(text.numpy()[0].decode("utf-8"))
guid=str(index),
text_a=text.numpy()[0].decode("utf-8"),
text_b=None,
# InputExample expects the label name rather than the int ID
# label=dataset.label_names[label.numpy()[0]])
label=label.numpy()[0],
)
feature = classifier_data_lib.convert_single_example(
index, example, None, self._seq_len, self._tokenizer
)
def create_int_feature(values): def create_int_feature(values):
f = tf.train.Feature( f = tf.train.Feature(
int64_list=tf.train.Int64List(value=list(values)) int64_list=tf.train.Int64List(value=list(values))
@ -329,10 +334,10 @@ class BertClassifierPreprocessor:
return f return f
features = collections.OrderedDict() features = collections.OrderedDict()
features["input_ids"] = create_int_feature(feature.input_ids) features["input_ids"] = create_int_feature(feature["input_ids"])
features["input_mask"] = create_int_feature(feature.input_mask) features["input_mask"] = create_int_feature(feature["input_mask"])
features["segment_ids"] = create_int_feature(feature.segment_ids) features["segment_ids"] = create_int_feature(feature["segment_ids"])
features["label_ids"] = create_int_feature([feature.label_id]) features["label_ids"] = create_int_feature([label.numpy()[0]])
tf_example = tf.train.Example( tf_example = tf.train.Example(
features=tf.train.Features(feature=features) features=tf.train.Features(feature=features)
) )