No public description

PiperOrigin-RevId: 550954023
This commit is contained in:
MediaPipe Team 2023-07-25 11:55:07 -07:00 committed by Copybara-Service
parent 113c9b30c2
commit 62538a9496
7 changed files with 262 additions and 53 deletions

View File

@ -45,6 +45,8 @@ class TFRecordCacheFiles:
num_shards: int = 1
def __post_init__(self):
if not tf.io.gfile.exists(self.cache_dir):
tf.io.gfile.makedirs(self.cache_dir)
if not self.cache_prefix_filename:
raise ValueError('cache_prefix_filename cannot be empty.')
if self.num_shards <= 0:
@ -79,8 +81,6 @@ class TFRecordCacheFiles:
Returns:
Array of TFRecordWriter objects
"""
if not tf.io.gfile.exists(self.cache_dir):
tf.io.gfile.makedirs(self.cache_dir)
return [tf.io.TFRecordWriter(path) for path in self.tfrecord_files]
def save_metadata(self, metadata):

View File

@ -76,7 +76,10 @@ py_test(
py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"],
deps = [
"//mediapipe/model_maker/python/core/data:cache_files",
"//mediapipe/model_maker/python/core/data:classification_dataset",
],
)
py_test(
@ -88,7 +91,10 @@ py_test(
py_library(
name = "preprocessor",
srcs = ["preprocessor.py"],
deps = [":dataset"],
deps = [
":dataset",
"//mediapipe/model_maker/python/core/data:cache_files",
],
)
py_test(
@ -99,6 +105,7 @@ py_test(
":dataset",
":model_spec",
":preprocessor",
"//mediapipe/model_maker/python/core/data:cache_files",
],
)

View File

@ -15,11 +15,15 @@
import csv
import dataclasses
import hashlib
import os
import random
import tempfile
from typing import List, Optional, Sequence
from typing import Optional, Sequence
import tensorflow as tf
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
from mediapipe.model_maker.python.core.data import classification_dataset
@ -46,21 +50,49 @@ class CSVParameters:
class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for text classifier."""
def __init__(
self,
dataset: tf.data.Dataset,
label_names: List[str],
tfrecord_cache_files: Optional[cache_files_lib.TFRecordCacheFiles] = None,
size: Optional[int] = None,
):
super().__init__(dataset, label_names, size)
if not tfrecord_cache_files:
tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles(
cache_prefix_filename="tfrecord", num_shards=1
)
self.tfrecord_cache_files = tfrecord_cache_files
@classmethod
def from_csv(cls,
def from_csv(
cls,
filename: str,
csv_params: CSVParameters,
shuffle: bool = True) -> "Dataset":
shuffle: bool = True,
cache_dir: Optional[str] = None,
num_shards: int = 1,
) -> "Dataset":
"""Loads text with labels from a CSV file.
Args:
filename: Name of the CSV file.
csv_params: Parameters used for reading the CSV file.
shuffle: If True, randomly shuffle the data.
cache_dir: Optional parameter to specify where to store the preprocessed
dataset. Only used for BERT models.
num_shards: Optional parameter for num shards of the preprocessed dataset.
Note that using more than 1 shard will reorder the dataset. Only used
for BERT models.
Returns:
Dataset containing (text, label) pairs and other related info.
"""
if cache_dir is None:
cache_dir = tempfile.mkdtemp()
# calculate hash for cache based off of files
hasher = hashlib.md5()
hasher.update(os.path.basename(filename).encode("utf-8"))
with tf.io.gfile.GFile(filename, "r") as f:
reader = csv.DictReader(
f,
@ -69,6 +101,9 @@ class Dataset(classification_dataset.ClassificationDataset):
quotechar=csv_params.quotechar)
lines = list(reader)
for line in lines:
hasher.update(str(line).encode("utf-8"))
if shuffle:
random.shuffle(lines)
@ -81,9 +116,18 @@ class Dataset(classification_dataset.ClassificationDataset):
index_by_label[line[csv_params.label_column]] for line in lines
]
label_index_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(label_indices, tf.int64))
tf.cast(label_indices, tf.int64)
)
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
return Dataset(
dataset=text_label_ds, label_names=label_names, size=len(texts)
hasher.update(str(num_shards).encode("utf-8"))
cache_prefix_filename = hasher.hexdigest()
tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles(
cache_prefix_filename, cache_dir, num_shards
)
return Dataset(
dataset=text_label_ds,
label_names=label_names,
tfrecord_cache_files=tfrecord_cache_files,
size=len(texts),
)

View File

@ -53,7 +53,7 @@ class DatasetTest(tf.test.TestCase):
def test_split(self):
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
data = dataset.Dataset(ds, ['pos', 'neg'], 4)
data = dataset.Dataset(ds, ['pos', 'neg'], size=4)
train_data, test_data = data.split(0.5)
expected_train_data = [b'good', b'bad']
expected_test_data = [b'neutral', b'odd']

View File

@ -15,14 +15,15 @@
"""Preprocessors for text classification."""
import collections
import hashlib
import os
import re
import tempfile
from typing import Mapping, Sequence, Tuple, Union
import tensorflow as tf
import tensorflow_hub
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
from official.nlp.data import classifier_data_lib
from official.nlp.tools import tokenization
@ -75,19 +76,20 @@ def _decode_record(
return bert_features, example["label_ids"]
def _single_file_dataset(
input_file: str, name_to_features: Mapping[str, tf.io.FixedLenFeature]
def _tfrecord_dataset(
tfrecord_files: Sequence[str],
name_to_features: Mapping[str, tf.io.FixedLenFeature],
) -> tf.data.TFRecordDataset:
"""Creates a single-file dataset to be passed for BERT custom training.
Args:
input_file: Filepath for the dataset.
tfrecord_files: Filepaths for the dataset.
name_to_features: Maps record keys to feature types.
Returns:
Dataset containing BERT model input features and labels.
"""
d = tf.data.TFRecordDataset(input_file)
d = tf.data.TFRecordDataset(tfrecord_files)
d = d.map(
lambda record: _decode_record(record, name_to_features),
num_parallel_calls=tf.data.AUTOTUNE)
@ -221,15 +223,23 @@ class BertClassifierPreprocessor:
seq_len: Length of the input sequence to the model.
vocab_file: File containing the BERT vocab.
tokenizer: BERT tokenizer.
model_name: Name of the model provided by the model_spec. Used to associate
cached files with specific Bert model vocab.
"""
def __init__(self, seq_len: int, do_lower_case: bool, uri: str):
def __init__(
self, seq_len: int, do_lower_case: bool, uri: str, model_name: str
):
self._seq_len = seq_len
# Vocab filepath is tied to the BERT module's URI.
self._vocab_file = os.path.join(
tensorflow_hub.resolve(uri), "assets", "vocab.txt")
self._tokenizer = tokenization.FullTokenizer(self._vocab_file,
do_lower_case)
tensorflow_hub.resolve(uri), "assets", "vocab.txt"
)
self._do_lower_case = do_lower_case
self._tokenizer = tokenization.FullTokenizer(
self._vocab_file, self._do_lower_case
)
self._model_name = model_name
def _get_name_to_features(self):
"""Gets the dictionary mapping record keys to feature types."""
@ -244,8 +254,45 @@ class BertClassifierPreprocessor:
"""Returns the vocab file of the BertClassifierPreprocessor."""
return self._vocab_file
def _get_tfrecord_cache_files(
self, ds_cache_files
) -> cache_files_lib.TFRecordCacheFiles:
"""Helper to regenerate cache prefix filename using preprocessor info.
We need to update the dataset cache_prefix cache because the actual cached
dataset depends on the preprocessor parameters such as model_name, seq_len,
and do_lower_case in addition to the raw dataset parameters which is already
included in the ds_cache_files.cache_prefix_filename
Specifically, the new cache_prefix_filename used by the preprocessor will
be a hash generated from the following:
1. cache_prefix_filename of the initial raw dataset
2. model_name
3. seq_len
4. do_lower_case
Args:
ds_cache_files: TFRecordCacheFiles from the original raw dataset object
Returns:
A new TFRecordCacheFiles object which incorporates the preprocessor
parameters.
"""
hasher = hashlib.md5()
hasher.update(ds_cache_files.cache_prefix_filename.encode("utf-8"))
hasher.update(self._model_name.encode("utf-8"))
hasher.update(str(self._seq_len).encode("utf-8"))
hasher.update(str(self._do_lower_case).encode("utf-8"))
cache_prefix_filename = hasher.hexdigest()
return cache_files_lib.TFRecordCacheFiles(
cache_prefix_filename,
ds_cache_files.cache_dir,
ds_cache_files.num_shards,
)
def preprocess(
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset:
self, dataset: text_classifier_ds.Dataset
) -> text_classifier_ds.Dataset:
"""Preprocesses data into input for a BERT-based classifier.
Args:
@ -254,32 +301,65 @@ class BertClassifierPreprocessor:
Returns:
Dataset containing (bert_features, label) data.
"""
examples = []
ds_cache_files = dataset.tfrecord_cache_files
# Get new tfrecord_cache_files by including preprocessor information.
tfrecord_cache_files = self._get_tfrecord_cache_files(ds_cache_files)
if not tfrecord_cache_files.is_cached():
print(f"Writing new cache files to {tfrecord_cache_files.cache_prefix}")
writers = tfrecord_cache_files.get_writers()
size = 0
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
_validate_text_and_label(text, label)
examples.append(
classifier_data_lib.InputExample(
example = classifier_data_lib.InputExample(
guid=str(index),
text_a=text.numpy()[0].decode("utf-8"),
text_b=None,
# InputExample expects the label name rather than the int ID
label=dataset.label_names[label.numpy()[0]]))
# label=dataset.label_names[label.numpy()[0]])
label=label.numpy()[0],
)
feature = classifier_data_lib.convert_single_example(
index, example, None, self._seq_len, self._tokenizer
)
tfrecord_file = os.path.join(tempfile.mkdtemp(), "bert_features.tfrecord")
classifier_data_lib.file_based_convert_examples_to_features(
examples=examples,
label_list=dataset.label_names,
max_seq_length=self._seq_len,
tokenizer=self._tokenizer,
output_file=tfrecord_file)
preprocessed_ds = _single_file_dataset(tfrecord_file,
self._get_name_to_features())
def create_int_feature(values):
f = tf.train.Feature(
int64_list=tf.train.Int64List(value=list(values))
)
return f
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
features["label_ids"] = create_int_feature([feature.label_id])
tf_example = tf.train.Example(
features=tf.train.Features(feature=features)
)
writers[index % len(writers)].write(tf_example.SerializeToString())
size = index + 1
for writer in writers:
writer.close()
metadata = {"size": size, "label_names": dataset.label_names}
tfrecord_cache_files.save_metadata(metadata)
else:
print(
f"Using existing cache files at {tfrecord_cache_files.cache_prefix}"
)
metadata = tfrecord_cache_files.load_metadata()
size = metadata["size"]
label_names = metadata["label_names"]
preprocessed_ds = _tfrecord_dataset(
tfrecord_cache_files.tfrecord_files, self._get_name_to_features()
)
return text_classifier_ds.Dataset(
dataset=preprocessed_ds,
size=dataset.size,
label_names=dataset.label_names)
size=size,
label_names=label_names,
tfrecord_cache_files=tfrecord_cache_files,
)
TextClassifierPreprocessor = (
Union[BertClassifierPreprocessor,
AverageWordEmbeddingClassifierPreprocessor])
TextClassifierPreprocessor = Union[
BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor
]

View File

@ -13,14 +13,17 @@
# limitations under the License.
import csv
import io
import os
import tempfile
from unittest import mock as unittest_mock
import mock
import numpy as np
import numpy.testing as npt
import tensorflow as tf
from mediapipe.model_maker.python.core.data import cache_files
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
from mediapipe.model_maker.python.text.text_classifier import model_spec
from mediapipe.model_maker.python.text.text_classifier import preprocessor
@ -84,11 +87,12 @@ class PreprocessorTest(tf.test.TestCase):
csv_file = self._get_csv_file()
dataset = text_classifier_ds.Dataset.from_csv(
filename=csv_file, csv_params=self.CSV_PARAMS_)
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=5,
do_lower_case=bert_spec.do_lower_case,
uri=bert_spec.downloaded_files.get_path(),
model_name=bert_spec.name,
)
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
labels = []
@ -97,18 +101,91 @@ class PreprocessorTest(tf.test.TestCase):
self.assertEqual(label.shape, [1])
labels.append(label.numpy()[0])
self.assertSameElements(
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids'])
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids']
)
for feature in features.values():
self.assertEqual(feature.shape, [1, 5])
input_masks.append(features['input_mask'].numpy()[0])
npt.assert_array_equal(features['input_type_ids'].numpy()[0],
[0, 0, 0, 0, 0])
npt.assert_array_equal(
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]]))
features['input_type_ids'].numpy()[0], [0, 0, 0, 0, 0]
)
npt.assert_array_equal(
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])
)
self.assertEqual(labels, [1, 0])
def test_bert_preprocessor_cache(self):
csv_file = self._get_csv_file()
dataset = text_classifier_ds.Dataset.from_csv(
filename=csv_file,
csv_params=self.CSV_PARAMS_,
cache_dir=self.get_temp_dir(),
)
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=5,
do_lower_case=bert_spec.do_lower_case,
uri=bert_spec.downloaded_files.get_path(),
model_name=bert_spec.name,
)
ds_cache_files = dataset.tfrecord_cache_files
preprocessed_cache_files = bert_preprocessor._get_tfrecord_cache_files(
ds_cache_files
)
self.assertFalse(preprocessed_cache_files.is_cached())
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
self.assertTrue(preprocessed_cache_files.is_cached())
self.assertEqual(
preprocessed_dataset.tfrecord_cache_files, preprocessed_cache_files
)
# The second time running preprocessor, it should load from cache directly
mock_stdout = io.StringIO()
with mock.patch('sys.stdout', mock_stdout):
_ = bert_preprocessor.preprocess(dataset)
self.assertEqual(
mock_stdout.getvalue(),
'Using existing cache files at'
f' {preprocessed_cache_files.cache_prefix}\n',
)
def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case):
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=seq_len,
do_lower_case=do_lower_case,
uri=bert_spec.downloaded_files.get_path(),
model_name=bert_spec.name,
)
new_cf = bert_preprocessor._get_tfrecord_cache_files(cf)
return new_cf.cache_prefix_filename
def test_bert_get_tfrecord_cache_files(self):
# Test to ensure regenerated cache_files have different prefixes
all_cf_prefixes = set()
cf = cache_files.TFRecordCacheFiles(
cache_prefix_filename='cache_prefix',
cache_dir=self.get_temp_dir(),
num_shards=1,
)
exbert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, True))
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 10, True))
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, False))
mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True))
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True))
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, False))
new_cf = cache_files.TFRecordCacheFiles(
cache_prefix_filename='new_cache_prefix',
cache_dir=self.get_temp_dir(),
num_shards=1,
)
all_cf_prefixes.add(self._get_new_prefix(new_cf, exbert_spec, 5, True))
# Each item of all_cf_prefixes should be unique, so 7 total.
self.assertLen(all_cf_prefixes, 7)
if __name__ == '__main__':
# Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
tf.test.main()

View File

@ -435,6 +435,7 @@ class _BertClassifier(TextClassifier):
seq_len=self._model_options.seq_len,
do_lower_case=self._model_spec.do_lower_case,
uri=self._model_spec.downloaded_files.get_path(),
model_name=self._model_spec.name,
)
return (self._text_preprocessor.preprocess(train_data),
self._text_preprocessor.preprocess(validation_data))