diff --git a/mediapipe/model_maker/python/text/text_classifier/preprocessor.py b/mediapipe/model_maker/python/text/text_classifier/preprocessor.py index 2a31bbd09..68a5df2fd 100644 --- a/mediapipe/model_maker/python/text/text_classifier/preprocessor.py +++ b/mediapipe/model_maker/python/text/text_classifier/preprocessor.py @@ -25,7 +25,6 @@ import tensorflow_hub from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds -from official.nlp.data import classifier_data_lib from official.nlp.tools import tokenization @@ -290,6 +289,23 @@ class BertClassifierPreprocessor: ds_cache_files.num_shards, ) + def _process_bert_features(self, text: str) -> Mapping[str, Sequence[int]]: + tokens = self._tokenizer.tokenize(text) + tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP] + tokens.insert(0, "[CLS]") + tokens.append("[SEP]") + input_ids = self._tokenizer.convert_tokens_to_ids(tokens) + input_mask = [1] * len(input_ids) + while len(input_ids) < self._seq_len: + input_ids.append(0) + input_mask.append(0) + segment_ids = [0] * self._seq_len + return { + "input_ids": input_ids, + "input_mask": input_mask, + "segment_ids": segment_ids, + } + def preprocess( self, dataset: text_classifier_ds.Dataset ) -> text_classifier_ds.Dataset: @@ -310,18 +326,7 @@ class BertClassifierPreprocessor: size = 0 for index, (text, label) in enumerate(dataset.gen_tf_dataset()): _validate_text_and_label(text, label) - example = classifier_data_lib.InputExample( - guid=str(index), - text_a=text.numpy()[0].decode("utf-8"), - text_b=None, - # InputExample expects the label name rather than the int ID - # label=dataset.label_names[label.numpy()[0]]) - label=label.numpy()[0], - ) - feature = classifier_data_lib.convert_single_example( - index, example, None, self._seq_len, self._tokenizer - ) - + feature = self._process_bert_features(text.numpy()[0].decode("utf-8")) def create_int_feature(values): f = tf.train.Feature( int64_list=tf.train.Int64List(value=list(values)) @@ -329,10 +334,10 @@ class BertClassifierPreprocessor: return f features = collections.OrderedDict() - features["input_ids"] = create_int_feature(feature.input_ids) - features["input_mask"] = create_int_feature(feature.input_mask) - features["segment_ids"] = create_int_feature(feature.segment_ids) - features["label_ids"] = create_int_feature([feature.label_id]) + features["input_ids"] = create_int_feature(feature["input_ids"]) + features["input_mask"] = create_int_feature(feature["input_mask"]) + features["segment_ids"] = create_int_feature(feature["segment_ids"]) + features["label_ids"] = create_int_feature([label.numpy()[0]]) tf_example = tf.train.Example( features=tf.train.Features(feature=features) )