mediapipe/mediapipe/model_maker/python/text/text_classifier/text_classifier.py
2023-06-27 18:05:15 -07:00

538 lines
21 KiB
Python

# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""API for text classification."""
import abc
import os
import tempfile
from typing import Any, Optional, Sequence, Tuple
import tensorflow as tf
from tensorflow_addons import optimizers as tfa_optimizers
import tensorflow_hub as hub
from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import metrics
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
from mediapipe.model_maker.python.text.text_classifier import preprocessor
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.metadata.metadata_writers import text_classifier as text_classifier_writer
def _validate(options: text_classifier_options.TextClassifierOptions):
"""Validates that `model_options` and `supported_model` are compatible.
Args:
options: Options for creating and training a text classifier.
Raises:
ValueError if there is a mismatch between `model_options` and
`supported_model`.
"""
if options.model_options is None:
return
if (isinstance(options.model_options, mo.AverageWordEmbeddingModelOptions) and
(options.supported_model !=
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
f" got {options.supported_model}")
if isinstance(options.model_options, mo.BertModelOptions) and (
options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER
and options.supported_model != ms.SupportedModels.EXBERT_CLASSIFIER
):
raise ValueError(
"Expected a Bert Classifier(MobileBERT or EXBERT), got "
f"{options.supported_model}"
)
class TextClassifier(classifier.Classifier):
"""API for creating and training a text classification model."""
def __init__(
self, model_spec: Any, label_names: Sequence[str], shuffle: bool
):
super().__init__(
model_spec=model_spec, label_names=label_names, shuffle=shuffle
)
self._model_spec = model_spec
self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
@classmethod
def create(
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
options: text_classifier_options.TextClassifierOptions
) -> "TextClassifier":
"""Factory function that creates and trains a text classifier.
Note that `train_data` and `validation_data` are expected to share the same
`label_names` since they should be split from the same dataset.
Args:
train_data: Training data.
validation_data: Validation data.
options: Options for creating and training the text classifier.
Returns:
A text classifier.
Raises:
ValueError if `train_data` and `validation_data` do not have the
same label_names or `options` contains an unknown `supported_model`
"""
if train_data.label_names != validation_data.label_names:
raise ValueError(
f"Training data label names {train_data.label_names} not equal to "
f"validation data label names {validation_data.label_names}")
_validate(options)
if options.model_options is None:
options.model_options = options.supported_model.value().model_options
if options.hparams is None:
options.hparams = options.supported_model.value().hparams
if (
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
):
text_classifier = (
_BertClassifier.create_bert_classifier(train_data, validation_data,
options,
train_data.label_names))
elif (options.supported_model ==
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
text_classifier = (
_AverageWordEmbeddingClassifier
.create_average_word_embedding_classifier(train_data, validation_data,
options,
train_data.label_names))
else:
raise ValueError(f"Unknown model {options.supported_model}")
return text_classifier
def evaluate(
self,
data: ds.Dataset,
batch_size: int = 32,
desired_precisions: Optional[Sequence[float]] = None,
desired_recalls: Optional[Sequence[float]] = None,
) -> Any:
"""Overrides Classifier.evaluate().
Args:
data: Evaluation dataset. Must be a TextClassifier Dataset.
batch_size: Number of samples per evaluation step.
desired_precisions: If specified, adds a RecallAtPrecision metric per
desired_precisions[i] entry which tracks the recall given the constraint
on precision. Only supported for binary classification.
desired_recalls: If specified, adds a PrecisionAtRecall metric per
desired_recalls[i] entry which tracks the precision given the constraint
on recall. Only supported for binary classification.
Returns:
The loss value and accuracy.
Raises:
ValueError if `data` is not a TextClassifier Dataset.
"""
# This override is needed because TextClassifier preprocesses its data
# outside of the `gen_tf_dataset()` method. The preprocess call also
# requires a TextClassifier Dataset instead of a core Dataset.
if not isinstance(data, text_ds.Dataset):
raise ValueError("Need a TextClassifier Dataset.")
processed_data = self._text_preprocessor.preprocess(data)
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
additional_metrics = []
if desired_precisions and len(data.label_names) == 2:
for precision in desired_precisions:
additional_metrics.append(
metrics.BinarySparseRecallAtPrecision(
precision, name=f"recall_at_precision_{precision}"
)
)
if desired_recalls and len(data.label_names) == 2:
for recall in desired_recalls:
additional_metrics.append(
metrics.BinarySparsePrecisionAtRecall(
recall, name=f"precision_at_recall_{recall}"
)
)
metric_functions = self._metric_functions + additional_metrics
self._model.compile(
optimizer=self._optimizer,
loss=self._loss_function,
metrics=metric_functions,
)
return self._model.evaluate(dataset)
def export_model(
self,
model_name: str = "model.tflite",
quantization_config: Optional[quantization.QuantizationConfig] = None):
"""Converts and saves the model to a TFLite file with metadata included.
Note that only the TFLite file is needed for deployment. This function also
saves a metadata.json file to the same directory as the TFLite file which
can be used to interpret the metadata content in the TFLite file.
Args:
model_name: File name to save TFLite model with metadata. The full export
path is {self._hparams.export_dir}/{model_name}.
quantization_config: The configuration for model quantization.
"""
tflite_file = os.path.join(self._hparams.export_dir, model_name)
tf.io.gfile.makedirs(os.path.dirname(tflite_file))
metadata_file = os.path.join(self._hparams.export_dir, "metadata.json")
tflite_model = model_util.convert_to_tflite(
model=self._model, quantization_config=quantization_config)
vocab_filepath = os.path.join(tempfile.mkdtemp(), "vocab.txt")
self._save_vocab(vocab_filepath)
writer = self._get_metadata_writer(tflite_model, vocab_filepath)
tflite_model_with_metadata, metadata_json = writer.populate()
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
with tf.io.gfile.GFile(metadata_file, "w") as f:
f.write(metadata_json)
@abc.abstractmethod
def _save_vocab(self, vocab_filepath: str):
"""Saves the preprocessor's vocab to `vocab_filepath`."""
@abc.abstractmethod
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
"""Gets the metadata writer for the text classifier TFLite model."""
class _AverageWordEmbeddingClassifier(TextClassifier):
"""APIs to help create and train an Average Word Embedding text classifier."""
_DELIM_REGEX_PATTERN = r"[^\w\']+"
def __init__(
self,
model_spec: ms.AverageWordEmbeddingClassifierSpec,
model_options: mo.AverageWordEmbeddingModelOptions,
hparams: hp.AverageWordEmbeddingHParams,
label_names: Sequence[str],
):
super().__init__(model_spec, label_names, hparams.shuffle)
self._model_options = model_options
self._hparams = hparams
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._loss_function = "sparse_categorical_crossentropy"
self._metric_functions = [
"accuracy",
metrics.SparsePrecision(name="precision", dtype=tf.float32),
metrics.SparseRecall(name="recall", dtype=tf.float32),
]
self._text_preprocessor: (
preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None
@classmethod
def create_average_word_embedding_classifier(
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
options: text_classifier_options.TextClassifierOptions,
label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier":
"""Creates, trains, and returns an Average Word Embedding classifier.
Args:
train_data: Training data.
validation_data: Validation data.
options: Options for creating and training the text classifier.
label_names: Label names used in the data.
Returns:
An Average Word Embedding classifier.
"""
average_word_embedding_classifier = _AverageWordEmbeddingClassifier(
model_spec=options.supported_model.value(),
model_options=options.model_options,
hparams=options.hparams,
label_names=train_data.label_names)
average_word_embedding_classifier._create_and_train_model(
train_data, validation_data)
return average_word_embedding_classifier
def _create_and_train_model(self, train_data: text_ds.Dataset,
validation_data: text_ds.Dataset):
"""Creates the Average Word Embedding classifier keras model and trains it.
Args:
train_data: Training data.
validation_data: Validation data.
"""
(processed_train_data, processed_validation_data) = (
self._load_and_run_preprocessor(train_data, validation_data))
self._create_model()
self._optimizer = "rmsprop"
self._train_model(processed_train_data, processed_validation_data)
def _load_and_run_preprocessor(
self, train_data: text_ds.Dataset, validation_data: text_ds.Dataset
) -> Tuple[text_ds.Dataset, text_ds.Dataset]:
"""Runs an AverageWordEmbeddingClassifierPreprocessor on the data.
Args:
train_data: Training data.
validation_data: Validation data.
Returns:
Preprocessed training data and preprocessed validation data.
"""
train_texts = [text.numpy()[0] for text, _ in train_data.gen_tf_dataset()]
validation_texts = [
text.numpy()[0] for text, _ in validation_data.gen_tf_dataset()
]
self._text_preprocessor = (
preprocessor.AverageWordEmbeddingClassifierPreprocessor(
seq_len=self._model_options.seq_len,
do_lower_case=self._model_options.do_lower_case,
texts=train_texts + validation_texts,
vocab_size=self._model_options.vocab_size))
return self._text_preprocessor.preprocess(
train_data), self._text_preprocessor.preprocess(validation_data)
def _create_model(self):
"""Creates an Average Word Embedding model."""
self._model = tf.keras.Sequential([
tf.keras.layers.InputLayer(
input_shape=[self._model_options.seq_len],
dtype=tf.int32,
name="input_ids",
),
tf.keras.layers.Embedding(
len(self._text_preprocessor.get_vocab()),
self._model_options.wordvec_dim,
input_length=self._model_options.seq_len,
),
tf.keras.layers.GlobalAveragePooling1D(),
tf.keras.layers.Dense(
self._model_options.wordvec_dim, activation=tf.nn.relu
),
tf.keras.layers.Dropout(self._model_options.dropout_rate),
tf.keras.layers.Dense(self._num_classes, activation="softmax"),
])
def _save_vocab(self, vocab_filepath: str):
with tf.io.gfile.GFile(vocab_filepath, "w") as f:
for token, index in self._text_preprocessor.get_vocab().items():
f.write(f"{token} {index}\n")
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
return text_classifier_writer.MetadataWriter.create_for_regex_model(
model_buffer=tflite_model,
regex_tokenizer=metadata_writer.RegexTokenizer(
# TODO: Align with MediaPipe's RegexTokenizer.
delim_regex_pattern=self._DELIM_REGEX_PATTERN,
vocab_file_path=vocab_filepath),
labels=metadata_writer.Labels().add(list(self._label_names)))
class _BertClassifier(TextClassifier):
"""APIs to help create and train a BERT-based text classifier."""
_INITIALIZER_RANGE = 0.02
def __init__(
self,
model_spec: ms.BertClassifierSpec,
model_options: mo.BertModelOptions,
hparams: hp.BertHParams,
label_names: Sequence[str],
):
super().__init__(model_spec, label_names, hparams.shuffle)
self._hparams = hparams
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._model_options = model_options
with self._hparams.get_strategy().scope():
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
self._metric_functions = [
tf.keras.metrics.SparseCategoricalAccuracy(
"test_accuracy", dtype=tf.float32
),
metrics.SparsePrecision(name="precision", dtype=tf.float32),
metrics.SparseRecall(name="recall", dtype=tf.float32),
]
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
@classmethod
def create_bert_classifier(
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
options: text_classifier_options.TextClassifierOptions,
label_names: Sequence[str]) -> "_BertClassifier":
"""Creates, trains, and returns a BERT-based classifier.
Args:
train_data: Training data.
validation_data: Validation data.
options: Options for creating and training the text classifier.
label_names: Label names used in the data.
Returns:
A BERT-based classifier.
"""
bert_classifier = _BertClassifier(
model_spec=options.supported_model.value(),
model_options=options.model_options,
hparams=options.hparams,
label_names=train_data.label_names)
bert_classifier._create_and_train_model(train_data, validation_data)
return bert_classifier
def _create_and_train_model(self, train_data: text_ds.Dataset,
validation_data: text_ds.Dataset):
"""Creates the BERT-based classifier keras model and trains it.
Args:
train_data: Training data.
validation_data: Validation data.
"""
(processed_train_data, processed_validation_data) = (
self._load_and_run_preprocessor(train_data, validation_data))
with self._hparams.get_strategy().scope():
self._create_model()
self._create_optimizer(processed_train_data)
self._train_model(processed_train_data, processed_validation_data)
def _load_and_run_preprocessor(
self, train_data: text_ds.Dataset, validation_data: text_ds.Dataset
) -> Tuple[text_ds.Dataset, text_ds.Dataset]:
"""Loads a BertClassifierPreprocessor and runs it on the data.
Args:
train_data: Training data.
validation_data: Validation data.
Returns:
Preprocessed training data and preprocessed validation data.
"""
self._text_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=self._model_options.seq_len,
do_lower_case=self._model_spec.do_lower_case,
uri=self._model_spec.downloaded_files.get_path(),
)
return (self._text_preprocessor.preprocess(train_data),
self._text_preprocessor.preprocess(validation_data))
def _create_model(self):
"""Creates a BERT-based classifier model.
The model architecture consists of stacking a dense classification layer and
dropout layer on top of the BERT encoder outputs.
"""
encoder_inputs = dict(
input_word_ids=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32),
input_mask=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32),
input_type_ids=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32),
)
encoder = hub.KerasLayer(
self._model_spec.downloaded_files.get_path(),
trainable=self._model_options.do_fine_tuning,
)
encoder_outputs = encoder(encoder_inputs)
pooled_output = encoder_outputs["pooled_output"]
output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)(
pooled_output)
initializer = tf.keras.initializers.TruncatedNormal(
stddev=self._INITIALIZER_RANGE)
output = tf.keras.layers.Dense(
self._num_classes,
kernel_initializer=initializer,
name="output",
activation="softmax",
dtype=tf.float32)(
output)
self._model = tf.keras.Model(inputs=encoder_inputs, outputs=output)
def _create_optimizer(self, train_data: text_ds.Dataset):
"""Loads an optimizer with a learning rate schedule.
The decay steps in the learning rate schedule depend on the
`steps_per_epoch` which may depend on the size of the training data.
Args:
train_data: Training data.
"""
self._hparams.steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=self._hparams.steps_per_epoch,
batch_size=self._hparams.batch_size,
train_data=train_data)
total_steps = self._hparams.steps_per_epoch * self._hparams.epochs
warmup_steps = int(total_steps * 0.1)
initial_lr = self._hparams.learning_rate
# Implements linear decay of the learning rate.
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=initial_lr,
decay_steps=total_steps,
end_learning_rate=0.0,
power=1.0)
if warmup_steps:
lr_schedule = model_util.WarmUp(
initial_learning_rate=initial_lr,
decay_schedule_fn=lr_schedule,
warmup_steps=warmup_steps)
if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
self._optimizer = tf.keras.optimizers.experimental.AdamW(
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0
)
self._optimizer.exclude_from_weight_decay(
var_names=["LayerNorm", "layer_norm", "bias"]
)
elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
self._optimizer = tfa_optimizers.LAMB(
lr_schedule,
weight_decay_rate=0.01,
epsilon=1e-6,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
global_clipnorm=1.0,
)
else:
raise ValueError(
"BertHParams.optimizer must be set to ADAM or "
f"LAMB. Got {self._hparams.optimizer}."
)
def _save_vocab(self, vocab_filepath: str):
tf.io.gfile.copy(
self._text_preprocessor.get_vocab_file(),
vocab_filepath,
overwrite=True)
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
return text_classifier_writer.MetadataWriter.create_for_bert_model(
model_buffer=tflite_model,
tokenizer=metadata_writer.BertTokenizer(vocab_filepath),
labels=metadata_writer.Labels().add(list(self._label_names)),
ids_name=self._model_spec.tflite_input_name["ids"],
mask_name=self._model_spec.tflite_input_name["mask"],
segment_name=self._model_spec.tflite_input_name["segment_ids"])