No public description

PiperOrigin-RevId: 550954023
This commit is contained in:
MediaPipe Team 2023-07-25 11:55:07 -07:00 committed by Copybara-Service
parent 113c9b30c2
commit 62538a9496
7 changed files with 262 additions and 53 deletions

View File

@ -45,6 +45,8 @@ class TFRecordCacheFiles:
num_shards: int = 1 num_shards: int = 1
def __post_init__(self): def __post_init__(self):
if not tf.io.gfile.exists(self.cache_dir):
tf.io.gfile.makedirs(self.cache_dir)
if not self.cache_prefix_filename: if not self.cache_prefix_filename:
raise ValueError('cache_prefix_filename cannot be empty.') raise ValueError('cache_prefix_filename cannot be empty.')
if self.num_shards <= 0: if self.num_shards <= 0:
@ -79,8 +81,6 @@ class TFRecordCacheFiles:
Returns: Returns:
Array of TFRecordWriter objects Array of TFRecordWriter objects
""" """
if not tf.io.gfile.exists(self.cache_dir):
tf.io.gfile.makedirs(self.cache_dir)
return [tf.io.TFRecordWriter(path) for path in self.tfrecord_files] return [tf.io.TFRecordWriter(path) for path in self.tfrecord_files]
def save_metadata(self, metadata): def save_metadata(self, metadata):

View File

@ -76,7 +76,10 @@ py_test(
py_library( py_library(
name = "dataset", name = "dataset",
srcs = ["dataset.py"], srcs = ["dataset.py"],
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"], deps = [
"//mediapipe/model_maker/python/core/data:cache_files",
"//mediapipe/model_maker/python/core/data:classification_dataset",
],
) )
py_test( py_test(
@ -88,7 +91,10 @@ py_test(
py_library( py_library(
name = "preprocessor", name = "preprocessor",
srcs = ["preprocessor.py"], srcs = ["preprocessor.py"],
deps = [":dataset"], deps = [
":dataset",
"//mediapipe/model_maker/python/core/data:cache_files",
],
) )
py_test( py_test(
@ -99,6 +105,7 @@ py_test(
":dataset", ":dataset",
":model_spec", ":model_spec",
":preprocessor", ":preprocessor",
"//mediapipe/model_maker/python/core/data:cache_files",
], ],
) )

View File

@ -15,11 +15,15 @@
import csv import csv
import dataclasses import dataclasses
import hashlib
import os
import random import random
import tempfile
from typing import List, Optional, Sequence
from typing import Optional, Sequence
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
from mediapipe.model_maker.python.core.data import classification_dataset from mediapipe.model_maker.python.core.data import classification_dataset
@ -46,21 +50,49 @@ class CSVParameters:
class Dataset(classification_dataset.ClassificationDataset): class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for text classifier.""" """Dataset library for text classifier."""
def __init__(
self,
dataset: tf.data.Dataset,
label_names: List[str],
tfrecord_cache_files: Optional[cache_files_lib.TFRecordCacheFiles] = None,
size: Optional[int] = None,
):
super().__init__(dataset, label_names, size)
if not tfrecord_cache_files:
tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles(
cache_prefix_filename="tfrecord", num_shards=1
)
self.tfrecord_cache_files = tfrecord_cache_files
@classmethod @classmethod
def from_csv(cls, def from_csv(
cls,
filename: str, filename: str,
csv_params: CSVParameters, csv_params: CSVParameters,
shuffle: bool = True) -> "Dataset": shuffle: bool = True,
cache_dir: Optional[str] = None,
num_shards: int = 1,
) -> "Dataset":
"""Loads text with labels from a CSV file. """Loads text with labels from a CSV file.
Args: Args:
filename: Name of the CSV file. filename: Name of the CSV file.
csv_params: Parameters used for reading the CSV file. csv_params: Parameters used for reading the CSV file.
shuffle: If True, randomly shuffle the data. shuffle: If True, randomly shuffle the data.
cache_dir: Optional parameter to specify where to store the preprocessed
dataset. Only used for BERT models.
num_shards: Optional parameter for num shards of the preprocessed dataset.
Note that using more than 1 shard will reorder the dataset. Only used
for BERT models.
Returns: Returns:
Dataset containing (text, label) pairs and other related info. Dataset containing (text, label) pairs and other related info.
""" """
if cache_dir is None:
cache_dir = tempfile.mkdtemp()
# calculate hash for cache based off of files
hasher = hashlib.md5()
hasher.update(os.path.basename(filename).encode("utf-8"))
with tf.io.gfile.GFile(filename, "r") as f: with tf.io.gfile.GFile(filename, "r") as f:
reader = csv.DictReader( reader = csv.DictReader(
f, f,
@ -69,6 +101,9 @@ class Dataset(classification_dataset.ClassificationDataset):
quotechar=csv_params.quotechar) quotechar=csv_params.quotechar)
lines = list(reader) lines = list(reader)
for line in lines:
hasher.update(str(line).encode("utf-8"))
if shuffle: if shuffle:
random.shuffle(lines) random.shuffle(lines)
@ -81,9 +116,18 @@ class Dataset(classification_dataset.ClassificationDataset):
index_by_label[line[csv_params.label_column]] for line in lines index_by_label[line[csv_params.label_column]] for line in lines
] ]
label_index_ds = tf.data.Dataset.from_tensor_slices( label_index_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(label_indices, tf.int64)) tf.cast(label_indices, tf.int64)
)
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds)) text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
return Dataset( hasher.update(str(num_shards).encode("utf-8"))
dataset=text_label_ds, label_names=label_names, size=len(texts) cache_prefix_filename = hasher.hexdigest()
tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles(
cache_prefix_filename, cache_dir, num_shards
)
return Dataset(
dataset=text_label_ds,
label_names=label_names,
tfrecord_cache_files=tfrecord_cache_files,
size=len(texts),
) )

View File

@ -53,7 +53,7 @@ class DatasetTest(tf.test.TestCase):
def test_split(self): def test_split(self):
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd']) ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
data = dataset.Dataset(ds, ['pos', 'neg'], 4) data = dataset.Dataset(ds, ['pos', 'neg'], size=4)
train_data, test_data = data.split(0.5) train_data, test_data = data.split(0.5)
expected_train_data = [b'good', b'bad'] expected_train_data = [b'good', b'bad']
expected_test_data = [b'neutral', b'odd'] expected_test_data = [b'neutral', b'odd']

View File

@ -15,14 +15,15 @@
"""Preprocessors for text classification.""" """Preprocessors for text classification."""
import collections import collections
import hashlib
import os import os
import re import re
import tempfile
from typing import Mapping, Sequence, Tuple, Union from typing import Mapping, Sequence, Tuple, Union
import tensorflow as tf import tensorflow as tf
import tensorflow_hub import tensorflow_hub
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
from official.nlp.data import classifier_data_lib from official.nlp.data import classifier_data_lib
from official.nlp.tools import tokenization from official.nlp.tools import tokenization
@ -75,19 +76,20 @@ def _decode_record(
return bert_features, example["label_ids"] return bert_features, example["label_ids"]
def _single_file_dataset( def _tfrecord_dataset(
input_file: str, name_to_features: Mapping[str, tf.io.FixedLenFeature] tfrecord_files: Sequence[str],
name_to_features: Mapping[str, tf.io.FixedLenFeature],
) -> tf.data.TFRecordDataset: ) -> tf.data.TFRecordDataset:
"""Creates a single-file dataset to be passed for BERT custom training. """Creates a single-file dataset to be passed for BERT custom training.
Args: Args:
input_file: Filepath for the dataset. tfrecord_files: Filepaths for the dataset.
name_to_features: Maps record keys to feature types. name_to_features: Maps record keys to feature types.
Returns: Returns:
Dataset containing BERT model input features and labels. Dataset containing BERT model input features and labels.
""" """
d = tf.data.TFRecordDataset(input_file) d = tf.data.TFRecordDataset(tfrecord_files)
d = d.map( d = d.map(
lambda record: _decode_record(record, name_to_features), lambda record: _decode_record(record, name_to_features),
num_parallel_calls=tf.data.AUTOTUNE) num_parallel_calls=tf.data.AUTOTUNE)
@ -221,15 +223,23 @@ class BertClassifierPreprocessor:
seq_len: Length of the input sequence to the model. seq_len: Length of the input sequence to the model.
vocab_file: File containing the BERT vocab. vocab_file: File containing the BERT vocab.
tokenizer: BERT tokenizer. tokenizer: BERT tokenizer.
model_name: Name of the model provided by the model_spec. Used to associate
cached files with specific Bert model vocab.
""" """
def __init__(self, seq_len: int, do_lower_case: bool, uri: str): def __init__(
self, seq_len: int, do_lower_case: bool, uri: str, model_name: str
):
self._seq_len = seq_len self._seq_len = seq_len
# Vocab filepath is tied to the BERT module's URI. # Vocab filepath is tied to the BERT module's URI.
self._vocab_file = os.path.join( self._vocab_file = os.path.join(
tensorflow_hub.resolve(uri), "assets", "vocab.txt") tensorflow_hub.resolve(uri), "assets", "vocab.txt"
self._tokenizer = tokenization.FullTokenizer(self._vocab_file, )
do_lower_case) self._do_lower_case = do_lower_case
self._tokenizer = tokenization.FullTokenizer(
self._vocab_file, self._do_lower_case
)
self._model_name = model_name
def _get_name_to_features(self): def _get_name_to_features(self):
"""Gets the dictionary mapping record keys to feature types.""" """Gets the dictionary mapping record keys to feature types."""
@ -244,8 +254,45 @@ class BertClassifierPreprocessor:
"""Returns the vocab file of the BertClassifierPreprocessor.""" """Returns the vocab file of the BertClassifierPreprocessor."""
return self._vocab_file return self._vocab_file
def _get_tfrecord_cache_files(
self, ds_cache_files
) -> cache_files_lib.TFRecordCacheFiles:
"""Helper to regenerate cache prefix filename using preprocessor info.
We need to update the dataset cache_prefix cache because the actual cached
dataset depends on the preprocessor parameters such as model_name, seq_len,
and do_lower_case in addition to the raw dataset parameters which is already
included in the ds_cache_files.cache_prefix_filename
Specifically, the new cache_prefix_filename used by the preprocessor will
be a hash generated from the following:
1. cache_prefix_filename of the initial raw dataset
2. model_name
3. seq_len
4. do_lower_case
Args:
ds_cache_files: TFRecordCacheFiles from the original raw dataset object
Returns:
A new TFRecordCacheFiles object which incorporates the preprocessor
parameters.
"""
hasher = hashlib.md5()
hasher.update(ds_cache_files.cache_prefix_filename.encode("utf-8"))
hasher.update(self._model_name.encode("utf-8"))
hasher.update(str(self._seq_len).encode("utf-8"))
hasher.update(str(self._do_lower_case).encode("utf-8"))
cache_prefix_filename = hasher.hexdigest()
return cache_files_lib.TFRecordCacheFiles(
cache_prefix_filename,
ds_cache_files.cache_dir,
ds_cache_files.num_shards,
)
def preprocess( def preprocess(
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset: self, dataset: text_classifier_ds.Dataset
) -> text_classifier_ds.Dataset:
"""Preprocesses data into input for a BERT-based classifier. """Preprocesses data into input for a BERT-based classifier.
Args: Args:
@ -254,32 +301,65 @@ class BertClassifierPreprocessor:
Returns: Returns:
Dataset containing (bert_features, label) data. Dataset containing (bert_features, label) data.
""" """
examples = [] ds_cache_files = dataset.tfrecord_cache_files
# Get new tfrecord_cache_files by including preprocessor information.
tfrecord_cache_files = self._get_tfrecord_cache_files(ds_cache_files)
if not tfrecord_cache_files.is_cached():
print(f"Writing new cache files to {tfrecord_cache_files.cache_prefix}")
writers = tfrecord_cache_files.get_writers()
size = 0
for index, (text, label) in enumerate(dataset.gen_tf_dataset()): for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
_validate_text_and_label(text, label) _validate_text_and_label(text, label)
examples.append( example = classifier_data_lib.InputExample(
classifier_data_lib.InputExample(
guid=str(index), guid=str(index),
text_a=text.numpy()[0].decode("utf-8"), text_a=text.numpy()[0].decode("utf-8"),
text_b=None, text_b=None,
# InputExample expects the label name rather than the int ID # InputExample expects the label name rather than the int ID
label=dataset.label_names[label.numpy()[0]])) # label=dataset.label_names[label.numpy()[0]])
label=label.numpy()[0],
)
feature = classifier_data_lib.convert_single_example(
index, example, None, self._seq_len, self._tokenizer
)
tfrecord_file = os.path.join(tempfile.mkdtemp(), "bert_features.tfrecord") def create_int_feature(values):
classifier_data_lib.file_based_convert_examples_to_features( f = tf.train.Feature(
examples=examples, int64_list=tf.train.Int64List(value=list(values))
label_list=dataset.label_names, )
max_seq_length=self._seq_len, return f
tokenizer=self._tokenizer,
output_file=tfrecord_file) features = collections.OrderedDict()
preprocessed_ds = _single_file_dataset(tfrecord_file, features["input_ids"] = create_int_feature(feature.input_ids)
self._get_name_to_features()) features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
features["label_ids"] = create_int_feature([feature.label_id])
tf_example = tf.train.Example(
features=tf.train.Features(feature=features)
)
writers[index % len(writers)].write(tf_example.SerializeToString())
size = index + 1
for writer in writers:
writer.close()
metadata = {"size": size, "label_names": dataset.label_names}
tfrecord_cache_files.save_metadata(metadata)
else:
print(
f"Using existing cache files at {tfrecord_cache_files.cache_prefix}"
)
metadata = tfrecord_cache_files.load_metadata()
size = metadata["size"]
label_names = metadata["label_names"]
preprocessed_ds = _tfrecord_dataset(
tfrecord_cache_files.tfrecord_files, self._get_name_to_features()
)
return text_classifier_ds.Dataset( return text_classifier_ds.Dataset(
dataset=preprocessed_ds, dataset=preprocessed_ds,
size=dataset.size, size=size,
label_names=dataset.label_names) label_names=label_names,
tfrecord_cache_files=tfrecord_cache_files,
)
TextClassifierPreprocessor = ( TextClassifierPreprocessor = Union[
Union[BertClassifierPreprocessor, BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor
AverageWordEmbeddingClassifierPreprocessor]) ]

View File

@ -13,14 +13,17 @@
# limitations under the License. # limitations under the License.
import csv import csv
import io
import os import os
import tempfile import tempfile
from unittest import mock as unittest_mock from unittest import mock as unittest_mock
import mock
import numpy as np import numpy as np
import numpy.testing as npt import numpy.testing as npt
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.data import cache_files
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
from mediapipe.model_maker.python.text.text_classifier import model_spec from mediapipe.model_maker.python.text.text_classifier import model_spec
from mediapipe.model_maker.python.text.text_classifier import preprocessor from mediapipe.model_maker.python.text.text_classifier import preprocessor
@ -84,11 +87,12 @@ class PreprocessorTest(tf.test.TestCase):
csv_file = self._get_csv_file() csv_file = self._get_csv_file()
dataset = text_classifier_ds.Dataset.from_csv( dataset = text_classifier_ds.Dataset.from_csv(
filename=csv_file, csv_params=self.CSV_PARAMS_) filename=csv_file, csv_params=self.CSV_PARAMS_)
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
bert_preprocessor = preprocessor.BertClassifierPreprocessor( bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=5, seq_len=5,
do_lower_case=bert_spec.do_lower_case, do_lower_case=bert_spec.do_lower_case,
uri=bert_spec.downloaded_files.get_path(), uri=bert_spec.downloaded_files.get_path(),
model_name=bert_spec.name,
) )
preprocessed_dataset = bert_preprocessor.preprocess(dataset) preprocessed_dataset = bert_preprocessor.preprocess(dataset)
labels = [] labels = []
@ -97,18 +101,91 @@ class PreprocessorTest(tf.test.TestCase):
self.assertEqual(label.shape, [1]) self.assertEqual(label.shape, [1])
labels.append(label.numpy()[0]) labels.append(label.numpy()[0])
self.assertSameElements( self.assertSameElements(
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids']) features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids']
)
for feature in features.values(): for feature in features.values():
self.assertEqual(feature.shape, [1, 5]) self.assertEqual(feature.shape, [1, 5])
input_masks.append(features['input_mask'].numpy()[0]) input_masks.append(features['input_mask'].numpy()[0])
npt.assert_array_equal(features['input_type_ids'].numpy()[0],
[0, 0, 0, 0, 0])
npt.assert_array_equal( npt.assert_array_equal(
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])) features['input_type_ids'].numpy()[0], [0, 0, 0, 0, 0]
)
npt.assert_array_equal(
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])
)
self.assertEqual(labels, [1, 0]) self.assertEqual(labels, [1, 0])
def test_bert_preprocessor_cache(self):
csv_file = self._get_csv_file()
dataset = text_classifier_ds.Dataset.from_csv(
filename=csv_file,
csv_params=self.CSV_PARAMS_,
cache_dir=self.get_temp_dir(),
)
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=5,
do_lower_case=bert_spec.do_lower_case,
uri=bert_spec.downloaded_files.get_path(),
model_name=bert_spec.name,
)
ds_cache_files = dataset.tfrecord_cache_files
preprocessed_cache_files = bert_preprocessor._get_tfrecord_cache_files(
ds_cache_files
)
self.assertFalse(preprocessed_cache_files.is_cached())
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
self.assertTrue(preprocessed_cache_files.is_cached())
self.assertEqual(
preprocessed_dataset.tfrecord_cache_files, preprocessed_cache_files
)
# The second time running preprocessor, it should load from cache directly
mock_stdout = io.StringIO()
with mock.patch('sys.stdout', mock_stdout):
_ = bert_preprocessor.preprocess(dataset)
self.assertEqual(
mock_stdout.getvalue(),
'Using existing cache files at'
f' {preprocessed_cache_files.cache_prefix}\n',
)
def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case):
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=seq_len,
do_lower_case=do_lower_case,
uri=bert_spec.downloaded_files.get_path(),
model_name=bert_spec.name,
)
new_cf = bert_preprocessor._get_tfrecord_cache_files(cf)
return new_cf.cache_prefix_filename
def test_bert_get_tfrecord_cache_files(self):
# Test to ensure regenerated cache_files have different prefixes
all_cf_prefixes = set()
cf = cache_files.TFRecordCacheFiles(
cache_prefix_filename='cache_prefix',
cache_dir=self.get_temp_dir(),
num_shards=1,
)
exbert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, True))
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 10, True))
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, False))
mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True))
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True))
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, False))
new_cf = cache_files.TFRecordCacheFiles(
cache_prefix_filename='new_cache_prefix',
cache_dir=self.get_temp_dir(),
num_shards=1,
)
all_cf_prefixes.add(self._get_new_prefix(new_cf, exbert_spec, 5, True))
# Each item of all_cf_prefixes should be unique, so 7 total.
self.assertLen(all_cf_prefixes, 7)
if __name__ == '__main__': if __name__ == '__main__':
# Load compressed models from tensorflow_hub # Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
tf.test.main() tf.test.main()

View File

@ -435,6 +435,7 @@ class _BertClassifier(TextClassifier):
seq_len=self._model_options.seq_len, seq_len=self._model_options.seq_len,
do_lower_case=self._model_spec.do_lower_case, do_lower_case=self._model_spec.do_lower_case,
uri=self._model_spec.downloaded_files.get_path(), uri=self._model_spec.downloaded_files.get_path(),
model_name=self._model_spec.name,
) )
return (self._text_preprocessor.preprocess(train_data), return (self._text_preprocessor.preprocess(train_data),
self._text_preprocessor.preprocess(validation_data)) self._text_preprocessor.preprocess(validation_data))