1. Move evaluation onto GPU/TPU hardware if available.
2. Move desired_precision and desired_recall from evaluate to hyperparameters so recall@precision metrics will be reported for both training and evaluation. This also fixes a bug where recompiling the model with the previously initialized metric objects would not properly reset the metric states. 3. Remove redundant label_names from create_... class methods in text_classifier. This information is already provided by the datasets. 4. Change loss function to FocalLoss. 5. Re-enable text_classifier unit tests using ExBert 6. Add input names to avoid flaky auto-assigned input names. PiperOrigin-RevId: 550992146
This commit is contained in:
parent
85c3fed70a
commit
bd7888cc0c
|
@ -59,7 +59,7 @@ class FocalLoss(tf.keras.losses.Loss):
|
|||
"""
|
||||
|
||||
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
|
||||
"""Constructor.
|
||||
"""Initializes FocalLoss.
|
||||
|
||||
Args:
|
||||
gamma: Focal loss gamma, as described in class docs.
|
||||
|
@ -115,6 +115,51 @@ class FocalLoss(tf.keras.losses.Loss):
|
|||
return tf.reduce_sum(losses) / batch_size
|
||||
|
||||
|
||||
class SparseFocalLoss(FocalLoss):
|
||||
"""Sparse implementation of Focal Loss.
|
||||
|
||||
This is the same as FocalLoss, except the labels are expected to be class ids
|
||||
instead of 1-hot encoded vectors. See FocalLoss class documentation defined
|
||||
in this same file for more details.
|
||||
|
||||
Example usage:
|
||||
>>> y_true = [1, 2]
|
||||
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
|
||||
>>> gamma = 2
|
||||
>>> focal_loss = SparseFocalLoss(gamma, 3)
|
||||
>>> focal_loss(y_true, y_pred).numpy()
|
||||
0.9326
|
||||
|
||||
>>> # Calling with 'sample_weight'.
|
||||
>>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
|
||||
0.6528
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, gamma, num_classes, class_weight: Optional[Sequence[float]] = None
|
||||
):
|
||||
"""Initializes SparseFocalLoss.
|
||||
|
||||
Args:
|
||||
gamma: Focal loss gamma, as described in class docs.
|
||||
num_classes: Number of classes.
|
||||
class_weight: A weight to apply to the loss, one for each class. The
|
||||
weight is applied for each input where the ground truth label matches.
|
||||
"""
|
||||
super().__init__(gamma, class_weight=class_weight)
|
||||
self._num_classes = num_classes
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
y_true: tf.Tensor,
|
||||
y_pred: tf.Tensor,
|
||||
sample_weight: Optional[tf.Tensor] = None,
|
||||
) -> tf.Tensor:
|
||||
y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32)
|
||||
y_true_one_hot = tf.one_hot(y_true, self._num_classes)
|
||||
return super().__call__(y_true_one_hot, y_pred, sample_weight=sample_weight)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PerceptualLossWeight:
|
||||
"""The weight for each perceptual loss.
|
||||
|
|
|
@ -101,6 +101,23 @@ class FocalLossTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertNear(loss, expected_loss, 1e-4)
|
||||
|
||||
|
||||
class SparseFocalLossTest(tf.test.TestCase):
|
||||
|
||||
def test_sparse_focal_loss_matches_focal_loss(self):
|
||||
num_classes = 2
|
||||
y_pred = tf.constant([[0.8, 0.2], [0.3, 0.7]])
|
||||
y_true = tf.constant([1, 0])
|
||||
y_true_one_hot = tf.one_hot(y_true, num_classes)
|
||||
for gamma in [0.0, 0.5, 1.0]:
|
||||
expected_loss_fn = loss_functions.FocalLoss(gamma=gamma)
|
||||
loss_fn = loss_functions.SparseFocalLoss(
|
||||
gamma=gamma, num_classes=num_classes
|
||||
)
|
||||
expected_loss = expected_loss_fn(y_true_one_hot, y_pred)
|
||||
loss = loss_fn(y_true, y_pred)
|
||||
self.assertNear(loss, expected_loss, 1e-4)
|
||||
|
||||
|
||||
class MockPerceptualLoss(loss_functions.PerceptualLoss):
|
||||
"""A mock class with implementation of abstract methods for testing."""
|
||||
|
||||
|
|
|
@ -131,6 +131,7 @@ py_library(
|
|||
":text_classifier_options",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||
"//mediapipe/model_maker/python/core/utils:loss_functions",
|
||||
"//mediapipe/model_maker/python/core/utils:metrics",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||
|
@ -154,6 +155,7 @@ py_test(
|
|||
],
|
||||
deps = [
|
||||
":text_classifier_import",
|
||||
"//mediapipe/model_maker/python/core/utils:loss_functions",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing import Union
|
||||
from typing import Sequence, Union
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
|
||||
|
@ -39,16 +39,34 @@ class BertHParams(hp.BaseHParams):
|
|||
|
||||
Attributes:
|
||||
learning_rate: Learning rate to use for gradient descent training.
|
||||
batch_size: Batch size for training.
|
||||
epochs: Number of training iterations over the dataset.
|
||||
optimizer: Optimizer to use for training. Only supported values are "adamw"
|
||||
and "lamb".
|
||||
end_learning_rate: End learning rate for linear decay. Defaults to 0.
|
||||
batch_size: Batch size for training. Defaults to 48.
|
||||
epochs: Number of training iterations over the dataset. Defaults to 2.
|
||||
optimizer: Optimizer to use for training. Supported values are defined in
|
||||
BertOptimizer enum: ADAMW and LAMB.
|
||||
weight_decay: Weight decay of the optimizer. Defaults to 0.01.
|
||||
desired_precisions: If specified, adds a RecallAtPrecision metric per
|
||||
desired_precisions[i] entry which tracks the recall given the constraint
|
||||
on precision. Only supported for binary classification.
|
||||
desired_recalls: If specified, adds a PrecisionAtRecall metric per
|
||||
desired_recalls[i] entry which tracks the precision given the constraint
|
||||
on recall. Only supported for binary classification.
|
||||
gamma: Gamma parameter for focal loss. To use cross entropy loss, set this
|
||||
value to 0. Defaults to 2.0.
|
||||
"""
|
||||
|
||||
learning_rate: float = 3e-5
|
||||
end_learning_rate: float = 0.0
|
||||
|
||||
batch_size: int = 48
|
||||
epochs: int = 2
|
||||
optimizer: BertOptimizer = BertOptimizer.ADAMW
|
||||
weight_decay: float = 0.01
|
||||
|
||||
desired_precisions: Sequence[float] = dataclasses.field(default_factory=list)
|
||||
desired_recalls: Sequence[float] = dataclasses.field(default_factory=list)
|
||||
|
||||
gamma: float = 2.0
|
||||
|
||||
|
||||
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
|
||||
|
|
|
@ -79,11 +79,6 @@ mobilebert_classifier_spec = functools.partial(
|
|||
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||
),
|
||||
name='MobileBert',
|
||||
tflite_input_name={
|
||||
'ids': 'serving_default_input_1:0',
|
||||
'segment_ids': 'serving_default_input_2:0',
|
||||
'mask': 'serving_default_input_3:0',
|
||||
},
|
||||
)
|
||||
|
||||
exbert_classifier_spec = functools.partial(
|
||||
|
@ -93,11 +88,6 @@ exbert_classifier_spec = functools.partial(
|
|||
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||
),
|
||||
name='ExBert',
|
||||
tflite_input_name={
|
||||
'ids': 'serving_default_input_1:0',
|
||||
'segment_ids': 'serving_default_input_2:0',
|
||||
'mask': 'serving_default_input_3:0',
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -46,11 +46,13 @@ class ModelSpecTest(tf.test.TestCase):
|
|||
self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path()))
|
||||
self.assertTrue(model_spec_obj.do_lower_case)
|
||||
self.assertEqual(
|
||||
model_spec_obj.tflite_input_name, {
|
||||
'ids': 'serving_default_input_1:0',
|
||||
'mask': 'serving_default_input_3:0',
|
||||
'segment_ids': 'serving_default_input_2:0'
|
||||
})
|
||||
model_spec_obj.tflite_input_name,
|
||||
{
|
||||
'ids': 'serving_default_input_word_ids:0',
|
||||
'mask': 'serving_default_input_mask:0',
|
||||
'segment_ids': 'serving_default_input_type_ids:0',
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
model_spec_obj.model_options,
|
||||
classifier_model_options.BertModelOptions(
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
}
|
||||
},
|
||||
{
|
||||
"name": "mask",
|
||||
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
|
||||
"name": "segment_ids",
|
||||
"description": "0 for the first sequence, 1 for the second sequence if exists.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
|
@ -27,8 +27,8 @@
|
|||
}
|
||||
},
|
||||
{
|
||||
"name": "segment_ids",
|
||||
"description": "0 for the first sequence, 1 for the second sequence if exists.",
|
||||
"name": "mask",
|
||||
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
|
|
|
@ -24,6 +24,7 @@ import tensorflow_hub as hub
|
|||
|
||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||
from mediapipe.model_maker.python.core.tasks import classifier
|
||||
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||
from mediapipe.model_maker.python.core.utils import metrics
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
|
@ -116,17 +117,14 @@ class TextClassifier(classifier.Classifier):
|
|||
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
|
||||
):
|
||||
text_classifier = (
|
||||
_BertClassifier.create_bert_classifier(train_data, validation_data,
|
||||
options,
|
||||
train_data.label_names))
|
||||
text_classifier = _BertClassifier.create_bert_classifier(
|
||||
train_data, validation_data, options
|
||||
)
|
||||
elif (options.supported_model ==
|
||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
||||
text_classifier = (
|
||||
_AverageWordEmbeddingClassifier
|
||||
.create_average_word_embedding_classifier(train_data, validation_data,
|
||||
options,
|
||||
train_data.label_names))
|
||||
text_classifier = _AverageWordEmbeddingClassifier.create_average_word_embedding_classifier(
|
||||
train_data, validation_data, options
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown model {options.supported_model}")
|
||||
|
||||
|
@ -166,28 +164,8 @@ class TextClassifier(classifier.Classifier):
|
|||
processed_data = self._text_preprocessor.preprocess(data)
|
||||
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
|
||||
|
||||
additional_metrics = []
|
||||
if desired_precisions and len(data.label_names) == 2:
|
||||
for precision in desired_precisions:
|
||||
additional_metrics.append(
|
||||
metrics.BinarySparseRecallAtPrecision(
|
||||
precision, name=f"recall_at_precision_{precision}"
|
||||
)
|
||||
)
|
||||
if desired_recalls and len(data.label_names) == 2:
|
||||
for recall in desired_recalls:
|
||||
additional_metrics.append(
|
||||
metrics.BinarySparsePrecisionAtRecall(
|
||||
recall, name=f"precision_at_recall_{recall}"
|
||||
)
|
||||
)
|
||||
metric_functions = self._metric_functions + additional_metrics
|
||||
self._model.compile(
|
||||
optimizer=self._optimizer,
|
||||
loss=self._loss_function,
|
||||
metrics=metric_functions,
|
||||
)
|
||||
return self._model.evaluate(dataset)
|
||||
with self._hparams.get_strategy().scope():
|
||||
return self._model.evaluate(dataset)
|
||||
|
||||
def export_model(
|
||||
self,
|
||||
|
@ -255,16 +233,17 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
|
|||
|
||||
@classmethod
|
||||
def create_average_word_embedding_classifier(
|
||||
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
|
||||
cls,
|
||||
train_data: text_ds.Dataset,
|
||||
validation_data: text_ds.Dataset,
|
||||
options: text_classifier_options.TextClassifierOptions,
|
||||
label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier":
|
||||
) -> "_AverageWordEmbeddingClassifier":
|
||||
"""Creates, trains, and returns an Average Word Embedding classifier.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
options: Options for creating and training the text classifier.
|
||||
label_names: Label names used in the data.
|
||||
|
||||
Returns:
|
||||
An Average Word Embedding classifier.
|
||||
|
@ -370,28 +349,25 @@ class _BertClassifier(TextClassifier):
|
|||
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||
self._model_options = model_options
|
||||
with self._hparams.get_strategy().scope():
|
||||
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
self._metric_functions = [
|
||||
tf.keras.metrics.SparseCategoricalAccuracy(
|
||||
"test_accuracy", dtype=tf.float32
|
||||
),
|
||||
metrics.SparsePrecision(name="precision", dtype=tf.float32),
|
||||
metrics.SparseRecall(name="recall", dtype=tf.float32),
|
||||
]
|
||||
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||
self._loss_function = loss_functions.SparseFocalLoss(
|
||||
self._hparams.gamma, self._num_classes
|
||||
)
|
||||
self._metric_functions = self._create_metrics()
|
||||
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||
|
||||
@classmethod
|
||||
def create_bert_classifier(
|
||||
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
|
||||
cls,
|
||||
train_data: text_ds.Dataset,
|
||||
validation_data: text_ds.Dataset,
|
||||
options: text_classifier_options.TextClassifierOptions,
|
||||
label_names: Sequence[str]) -> "_BertClassifier":
|
||||
) -> "_BertClassifier":
|
||||
"""Creates, trains, and returns a BERT-based classifier.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
options: Options for creating and training the text classifier.
|
||||
label_names: Label names used in the data.
|
||||
|
||||
Returns:
|
||||
A BERT-based classifier.
|
||||
|
@ -437,8 +413,57 @@ class _BertClassifier(TextClassifier):
|
|||
uri=self._model_spec.downloaded_files.get_path(),
|
||||
model_name=self._model_spec.name,
|
||||
)
|
||||
return (self._text_preprocessor.preprocess(train_data),
|
||||
self._text_preprocessor.preprocess(validation_data))
|
||||
return (
|
||||
self._text_preprocessor.preprocess(train_data),
|
||||
self._text_preprocessor.preprocess(validation_data),
|
||||
)
|
||||
|
||||
def _create_metrics(self):
|
||||
"""Creates metrics for training and evaluation.
|
||||
|
||||
The default metrics are accuracy, precision, and recall.
|
||||
|
||||
For binary classification tasks only (num_classes=2):
|
||||
Users can configure PrecisionAtRecall and RecallAtPrecision metrics using
|
||||
the desired_presisions and desired_recalls fields in BertHParams.
|
||||
|
||||
Returns:
|
||||
A list of tf.keras.Metric subclasses which can be used with model.compile
|
||||
"""
|
||||
metric_functions = [
|
||||
tf.keras.metrics.SparseCategoricalAccuracy(
|
||||
"accuracy", dtype=tf.float32
|
||||
),
|
||||
metrics.SparsePrecision(name="precision", dtype=tf.float32),
|
||||
metrics.SparseRecall(name="recall", dtype=tf.float32),
|
||||
]
|
||||
if self._num_classes == 2:
|
||||
if self._hparams.desired_precisions:
|
||||
for desired_precision in self._hparams.desired_precisions:
|
||||
metric_functions.append(
|
||||
metrics.BinarySparseRecallAtPrecision(
|
||||
desired_precision,
|
||||
name=f"recall_at_precision_{desired_precision}",
|
||||
num_thresholds=1000,
|
||||
)
|
||||
)
|
||||
if self._hparams.desired_recalls:
|
||||
for desired_recall in self._hparams.desired_recalls:
|
||||
metric_functions.append(
|
||||
metrics.BinarySparseRecallAtPrecision(
|
||||
desired_recall,
|
||||
name=f"precision_at_recall_{desired_recall}",
|
||||
num_thresholds=1000,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if self._hparams.desired_precisions or self._hparams.desired_recalls:
|
||||
raise ValueError(
|
||||
"desired_recalls and desired_precisions parameters are binary"
|
||||
" metrics and not supported for num_classes > 2. Found"
|
||||
f" num_classes: {self._num_classes}"
|
||||
)
|
||||
return metric_functions
|
||||
|
||||
def _create_model(self):
|
||||
"""Creates a BERT-based classifier model.
|
||||
|
@ -448,11 +473,20 @@ class _BertClassifier(TextClassifier):
|
|||
"""
|
||||
encoder_inputs = dict(
|
||||
input_word_ids=tf.keras.layers.Input(
|
||||
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
||||
shape=(self._model_options.seq_len,),
|
||||
dtype=tf.int32,
|
||||
name="input_word_ids",
|
||||
),
|
||||
input_mask=tf.keras.layers.Input(
|
||||
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
||||
shape=(self._model_options.seq_len,),
|
||||
dtype=tf.int32,
|
||||
name="input_mask",
|
||||
),
|
||||
input_type_ids=tf.keras.layers.Input(
|
||||
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
||||
shape=(self._model_options.seq_len,),
|
||||
dtype=tf.int32,
|
||||
name="input_type_ids",
|
||||
),
|
||||
)
|
||||
encoder = hub.KerasLayer(
|
||||
self._model_spec.downloaded_files.get_path(),
|
||||
|
@ -494,16 +528,21 @@ class _BertClassifier(TextClassifier):
|
|||
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
|
||||
initial_learning_rate=initial_lr,
|
||||
decay_steps=total_steps,
|
||||
end_learning_rate=0.0,
|
||||
power=1.0)
|
||||
end_learning_rate=self._hparams.end_learning_rate,
|
||||
power=1.0,
|
||||
)
|
||||
if warmup_steps:
|
||||
lr_schedule = model_util.WarmUp(
|
||||
initial_learning_rate=initial_lr,
|
||||
decay_schedule_fn=lr_schedule,
|
||||
warmup_steps=warmup_steps)
|
||||
warmup_steps=warmup_steps,
|
||||
)
|
||||
if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
|
||||
self._optimizer = tf.keras.optimizers.experimental.AdamW(
|
||||
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0
|
||||
lr_schedule,
|
||||
weight_decay=self._hparams.weight_decay,
|
||||
epsilon=1e-6,
|
||||
global_clipnorm=1.0,
|
||||
)
|
||||
self._optimizer.exclude_from_weight_decay(
|
||||
var_names=["LayerNorm", "layer_norm", "bias"]
|
||||
|
@ -511,7 +550,7 @@ class _BertClassifier(TextClassifier):
|
|||
elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
|
||||
self._optimizer = tfa_optimizers.LAMB(
|
||||
lr_schedule,
|
||||
weight_decay_rate=0.01,
|
||||
weight_decay_rate=self._hparams.weight_decay,
|
||||
epsilon=1e-6,
|
||||
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
|
||||
global_clipnorm=1.0,
|
||||
|
|
|
@ -16,17 +16,17 @@ import csv
|
|||
import filecmp
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||
from mediapipe.model_maker.python.text import text_classifier
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
|
||||
@unittest.skip('b/275624089')
|
||||
class TextClassifierTest(tf.test.TestCase):
|
||||
class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
||||
test_utils.get_test_data_path('average_word_embedding_metadata.json'))
|
||||
|
@ -78,8 +78,8 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||
options))
|
||||
|
||||
_, accuracy = average_word_embedding_classifier.evaluate(validation_data)
|
||||
self.assertGreaterEqual(accuracy, 0.0)
|
||||
metrics = average_word_embedding_classifier.evaluate(validation_data)
|
||||
self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy
|
||||
|
||||
# Test export_model
|
||||
average_word_embedding_classifier.export_model()
|
||||
|
@ -98,12 +98,25 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
filecmp.cmp(
|
||||
output_metadata_file,
|
||||
self._AVERAGE_WORD_EMBEDDING_JSON_FILE,
|
||||
shallow=False))
|
||||
shallow=False,
|
||||
)
|
||||
)
|
||||
|
||||
def test_create_and_train_bert(self):
|
||||
@parameterized.named_parameters(
|
||||
# Skipping mobilebert b/c OSS test timeout/flakiness: b/275624089
|
||||
# dict(
|
||||
# testcase_name='mobilebert',
|
||||
# supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||
# ),
|
||||
dict(
|
||||
testcase_name='exbert',
|
||||
supported_model=text_classifier.SupportedModels.EXBERT_CLASSIFIER,
|
||||
),
|
||||
)
|
||||
def test_create_and_train_bert(self, supported_model):
|
||||
train_data, validation_data = self._get_data()
|
||||
options = text_classifier.TextClassifierOptions(
|
||||
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||
supported_model=supported_model,
|
||||
model_options=text_classifier.BertModelOptions(
|
||||
do_fine_tuning=False, seq_len=2
|
||||
),
|
||||
|
@ -117,8 +130,8 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
bert_classifier = text_classifier.TextClassifier.create(
|
||||
train_data, validation_data, options)
|
||||
|
||||
_, accuracy = bert_classifier.evaluate(validation_data)
|
||||
self.assertGreaterEqual(accuracy, 0.0)
|
||||
metrics = bert_classifier.evaluate(validation_data)
|
||||
self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy
|
||||
|
||||
# Test export_model
|
||||
bert_classifier.export_model()
|
||||
|
@ -142,45 +155,93 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
)
|
||||
|
||||
def test_label_mismatch(self):
|
||||
options = (
|
||||
text_classifier.TextClassifierOptions(
|
||||
supported_model=(
|
||||
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER)))
|
||||
options = text_classifier.TextClassifierOptions(
|
||||
supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER)
|
||||
)
|
||||
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||
train_data = text_classifier.Dataset(train_tf_dataset, 1, ['foo'])
|
||||
train_data = text_classifier.Dataset(train_tf_dataset, ['foo'], 1)
|
||||
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||
validation_data = text_classifier.Dataset(validation_tf_dataset, 1, ['bar'])
|
||||
validation_data = text_classifier.Dataset(validation_tf_dataset, ['bar'], 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Training data label names .* not equal to validation data label names'
|
||||
'Training data label names .* not equal to validation data label names',
|
||||
):
|
||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||
options)
|
||||
text_classifier.TextClassifier.create(
|
||||
train_data, validation_data, options
|
||||
)
|
||||
|
||||
def test_options_mismatch(self):
|
||||
train_data, validation_data = self._get_data()
|
||||
|
||||
avg_options = (
|
||||
text_classifier.TextClassifierOptions(
|
||||
supported_model=(
|
||||
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER),
|
||||
model_options=text_classifier.AverageWordEmbeddingModelOptions()))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
||||
' SupportedModels.MOBILEBERT_CLASSIFIER'):
|
||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||
avg_options)
|
||||
avg_options = text_classifier.TextClassifierOptions(
|
||||
supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER),
|
||||
model_options=text_classifier.AverageWordEmbeddingModelOptions(),
|
||||
)
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
||||
' SupportedModels.EXBERT_CLASSIFIER',
|
||||
):
|
||||
text_classifier.TextClassifier.create(
|
||||
train_data, validation_data, avg_options
|
||||
)
|
||||
|
||||
bert_options = (
|
||||
text_classifier.TextClassifierOptions(
|
||||
supported_model=(text_classifier.SupportedModels
|
||||
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
||||
model_options=text_classifier.BertModelOptions()))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
|
||||
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):
|
||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||
bert_options)
|
||||
bert_options = text_classifier.TextClassifierOptions(
|
||||
supported_model=(
|
||||
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
|
||||
),
|
||||
model_options=text_classifier.BertModelOptions(),
|
||||
)
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
'Expected a Bert Classifier(MobileBERT or EXBERT), got'
|
||||
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER',
|
||||
):
|
||||
text_classifier.TextClassifier.create(
|
||||
train_data, validation_data, bert_options
|
||||
)
|
||||
|
||||
def test_bert_loss_and_metrics_creation(self):
|
||||
train_data, validation_data = self._get_data()
|
||||
supported_model = text_classifier.SupportedModels.EXBERT_CLASSIFIER
|
||||
hparams = text_classifier.BertHParams(
|
||||
desired_recalls=[0.2],
|
||||
desired_precisions=[0.9],
|
||||
epochs=1,
|
||||
batch_size=1,
|
||||
learning_rate=3e-5,
|
||||
distribution_strategy='off',
|
||||
gamma=3.5,
|
||||
)
|
||||
options = text_classifier.TextClassifierOptions(
|
||||
supported_model=supported_model, hparams=hparams
|
||||
)
|
||||
bert_classifier = text_classifier.TextClassifier.create(
|
||||
train_data, validation_data, options
|
||||
)
|
||||
loss_fn = bert_classifier._loss_function
|
||||
self.assertIsInstance(loss_fn, loss_functions.SparseFocalLoss)
|
||||
self.assertEqual(loss_fn._gamma, 3.5)
|
||||
self.assertEqual(loss_fn._num_classes, 2)
|
||||
metric_names = [m.name for m in bert_classifier._metric_functions]
|
||||
expected_metric_names = [
|
||||
'accuracy',
|
||||
'recall',
|
||||
'precision',
|
||||
'precision_at_recall_0.2',
|
||||
'recall_at_precision_0.9',
|
||||
]
|
||||
self.assertCountEqual(metric_names, expected_metric_names)
|
||||
|
||||
# Non-binary data
|
||||
tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||
data = text_classifier.Dataset(tf_dataset, ['foo', 'bar', 'baz'], 1)
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
'desired_recalls and desired_precisions parameters are binary metrics'
|
||||
' and not supported for num_classes > 2. Found num_classes: 3',
|
||||
):
|
||||
text_classifier.TextClassifier.create(data, data, options)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue
Block a user