No public description
PiperOrigin-RevId: 550954023
This commit is contained in:
parent
113c9b30c2
commit
62538a9496
|
@ -45,6 +45,8 @@ class TFRecordCacheFiles:
|
|||
num_shards: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
if not tf.io.gfile.exists(self.cache_dir):
|
||||
tf.io.gfile.makedirs(self.cache_dir)
|
||||
if not self.cache_prefix_filename:
|
||||
raise ValueError('cache_prefix_filename cannot be empty.')
|
||||
if self.num_shards <= 0:
|
||||
|
@ -79,8 +81,6 @@ class TFRecordCacheFiles:
|
|||
Returns:
|
||||
Array of TFRecordWriter objects
|
||||
"""
|
||||
if not tf.io.gfile.exists(self.cache_dir):
|
||||
tf.io.gfile.makedirs(self.cache_dir)
|
||||
return [tf.io.TFRecordWriter(path) for path in self.tfrecord_files]
|
||||
|
||||
def save_metadata(self, metadata):
|
||||
|
|
|
@ -76,7 +76,10 @@ py_test(
|
|||
py_library(
|
||||
name = "dataset",
|
||||
srcs = ["dataset.py"],
|
||||
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"],
|
||||
deps = [
|
||||
"//mediapipe/model_maker/python/core/data:cache_files",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -88,7 +91,10 @@ py_test(
|
|||
py_library(
|
||||
name = "preprocessor",
|
||||
srcs = ["preprocessor.py"],
|
||||
deps = [":dataset"],
|
||||
deps = [
|
||||
":dataset",
|
||||
"//mediapipe/model_maker/python/core/data:cache_files",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -99,6 +105,7 @@ py_test(
|
|||
":dataset",
|
||||
":model_spec",
|
||||
":preprocessor",
|
||||
"//mediapipe/model_maker/python/core/data:cache_files",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -15,11 +15,15 @@
|
|||
|
||||
import csv
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
from typing import Optional, Sequence
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
|
||||
|
||||
|
@ -46,21 +50,49 @@ class CSVParameters:
|
|||
class Dataset(classification_dataset.ClassificationDataset):
|
||||
"""Dataset library for text classifier."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: tf.data.Dataset,
|
||||
label_names: List[str],
|
||||
tfrecord_cache_files: Optional[cache_files_lib.TFRecordCacheFiles] = None,
|
||||
size: Optional[int] = None,
|
||||
):
|
||||
super().__init__(dataset, label_names, size)
|
||||
if not tfrecord_cache_files:
|
||||
tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles(
|
||||
cache_prefix_filename="tfrecord", num_shards=1
|
||||
)
|
||||
self.tfrecord_cache_files = tfrecord_cache_files
|
||||
|
||||
@classmethod
|
||||
def from_csv(cls,
|
||||
def from_csv(
|
||||
cls,
|
||||
filename: str,
|
||||
csv_params: CSVParameters,
|
||||
shuffle: bool = True) -> "Dataset":
|
||||
shuffle: bool = True,
|
||||
cache_dir: Optional[str] = None,
|
||||
num_shards: int = 1,
|
||||
) -> "Dataset":
|
||||
"""Loads text with labels from a CSV file.
|
||||
|
||||
Args:
|
||||
filename: Name of the CSV file.
|
||||
csv_params: Parameters used for reading the CSV file.
|
||||
shuffle: If True, randomly shuffle the data.
|
||||
cache_dir: Optional parameter to specify where to store the preprocessed
|
||||
dataset. Only used for BERT models.
|
||||
num_shards: Optional parameter for num shards of the preprocessed dataset.
|
||||
Note that using more than 1 shard will reorder the dataset. Only used
|
||||
for BERT models.
|
||||
|
||||
Returns:
|
||||
Dataset containing (text, label) pairs and other related info.
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = tempfile.mkdtemp()
|
||||
# calculate hash for cache based off of files
|
||||
hasher = hashlib.md5()
|
||||
hasher.update(os.path.basename(filename).encode("utf-8"))
|
||||
with tf.io.gfile.GFile(filename, "r") as f:
|
||||
reader = csv.DictReader(
|
||||
f,
|
||||
|
@ -69,6 +101,9 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
quotechar=csv_params.quotechar)
|
||||
|
||||
lines = list(reader)
|
||||
for line in lines:
|
||||
hasher.update(str(line).encode("utf-8"))
|
||||
|
||||
if shuffle:
|
||||
random.shuffle(lines)
|
||||
|
||||
|
@ -81,9 +116,18 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
index_by_label[line[csv_params.label_column]] for line in lines
|
||||
]
|
||||
label_index_ds = tf.data.Dataset.from_tensor_slices(
|
||||
tf.cast(label_indices, tf.int64))
|
||||
tf.cast(label_indices, tf.int64)
|
||||
)
|
||||
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
|
||||
|
||||
return Dataset(
|
||||
dataset=text_label_ds, label_names=label_names, size=len(texts)
|
||||
hasher.update(str(num_shards).encode("utf-8"))
|
||||
cache_prefix_filename = hasher.hexdigest()
|
||||
tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles(
|
||||
cache_prefix_filename, cache_dir, num_shards
|
||||
)
|
||||
return Dataset(
|
||||
dataset=text_label_ds,
|
||||
label_names=label_names,
|
||||
tfrecord_cache_files=tfrecord_cache_files,
|
||||
size=len(texts),
|
||||
)
|
||||
|
|
|
@ -53,7 +53,7 @@ class DatasetTest(tf.test.TestCase):
|
|||
|
||||
def test_split(self):
|
||||
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
|
||||
data = dataset.Dataset(ds, ['pos', 'neg'], 4)
|
||||
data = dataset.Dataset(ds, ['pos', 'neg'], size=4)
|
||||
train_data, test_data = data.split(0.5)
|
||||
expected_train_data = [b'good', b'bad']
|
||||
expected_test_data = [b'neutral', b'odd']
|
||||
|
|
|
@ -15,14 +15,15 @@
|
|||
"""Preprocessors for text classification."""
|
||||
|
||||
import collections
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from typing import Mapping, Sequence, Tuple, Union
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub
|
||||
|
||||
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||
from official.nlp.data import classifier_data_lib
|
||||
from official.nlp.tools import tokenization
|
||||
|
@ -75,19 +76,20 @@ def _decode_record(
|
|||
return bert_features, example["label_ids"]
|
||||
|
||||
|
||||
def _single_file_dataset(
|
||||
input_file: str, name_to_features: Mapping[str, tf.io.FixedLenFeature]
|
||||
def _tfrecord_dataset(
|
||||
tfrecord_files: Sequence[str],
|
||||
name_to_features: Mapping[str, tf.io.FixedLenFeature],
|
||||
) -> tf.data.TFRecordDataset:
|
||||
"""Creates a single-file dataset to be passed for BERT custom training.
|
||||
|
||||
Args:
|
||||
input_file: Filepath for the dataset.
|
||||
tfrecord_files: Filepaths for the dataset.
|
||||
name_to_features: Maps record keys to feature types.
|
||||
|
||||
Returns:
|
||||
Dataset containing BERT model input features and labels.
|
||||
"""
|
||||
d = tf.data.TFRecordDataset(input_file)
|
||||
d = tf.data.TFRecordDataset(tfrecord_files)
|
||||
d = d.map(
|
||||
lambda record: _decode_record(record, name_to_features),
|
||||
num_parallel_calls=tf.data.AUTOTUNE)
|
||||
|
@ -221,15 +223,23 @@ class BertClassifierPreprocessor:
|
|||
seq_len: Length of the input sequence to the model.
|
||||
vocab_file: File containing the BERT vocab.
|
||||
tokenizer: BERT tokenizer.
|
||||
model_name: Name of the model provided by the model_spec. Used to associate
|
||||
cached files with specific Bert model vocab.
|
||||
"""
|
||||
|
||||
def __init__(self, seq_len: int, do_lower_case: bool, uri: str):
|
||||
def __init__(
|
||||
self, seq_len: int, do_lower_case: bool, uri: str, model_name: str
|
||||
):
|
||||
self._seq_len = seq_len
|
||||
# Vocab filepath is tied to the BERT module's URI.
|
||||
self._vocab_file = os.path.join(
|
||||
tensorflow_hub.resolve(uri), "assets", "vocab.txt")
|
||||
self._tokenizer = tokenization.FullTokenizer(self._vocab_file,
|
||||
do_lower_case)
|
||||
tensorflow_hub.resolve(uri), "assets", "vocab.txt"
|
||||
)
|
||||
self._do_lower_case = do_lower_case
|
||||
self._tokenizer = tokenization.FullTokenizer(
|
||||
self._vocab_file, self._do_lower_case
|
||||
)
|
||||
self._model_name = model_name
|
||||
|
||||
def _get_name_to_features(self):
|
||||
"""Gets the dictionary mapping record keys to feature types."""
|
||||
|
@ -244,8 +254,45 @@ class BertClassifierPreprocessor:
|
|||
"""Returns the vocab file of the BertClassifierPreprocessor."""
|
||||
return self._vocab_file
|
||||
|
||||
def _get_tfrecord_cache_files(
|
||||
self, ds_cache_files
|
||||
) -> cache_files_lib.TFRecordCacheFiles:
|
||||
"""Helper to regenerate cache prefix filename using preprocessor info.
|
||||
|
||||
We need to update the dataset cache_prefix cache because the actual cached
|
||||
dataset depends on the preprocessor parameters such as model_name, seq_len,
|
||||
and do_lower_case in addition to the raw dataset parameters which is already
|
||||
included in the ds_cache_files.cache_prefix_filename
|
||||
|
||||
Specifically, the new cache_prefix_filename used by the preprocessor will
|
||||
be a hash generated from the following:
|
||||
1. cache_prefix_filename of the initial raw dataset
|
||||
2. model_name
|
||||
3. seq_len
|
||||
4. do_lower_case
|
||||
|
||||
Args:
|
||||
ds_cache_files: TFRecordCacheFiles from the original raw dataset object
|
||||
|
||||
Returns:
|
||||
A new TFRecordCacheFiles object which incorporates the preprocessor
|
||||
parameters.
|
||||
"""
|
||||
hasher = hashlib.md5()
|
||||
hasher.update(ds_cache_files.cache_prefix_filename.encode("utf-8"))
|
||||
hasher.update(self._model_name.encode("utf-8"))
|
||||
hasher.update(str(self._seq_len).encode("utf-8"))
|
||||
hasher.update(str(self._do_lower_case).encode("utf-8"))
|
||||
cache_prefix_filename = hasher.hexdigest()
|
||||
return cache_files_lib.TFRecordCacheFiles(
|
||||
cache_prefix_filename,
|
||||
ds_cache_files.cache_dir,
|
||||
ds_cache_files.num_shards,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset:
|
||||
self, dataset: text_classifier_ds.Dataset
|
||||
) -> text_classifier_ds.Dataset:
|
||||
"""Preprocesses data into input for a BERT-based classifier.
|
||||
|
||||
Args:
|
||||
|
@ -254,32 +301,65 @@ class BertClassifierPreprocessor:
|
|||
Returns:
|
||||
Dataset containing (bert_features, label) data.
|
||||
"""
|
||||
examples = []
|
||||
ds_cache_files = dataset.tfrecord_cache_files
|
||||
# Get new tfrecord_cache_files by including preprocessor information.
|
||||
tfrecord_cache_files = self._get_tfrecord_cache_files(ds_cache_files)
|
||||
if not tfrecord_cache_files.is_cached():
|
||||
print(f"Writing new cache files to {tfrecord_cache_files.cache_prefix}")
|
||||
writers = tfrecord_cache_files.get_writers()
|
||||
size = 0
|
||||
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
|
||||
_validate_text_and_label(text, label)
|
||||
examples.append(
|
||||
classifier_data_lib.InputExample(
|
||||
example = classifier_data_lib.InputExample(
|
||||
guid=str(index),
|
||||
text_a=text.numpy()[0].decode("utf-8"),
|
||||
text_b=None,
|
||||
# InputExample expects the label name rather than the int ID
|
||||
label=dataset.label_names[label.numpy()[0]]))
|
||||
# label=dataset.label_names[label.numpy()[0]])
|
||||
label=label.numpy()[0],
|
||||
)
|
||||
feature = classifier_data_lib.convert_single_example(
|
||||
index, example, None, self._seq_len, self._tokenizer
|
||||
)
|
||||
|
||||
tfrecord_file = os.path.join(tempfile.mkdtemp(), "bert_features.tfrecord")
|
||||
classifier_data_lib.file_based_convert_examples_to_features(
|
||||
examples=examples,
|
||||
label_list=dataset.label_names,
|
||||
max_seq_length=self._seq_len,
|
||||
tokenizer=self._tokenizer,
|
||||
output_file=tfrecord_file)
|
||||
preprocessed_ds = _single_file_dataset(tfrecord_file,
|
||||
self._get_name_to_features())
|
||||
def create_int_feature(values):
|
||||
f = tf.train.Feature(
|
||||
int64_list=tf.train.Int64List(value=list(values))
|
||||
)
|
||||
return f
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = create_int_feature(feature.input_ids)
|
||||
features["input_mask"] = create_int_feature(feature.input_mask)
|
||||
features["segment_ids"] = create_int_feature(feature.segment_ids)
|
||||
features["label_ids"] = create_int_feature([feature.label_id])
|
||||
tf_example = tf.train.Example(
|
||||
features=tf.train.Features(feature=features)
|
||||
)
|
||||
writers[index % len(writers)].write(tf_example.SerializeToString())
|
||||
size = index + 1
|
||||
for writer in writers:
|
||||
writer.close()
|
||||
metadata = {"size": size, "label_names": dataset.label_names}
|
||||
tfrecord_cache_files.save_metadata(metadata)
|
||||
else:
|
||||
print(
|
||||
f"Using existing cache files at {tfrecord_cache_files.cache_prefix}"
|
||||
)
|
||||
metadata = tfrecord_cache_files.load_metadata()
|
||||
size = metadata["size"]
|
||||
label_names = metadata["label_names"]
|
||||
preprocessed_ds = _tfrecord_dataset(
|
||||
tfrecord_cache_files.tfrecord_files, self._get_name_to_features()
|
||||
)
|
||||
return text_classifier_ds.Dataset(
|
||||
dataset=preprocessed_ds,
|
||||
size=dataset.size,
|
||||
label_names=dataset.label_names)
|
||||
size=size,
|
||||
label_names=label_names,
|
||||
tfrecord_cache_files=tfrecord_cache_files,
|
||||
)
|
||||
|
||||
|
||||
TextClassifierPreprocessor = (
|
||||
Union[BertClassifierPreprocessor,
|
||||
AverageWordEmbeddingClassifierPreprocessor])
|
||||
TextClassifierPreprocessor = Union[
|
||||
BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor
|
||||
]
|
||||
|
|
|
@ -13,14 +13,17 @@
|
|||
# limitations under the License.
|
||||
|
||||
import csv
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
import mock
|
||||
import numpy as np
|
||||
import numpy.testing as npt
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import cache_files
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||
|
@ -84,11 +87,12 @@ class PreprocessorTest(tf.test.TestCase):
|
|||
csv_file = self._get_csv_file()
|
||||
dataset = text_classifier_ds.Dataset.from_csv(
|
||||
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
||||
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
|
||||
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||
seq_len=5,
|
||||
do_lower_case=bert_spec.do_lower_case,
|
||||
uri=bert_spec.downloaded_files.get_path(),
|
||||
model_name=bert_spec.name,
|
||||
)
|
||||
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
||||
labels = []
|
||||
|
@ -97,18 +101,91 @@ class PreprocessorTest(tf.test.TestCase):
|
|||
self.assertEqual(label.shape, [1])
|
||||
labels.append(label.numpy()[0])
|
||||
self.assertSameElements(
|
||||
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids'])
|
||||
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids']
|
||||
)
|
||||
for feature in features.values():
|
||||
self.assertEqual(feature.shape, [1, 5])
|
||||
input_masks.append(features['input_mask'].numpy()[0])
|
||||
npt.assert_array_equal(features['input_type_ids'].numpy()[0],
|
||||
[0, 0, 0, 0, 0])
|
||||
npt.assert_array_equal(
|
||||
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]]))
|
||||
features['input_type_ids'].numpy()[0], [0, 0, 0, 0, 0]
|
||||
)
|
||||
npt.assert_array_equal(
|
||||
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])
|
||||
)
|
||||
self.assertEqual(labels, [1, 0])
|
||||
|
||||
def test_bert_preprocessor_cache(self):
|
||||
csv_file = self._get_csv_file()
|
||||
dataset = text_classifier_ds.Dataset.from_csv(
|
||||
filename=csv_file,
|
||||
csv_params=self.CSV_PARAMS_,
|
||||
cache_dir=self.get_temp_dir(),
|
||||
)
|
||||
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
|
||||
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||
seq_len=5,
|
||||
do_lower_case=bert_spec.do_lower_case,
|
||||
uri=bert_spec.downloaded_files.get_path(),
|
||||
model_name=bert_spec.name,
|
||||
)
|
||||
ds_cache_files = dataset.tfrecord_cache_files
|
||||
preprocessed_cache_files = bert_preprocessor._get_tfrecord_cache_files(
|
||||
ds_cache_files
|
||||
)
|
||||
self.assertFalse(preprocessed_cache_files.is_cached())
|
||||
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
||||
self.assertTrue(preprocessed_cache_files.is_cached())
|
||||
self.assertEqual(
|
||||
preprocessed_dataset.tfrecord_cache_files, preprocessed_cache_files
|
||||
)
|
||||
|
||||
# The second time running preprocessor, it should load from cache directly
|
||||
mock_stdout = io.StringIO()
|
||||
with mock.patch('sys.stdout', mock_stdout):
|
||||
_ = bert_preprocessor.preprocess(dataset)
|
||||
self.assertEqual(
|
||||
mock_stdout.getvalue(),
|
||||
'Using existing cache files at'
|
||||
f' {preprocessed_cache_files.cache_prefix}\n',
|
||||
)
|
||||
|
||||
def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case):
|
||||
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||
seq_len=seq_len,
|
||||
do_lower_case=do_lower_case,
|
||||
uri=bert_spec.downloaded_files.get_path(),
|
||||
model_name=bert_spec.name,
|
||||
)
|
||||
new_cf = bert_preprocessor._get_tfrecord_cache_files(cf)
|
||||
return new_cf.cache_prefix_filename
|
||||
|
||||
def test_bert_get_tfrecord_cache_files(self):
|
||||
# Test to ensure regenerated cache_files have different prefixes
|
||||
all_cf_prefixes = set()
|
||||
cf = cache_files.TFRecordCacheFiles(
|
||||
cache_prefix_filename='cache_prefix',
|
||||
cache_dir=self.get_temp_dir(),
|
||||
num_shards=1,
|
||||
)
|
||||
exbert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, True))
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 10, True))
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, False))
|
||||
mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True))
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True))
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, False))
|
||||
new_cf = cache_files.TFRecordCacheFiles(
|
||||
cache_prefix_filename='new_cache_prefix',
|
||||
cache_dir=self.get_temp_dir(),
|
||||
num_shards=1,
|
||||
)
|
||||
all_cf_prefixes.add(self._get_new_prefix(new_cf, exbert_spec, 5, True))
|
||||
|
||||
# Each item of all_cf_prefixes should be unique, so 7 total.
|
||||
self.assertLen(all_cf_prefixes, 7)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load compressed models from tensorflow_hub
|
||||
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
||||
tf.test.main()
|
||||
|
|
|
@ -435,6 +435,7 @@ class _BertClassifier(TextClassifier):
|
|||
seq_len=self._model_options.seq_len,
|
||||
do_lower_case=self._model_spec.do_lower_case,
|
||||
uri=self._model_spec.downloaded_files.get_path(),
|
||||
model_name=self._model_spec.name,
|
||||
)
|
||||
return (self._text_preprocessor.preprocess(train_data),
|
||||
self._text_preprocessor.preprocess(validation_data))
|
||||
|
|
Loading…
Reference in New Issue
Block a user