No public description
PiperOrigin-RevId: 550954023
This commit is contained in:
parent
113c9b30c2
commit
62538a9496
|
@ -45,6 +45,8 @@ class TFRecordCacheFiles:
|
||||||
num_shards: int = 1
|
num_shards: int = 1
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
if not tf.io.gfile.exists(self.cache_dir):
|
||||||
|
tf.io.gfile.makedirs(self.cache_dir)
|
||||||
if not self.cache_prefix_filename:
|
if not self.cache_prefix_filename:
|
||||||
raise ValueError('cache_prefix_filename cannot be empty.')
|
raise ValueError('cache_prefix_filename cannot be empty.')
|
||||||
if self.num_shards <= 0:
|
if self.num_shards <= 0:
|
||||||
|
@ -79,8 +81,6 @@ class TFRecordCacheFiles:
|
||||||
Returns:
|
Returns:
|
||||||
Array of TFRecordWriter objects
|
Array of TFRecordWriter objects
|
||||||
"""
|
"""
|
||||||
if not tf.io.gfile.exists(self.cache_dir):
|
|
||||||
tf.io.gfile.makedirs(self.cache_dir)
|
|
||||||
return [tf.io.TFRecordWriter(path) for path in self.tfrecord_files]
|
return [tf.io.TFRecordWriter(path) for path in self.tfrecord_files]
|
||||||
|
|
||||||
def save_metadata(self, metadata):
|
def save_metadata(self, metadata):
|
||||||
|
|
|
@ -76,7 +76,10 @@ py_test(
|
||||||
py_library(
|
py_library(
|
||||||
name = "dataset",
|
name = "dataset",
|
||||||
srcs = ["dataset.py"],
|
srcs = ["dataset.py"],
|
||||||
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"],
|
deps = [
|
||||||
|
"//mediapipe/model_maker/python/core/data:cache_files",
|
||||||
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
@ -88,7 +91,10 @@ py_test(
|
||||||
py_library(
|
py_library(
|
||||||
name = "preprocessor",
|
name = "preprocessor",
|
||||||
srcs = ["preprocessor.py"],
|
srcs = ["preprocessor.py"],
|
||||||
deps = [":dataset"],
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
"//mediapipe/model_maker/python/core/data:cache_files",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
@ -99,6 +105,7 @@ py_test(
|
||||||
":dataset",
|
":dataset",
|
||||||
":model_spec",
|
":model_spec",
|
||||||
":preprocessor",
|
":preprocessor",
|
||||||
|
"//mediapipe/model_maker/python/core/data:cache_files",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -15,11 +15,15 @@
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
|
import tempfile
|
||||||
|
from typing import List, Optional, Sequence
|
||||||
|
|
||||||
from typing import Optional, Sequence
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
|
||||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,21 +50,49 @@ class CSVParameters:
|
||||||
class Dataset(classification_dataset.ClassificationDataset):
|
class Dataset(classification_dataset.ClassificationDataset):
|
||||||
"""Dataset library for text classifier."""
|
"""Dataset library for text classifier."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset: tf.data.Dataset,
|
||||||
|
label_names: List[str],
|
||||||
|
tfrecord_cache_files: Optional[cache_files_lib.TFRecordCacheFiles] = None,
|
||||||
|
size: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__(dataset, label_names, size)
|
||||||
|
if not tfrecord_cache_files:
|
||||||
|
tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles(
|
||||||
|
cache_prefix_filename="tfrecord", num_shards=1
|
||||||
|
)
|
||||||
|
self.tfrecord_cache_files = tfrecord_cache_files
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_csv(cls,
|
def from_csv(
|
||||||
filename: str,
|
cls,
|
||||||
csv_params: CSVParameters,
|
filename: str,
|
||||||
shuffle: bool = True) -> "Dataset":
|
csv_params: CSVParameters,
|
||||||
|
shuffle: bool = True,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
num_shards: int = 1,
|
||||||
|
) -> "Dataset":
|
||||||
"""Loads text with labels from a CSV file.
|
"""Loads text with labels from a CSV file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename: Name of the CSV file.
|
filename: Name of the CSV file.
|
||||||
csv_params: Parameters used for reading the CSV file.
|
csv_params: Parameters used for reading the CSV file.
|
||||||
shuffle: If True, randomly shuffle the data.
|
shuffle: If True, randomly shuffle the data.
|
||||||
|
cache_dir: Optional parameter to specify where to store the preprocessed
|
||||||
|
dataset. Only used for BERT models.
|
||||||
|
num_shards: Optional parameter for num shards of the preprocessed dataset.
|
||||||
|
Note that using more than 1 shard will reorder the dataset. Only used
|
||||||
|
for BERT models.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset containing (text, label) pairs and other related info.
|
Dataset containing (text, label) pairs and other related info.
|
||||||
"""
|
"""
|
||||||
|
if cache_dir is None:
|
||||||
|
cache_dir = tempfile.mkdtemp()
|
||||||
|
# calculate hash for cache based off of files
|
||||||
|
hasher = hashlib.md5()
|
||||||
|
hasher.update(os.path.basename(filename).encode("utf-8"))
|
||||||
with tf.io.gfile.GFile(filename, "r") as f:
|
with tf.io.gfile.GFile(filename, "r") as f:
|
||||||
reader = csv.DictReader(
|
reader = csv.DictReader(
|
||||||
f,
|
f,
|
||||||
|
@ -69,6 +101,9 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
quotechar=csv_params.quotechar)
|
quotechar=csv_params.quotechar)
|
||||||
|
|
||||||
lines = list(reader)
|
lines = list(reader)
|
||||||
|
for line in lines:
|
||||||
|
hasher.update(str(line).encode("utf-8"))
|
||||||
|
|
||||||
if shuffle:
|
if shuffle:
|
||||||
random.shuffle(lines)
|
random.shuffle(lines)
|
||||||
|
|
||||||
|
@ -81,9 +116,18 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
index_by_label[line[csv_params.label_column]] for line in lines
|
index_by_label[line[csv_params.label_column]] for line in lines
|
||||||
]
|
]
|
||||||
label_index_ds = tf.data.Dataset.from_tensor_slices(
|
label_index_ds = tf.data.Dataset.from_tensor_slices(
|
||||||
tf.cast(label_indices, tf.int64))
|
tf.cast(label_indices, tf.int64)
|
||||||
|
)
|
||||||
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
|
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
|
||||||
|
|
||||||
return Dataset(
|
hasher.update(str(num_shards).encode("utf-8"))
|
||||||
dataset=text_label_ds, label_names=label_names, size=len(texts)
|
cache_prefix_filename = hasher.hexdigest()
|
||||||
|
tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles(
|
||||||
|
cache_prefix_filename, cache_dir, num_shards
|
||||||
|
)
|
||||||
|
return Dataset(
|
||||||
|
dataset=text_label_ds,
|
||||||
|
label_names=label_names,
|
||||||
|
tfrecord_cache_files=tfrecord_cache_files,
|
||||||
|
size=len(texts),
|
||||||
)
|
)
|
||||||
|
|
|
@ -53,7 +53,7 @@ class DatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
def test_split(self):
|
def test_split(self):
|
||||||
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
|
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
|
||||||
data = dataset.Dataset(ds, ['pos', 'neg'], 4)
|
data = dataset.Dataset(ds, ['pos', 'neg'], size=4)
|
||||||
train_data, test_data = data.split(0.5)
|
train_data, test_data = data.split(0.5)
|
||||||
expected_train_data = [b'good', b'bad']
|
expected_train_data = [b'good', b'bad']
|
||||||
expected_test_data = [b'neutral', b'odd']
|
expected_test_data = [b'neutral', b'odd']
|
||||||
|
|
|
@ -15,14 +15,15 @@
|
||||||
"""Preprocessors for text classification."""
|
"""Preprocessors for text classification."""
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import tempfile
|
|
||||||
from typing import Mapping, Sequence, Tuple, Union
|
from typing import Mapping, Sequence, Tuple, Union
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow_hub
|
import tensorflow_hub
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib
|
||||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||||
from official.nlp.data import classifier_data_lib
|
from official.nlp.data import classifier_data_lib
|
||||||
from official.nlp.tools import tokenization
|
from official.nlp.tools import tokenization
|
||||||
|
@ -75,19 +76,20 @@ def _decode_record(
|
||||||
return bert_features, example["label_ids"]
|
return bert_features, example["label_ids"]
|
||||||
|
|
||||||
|
|
||||||
def _single_file_dataset(
|
def _tfrecord_dataset(
|
||||||
input_file: str, name_to_features: Mapping[str, tf.io.FixedLenFeature]
|
tfrecord_files: Sequence[str],
|
||||||
|
name_to_features: Mapping[str, tf.io.FixedLenFeature],
|
||||||
) -> tf.data.TFRecordDataset:
|
) -> tf.data.TFRecordDataset:
|
||||||
"""Creates a single-file dataset to be passed for BERT custom training.
|
"""Creates a single-file dataset to be passed for BERT custom training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_file: Filepath for the dataset.
|
tfrecord_files: Filepaths for the dataset.
|
||||||
name_to_features: Maps record keys to feature types.
|
name_to_features: Maps record keys to feature types.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset containing BERT model input features and labels.
|
Dataset containing BERT model input features and labels.
|
||||||
"""
|
"""
|
||||||
d = tf.data.TFRecordDataset(input_file)
|
d = tf.data.TFRecordDataset(tfrecord_files)
|
||||||
d = d.map(
|
d = d.map(
|
||||||
lambda record: _decode_record(record, name_to_features),
|
lambda record: _decode_record(record, name_to_features),
|
||||||
num_parallel_calls=tf.data.AUTOTUNE)
|
num_parallel_calls=tf.data.AUTOTUNE)
|
||||||
|
@ -221,15 +223,23 @@ class BertClassifierPreprocessor:
|
||||||
seq_len: Length of the input sequence to the model.
|
seq_len: Length of the input sequence to the model.
|
||||||
vocab_file: File containing the BERT vocab.
|
vocab_file: File containing the BERT vocab.
|
||||||
tokenizer: BERT tokenizer.
|
tokenizer: BERT tokenizer.
|
||||||
|
model_name: Name of the model provided by the model_spec. Used to associate
|
||||||
|
cached files with specific Bert model vocab.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, seq_len: int, do_lower_case: bool, uri: str):
|
def __init__(
|
||||||
|
self, seq_len: int, do_lower_case: bool, uri: str, model_name: str
|
||||||
|
):
|
||||||
self._seq_len = seq_len
|
self._seq_len = seq_len
|
||||||
# Vocab filepath is tied to the BERT module's URI.
|
# Vocab filepath is tied to the BERT module's URI.
|
||||||
self._vocab_file = os.path.join(
|
self._vocab_file = os.path.join(
|
||||||
tensorflow_hub.resolve(uri), "assets", "vocab.txt")
|
tensorflow_hub.resolve(uri), "assets", "vocab.txt"
|
||||||
self._tokenizer = tokenization.FullTokenizer(self._vocab_file,
|
)
|
||||||
do_lower_case)
|
self._do_lower_case = do_lower_case
|
||||||
|
self._tokenizer = tokenization.FullTokenizer(
|
||||||
|
self._vocab_file, self._do_lower_case
|
||||||
|
)
|
||||||
|
self._model_name = model_name
|
||||||
|
|
||||||
def _get_name_to_features(self):
|
def _get_name_to_features(self):
|
||||||
"""Gets the dictionary mapping record keys to feature types."""
|
"""Gets the dictionary mapping record keys to feature types."""
|
||||||
|
@ -244,8 +254,45 @@ class BertClassifierPreprocessor:
|
||||||
"""Returns the vocab file of the BertClassifierPreprocessor."""
|
"""Returns the vocab file of the BertClassifierPreprocessor."""
|
||||||
return self._vocab_file
|
return self._vocab_file
|
||||||
|
|
||||||
|
def _get_tfrecord_cache_files(
|
||||||
|
self, ds_cache_files
|
||||||
|
) -> cache_files_lib.TFRecordCacheFiles:
|
||||||
|
"""Helper to regenerate cache prefix filename using preprocessor info.
|
||||||
|
|
||||||
|
We need to update the dataset cache_prefix cache because the actual cached
|
||||||
|
dataset depends on the preprocessor parameters such as model_name, seq_len,
|
||||||
|
and do_lower_case in addition to the raw dataset parameters which is already
|
||||||
|
included in the ds_cache_files.cache_prefix_filename
|
||||||
|
|
||||||
|
Specifically, the new cache_prefix_filename used by the preprocessor will
|
||||||
|
be a hash generated from the following:
|
||||||
|
1. cache_prefix_filename of the initial raw dataset
|
||||||
|
2. model_name
|
||||||
|
3. seq_len
|
||||||
|
4. do_lower_case
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds_cache_files: TFRecordCacheFiles from the original raw dataset object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new TFRecordCacheFiles object which incorporates the preprocessor
|
||||||
|
parameters.
|
||||||
|
"""
|
||||||
|
hasher = hashlib.md5()
|
||||||
|
hasher.update(ds_cache_files.cache_prefix_filename.encode("utf-8"))
|
||||||
|
hasher.update(self._model_name.encode("utf-8"))
|
||||||
|
hasher.update(str(self._seq_len).encode("utf-8"))
|
||||||
|
hasher.update(str(self._do_lower_case).encode("utf-8"))
|
||||||
|
cache_prefix_filename = hasher.hexdigest()
|
||||||
|
return cache_files_lib.TFRecordCacheFiles(
|
||||||
|
cache_prefix_filename,
|
||||||
|
ds_cache_files.cache_dir,
|
||||||
|
ds_cache_files.num_shards,
|
||||||
|
)
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset:
|
self, dataset: text_classifier_ds.Dataset
|
||||||
|
) -> text_classifier_ds.Dataset:
|
||||||
"""Preprocesses data into input for a BERT-based classifier.
|
"""Preprocesses data into input for a BERT-based classifier.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -254,32 +301,65 @@ class BertClassifierPreprocessor:
|
||||||
Returns:
|
Returns:
|
||||||
Dataset containing (bert_features, label) data.
|
Dataset containing (bert_features, label) data.
|
||||||
"""
|
"""
|
||||||
examples = []
|
ds_cache_files = dataset.tfrecord_cache_files
|
||||||
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
|
# Get new tfrecord_cache_files by including preprocessor information.
|
||||||
_validate_text_and_label(text, label)
|
tfrecord_cache_files = self._get_tfrecord_cache_files(ds_cache_files)
|
||||||
examples.append(
|
if not tfrecord_cache_files.is_cached():
|
||||||
classifier_data_lib.InputExample(
|
print(f"Writing new cache files to {tfrecord_cache_files.cache_prefix}")
|
||||||
guid=str(index),
|
writers = tfrecord_cache_files.get_writers()
|
||||||
text_a=text.numpy()[0].decode("utf-8"),
|
size = 0
|
||||||
text_b=None,
|
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
|
||||||
# InputExample expects the label name rather than the int ID
|
_validate_text_and_label(text, label)
|
||||||
label=dataset.label_names[label.numpy()[0]]))
|
example = classifier_data_lib.InputExample(
|
||||||
|
guid=str(index),
|
||||||
|
text_a=text.numpy()[0].decode("utf-8"),
|
||||||
|
text_b=None,
|
||||||
|
# InputExample expects the label name rather than the int ID
|
||||||
|
# label=dataset.label_names[label.numpy()[0]])
|
||||||
|
label=label.numpy()[0],
|
||||||
|
)
|
||||||
|
feature = classifier_data_lib.convert_single_example(
|
||||||
|
index, example, None, self._seq_len, self._tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
tfrecord_file = os.path.join(tempfile.mkdtemp(), "bert_features.tfrecord")
|
def create_int_feature(values):
|
||||||
classifier_data_lib.file_based_convert_examples_to_features(
|
f = tf.train.Feature(
|
||||||
examples=examples,
|
int64_list=tf.train.Int64List(value=list(values))
|
||||||
label_list=dataset.label_names,
|
)
|
||||||
max_seq_length=self._seq_len,
|
return f
|
||||||
tokenizer=self._tokenizer,
|
|
||||||
output_file=tfrecord_file)
|
features = collections.OrderedDict()
|
||||||
preprocessed_ds = _single_file_dataset(tfrecord_file,
|
features["input_ids"] = create_int_feature(feature.input_ids)
|
||||||
self._get_name_to_features())
|
features["input_mask"] = create_int_feature(feature.input_mask)
|
||||||
|
features["segment_ids"] = create_int_feature(feature.segment_ids)
|
||||||
|
features["label_ids"] = create_int_feature([feature.label_id])
|
||||||
|
tf_example = tf.train.Example(
|
||||||
|
features=tf.train.Features(feature=features)
|
||||||
|
)
|
||||||
|
writers[index % len(writers)].write(tf_example.SerializeToString())
|
||||||
|
size = index + 1
|
||||||
|
for writer in writers:
|
||||||
|
writer.close()
|
||||||
|
metadata = {"size": size, "label_names": dataset.label_names}
|
||||||
|
tfrecord_cache_files.save_metadata(metadata)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Using existing cache files at {tfrecord_cache_files.cache_prefix}"
|
||||||
|
)
|
||||||
|
metadata = tfrecord_cache_files.load_metadata()
|
||||||
|
size = metadata["size"]
|
||||||
|
label_names = metadata["label_names"]
|
||||||
|
preprocessed_ds = _tfrecord_dataset(
|
||||||
|
tfrecord_cache_files.tfrecord_files, self._get_name_to_features()
|
||||||
|
)
|
||||||
return text_classifier_ds.Dataset(
|
return text_classifier_ds.Dataset(
|
||||||
dataset=preprocessed_ds,
|
dataset=preprocessed_ds,
|
||||||
size=dataset.size,
|
size=size,
|
||||||
label_names=dataset.label_names)
|
label_names=label_names,
|
||||||
|
tfrecord_cache_files=tfrecord_cache_files,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
TextClassifierPreprocessor = (
|
TextClassifierPreprocessor = Union[
|
||||||
Union[BertClassifierPreprocessor,
|
BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor
|
||||||
AverageWordEmbeddingClassifierPreprocessor])
|
]
|
||||||
|
|
|
@ -13,14 +13,17 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
|
import io
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest import mock as unittest_mock
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
|
import mock
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.testing as npt
|
import numpy.testing as npt
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import cache_files
|
||||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||||
|
@ -84,11 +87,12 @@ class PreprocessorTest(tf.test.TestCase):
|
||||||
csv_file = self._get_csv_file()
|
csv_file = self._get_csv_file()
|
||||||
dataset = text_classifier_ds.Dataset.from_csv(
|
dataset = text_classifier_ds.Dataset.from_csv(
|
||||||
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
||||||
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
|
||||||
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||||
seq_len=5,
|
seq_len=5,
|
||||||
do_lower_case=bert_spec.do_lower_case,
|
do_lower_case=bert_spec.do_lower_case,
|
||||||
uri=bert_spec.downloaded_files.get_path(),
|
uri=bert_spec.downloaded_files.get_path(),
|
||||||
|
model_name=bert_spec.name,
|
||||||
)
|
)
|
||||||
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
||||||
labels = []
|
labels = []
|
||||||
|
@ -97,18 +101,91 @@ class PreprocessorTest(tf.test.TestCase):
|
||||||
self.assertEqual(label.shape, [1])
|
self.assertEqual(label.shape, [1])
|
||||||
labels.append(label.numpy()[0])
|
labels.append(label.numpy()[0])
|
||||||
self.assertSameElements(
|
self.assertSameElements(
|
||||||
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids'])
|
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids']
|
||||||
|
)
|
||||||
for feature in features.values():
|
for feature in features.values():
|
||||||
self.assertEqual(feature.shape, [1, 5])
|
self.assertEqual(feature.shape, [1, 5])
|
||||||
input_masks.append(features['input_mask'].numpy()[0])
|
input_masks.append(features['input_mask'].numpy()[0])
|
||||||
npt.assert_array_equal(features['input_type_ids'].numpy()[0],
|
npt.assert_array_equal(
|
||||||
[0, 0, 0, 0, 0])
|
features['input_type_ids'].numpy()[0], [0, 0, 0, 0, 0]
|
||||||
|
)
|
||||||
npt.assert_array_equal(
|
npt.assert_array_equal(
|
||||||
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]]))
|
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])
|
||||||
|
)
|
||||||
self.assertEqual(labels, [1, 0])
|
self.assertEqual(labels, [1, 0])
|
||||||
|
|
||||||
|
def test_bert_preprocessor_cache(self):
|
||||||
|
csv_file = self._get_csv_file()
|
||||||
|
dataset = text_classifier_ds.Dataset.from_csv(
|
||||||
|
filename=csv_file,
|
||||||
|
csv_params=self.CSV_PARAMS_,
|
||||||
|
cache_dir=self.get_temp_dir(),
|
||||||
|
)
|
||||||
|
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
|
||||||
|
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||||
|
seq_len=5,
|
||||||
|
do_lower_case=bert_spec.do_lower_case,
|
||||||
|
uri=bert_spec.downloaded_files.get_path(),
|
||||||
|
model_name=bert_spec.name,
|
||||||
|
)
|
||||||
|
ds_cache_files = dataset.tfrecord_cache_files
|
||||||
|
preprocessed_cache_files = bert_preprocessor._get_tfrecord_cache_files(
|
||||||
|
ds_cache_files
|
||||||
|
)
|
||||||
|
self.assertFalse(preprocessed_cache_files.is_cached())
|
||||||
|
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
||||||
|
self.assertTrue(preprocessed_cache_files.is_cached())
|
||||||
|
self.assertEqual(
|
||||||
|
preprocessed_dataset.tfrecord_cache_files, preprocessed_cache_files
|
||||||
|
)
|
||||||
|
|
||||||
|
# The second time running preprocessor, it should load from cache directly
|
||||||
|
mock_stdout = io.StringIO()
|
||||||
|
with mock.patch('sys.stdout', mock_stdout):
|
||||||
|
_ = bert_preprocessor.preprocess(dataset)
|
||||||
|
self.assertEqual(
|
||||||
|
mock_stdout.getvalue(),
|
||||||
|
'Using existing cache files at'
|
||||||
|
f' {preprocessed_cache_files.cache_prefix}\n',
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case):
|
||||||
|
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||||
|
seq_len=seq_len,
|
||||||
|
do_lower_case=do_lower_case,
|
||||||
|
uri=bert_spec.downloaded_files.get_path(),
|
||||||
|
model_name=bert_spec.name,
|
||||||
|
)
|
||||||
|
new_cf = bert_preprocessor._get_tfrecord_cache_files(cf)
|
||||||
|
return new_cf.cache_prefix_filename
|
||||||
|
|
||||||
|
def test_bert_get_tfrecord_cache_files(self):
|
||||||
|
# Test to ensure regenerated cache_files have different prefixes
|
||||||
|
all_cf_prefixes = set()
|
||||||
|
cf = cache_files.TFRecordCacheFiles(
|
||||||
|
cache_prefix_filename='cache_prefix',
|
||||||
|
cache_dir=self.get_temp_dir(),
|
||||||
|
num_shards=1,
|
||||||
|
)
|
||||||
|
exbert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
|
||||||
|
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, True))
|
||||||
|
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 10, True))
|
||||||
|
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, False))
|
||||||
|
mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||||
|
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True))
|
||||||
|
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True))
|
||||||
|
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, False))
|
||||||
|
new_cf = cache_files.TFRecordCacheFiles(
|
||||||
|
cache_prefix_filename='new_cache_prefix',
|
||||||
|
cache_dir=self.get_temp_dir(),
|
||||||
|
num_shards=1,
|
||||||
|
)
|
||||||
|
all_cf_prefixes.add(self._get_new_prefix(new_cf, exbert_spec, 5, True))
|
||||||
|
|
||||||
|
# Each item of all_cf_prefixes should be unique, so 7 total.
|
||||||
|
self.assertLen(all_cf_prefixes, 7)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Load compressed models from tensorflow_hub
|
# Load compressed models from tensorflow_hub
|
||||||
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
|
@ -435,6 +435,7 @@ class _BertClassifier(TextClassifier):
|
||||||
seq_len=self._model_options.seq_len,
|
seq_len=self._model_options.seq_len,
|
||||||
do_lower_case=self._model_spec.do_lower_case,
|
do_lower_case=self._model_spec.do_lower_case,
|
||||||
uri=self._model_spec.downloaded_files.get_path(),
|
uri=self._model_spec.downloaded_files.get_path(),
|
||||||
|
model_name=self._model_spec.name,
|
||||||
)
|
)
|
||||||
return (self._text_preprocessor.preprocess(train_data),
|
return (self._text_preprocessor.preprocess(train_data),
|
||||||
self._text_preprocessor.preprocess(validation_data))
|
self._text_preprocessor.preprocess(validation_data))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user