1. Move evaluation onto GPU/TPU hardware if available.

2. Move desired_precision and desired_recall from evaluate to hyperparameters so recall@precision metrics will be reported for both training and evaluation. This also fixes a bug where recompiling the model with the previously initialized metric objects would not properly reset the metric states.
3. Remove redundant label_names from create_... class methods in text_classifier. This information is already provided by the datasets.
4. Change loss function to FocalLoss.
5. Re-enable text_classifier unit tests using ExBert
6. Add input names to avoid flaky auto-assigned input names.

PiperOrigin-RevId: 550992146
This commit is contained in:
MediaPipe Team 2023-07-25 14:10:03 -07:00 committed by Copybara-Service
parent 85c3fed70a
commit bd7888cc0c
9 changed files with 294 additions and 120 deletions

View File

@ -59,7 +59,7 @@ class FocalLoss(tf.keras.losses.Loss):
""" """
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None): def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
"""Constructor. """Initializes FocalLoss.
Args: Args:
gamma: Focal loss gamma, as described in class docs. gamma: Focal loss gamma, as described in class docs.
@ -115,6 +115,51 @@ class FocalLoss(tf.keras.losses.Loss):
return tf.reduce_sum(losses) / batch_size return tf.reduce_sum(losses) / batch_size
class SparseFocalLoss(FocalLoss):
"""Sparse implementation of Focal Loss.
This is the same as FocalLoss, except the labels are expected to be class ids
instead of 1-hot encoded vectors. See FocalLoss class documentation defined
in this same file for more details.
Example usage:
>>> y_true = [1, 2]
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
>>> gamma = 2
>>> focal_loss = SparseFocalLoss(gamma, 3)
>>> focal_loss(y_true, y_pred).numpy()
0.9326
>>> # Calling with 'sample_weight'.
>>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
0.6528
"""
def __init__(
self, gamma, num_classes, class_weight: Optional[Sequence[float]] = None
):
"""Initializes SparseFocalLoss.
Args:
gamma: Focal loss gamma, as described in class docs.
num_classes: Number of classes.
class_weight: A weight to apply to the loss, one for each class. The
weight is applied for each input where the ground truth label matches.
"""
super().__init__(gamma, class_weight=class_weight)
self._num_classes = num_classes
def __call__(
self,
y_true: tf.Tensor,
y_pred: tf.Tensor,
sample_weight: Optional[tf.Tensor] = None,
) -> tf.Tensor:
y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32)
y_true_one_hot = tf.one_hot(y_true, self._num_classes)
return super().__call__(y_true_one_hot, y_pred, sample_weight=sample_weight)
@dataclasses.dataclass @dataclasses.dataclass
class PerceptualLossWeight: class PerceptualLossWeight:
"""The weight for each perceptual loss. """The weight for each perceptual loss.

View File

@ -101,6 +101,23 @@ class FocalLossTest(tf.test.TestCase, parameterized.TestCase):
self.assertNear(loss, expected_loss, 1e-4) self.assertNear(loss, expected_loss, 1e-4)
class SparseFocalLossTest(tf.test.TestCase):
def test_sparse_focal_loss_matches_focal_loss(self):
num_classes = 2
y_pred = tf.constant([[0.8, 0.2], [0.3, 0.7]])
y_true = tf.constant([1, 0])
y_true_one_hot = tf.one_hot(y_true, num_classes)
for gamma in [0.0, 0.5, 1.0]:
expected_loss_fn = loss_functions.FocalLoss(gamma=gamma)
loss_fn = loss_functions.SparseFocalLoss(
gamma=gamma, num_classes=num_classes
)
expected_loss = expected_loss_fn(y_true_one_hot, y_pred)
loss = loss_fn(y_true, y_pred)
self.assertNear(loss, expected_loss, 1e-4)
class MockPerceptualLoss(loss_functions.PerceptualLoss): class MockPerceptualLoss(loss_functions.PerceptualLoss):
"""A mock class with implementation of abstract methods for testing.""" """A mock class with implementation of abstract methods for testing."""

View File

@ -131,6 +131,7 @@ py_library(
":text_classifier_options", ":text_classifier_options",
"//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/data:dataset",
"//mediapipe/model_maker/python/core/tasks:classifier", "//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:loss_functions",
"//mediapipe/model_maker/python/core/utils:metrics", "//mediapipe/model_maker/python/core/utils:metrics",
"//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization", "//mediapipe/model_maker/python/core/utils:quantization",
@ -154,6 +155,7 @@ py_test(
], ],
deps = [ deps = [
":text_classifier_import", ":text_classifier_import",
"//mediapipe/model_maker/python/core/utils:loss_functions",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],
) )

View File

@ -15,7 +15,7 @@
import dataclasses import dataclasses
import enum import enum
from typing import Union from typing import Sequence, Union
from mediapipe.model_maker.python.core import hyperparameters as hp from mediapipe.model_maker.python.core import hyperparameters as hp
@ -39,16 +39,34 @@ class BertHParams(hp.BaseHParams):
Attributes: Attributes:
learning_rate: Learning rate to use for gradient descent training. learning_rate: Learning rate to use for gradient descent training.
batch_size: Batch size for training. end_learning_rate: End learning rate for linear decay. Defaults to 0.
epochs: Number of training iterations over the dataset. batch_size: Batch size for training. Defaults to 48.
optimizer: Optimizer to use for training. Only supported values are "adamw" epochs: Number of training iterations over the dataset. Defaults to 2.
and "lamb". optimizer: Optimizer to use for training. Supported values are defined in
BertOptimizer enum: ADAMW and LAMB.
weight_decay: Weight decay of the optimizer. Defaults to 0.01.
desired_precisions: If specified, adds a RecallAtPrecision metric per
desired_precisions[i] entry which tracks the recall given the constraint
on precision. Only supported for binary classification.
desired_recalls: If specified, adds a PrecisionAtRecall metric per
desired_recalls[i] entry which tracks the precision given the constraint
on recall. Only supported for binary classification.
gamma: Gamma parameter for focal loss. To use cross entropy loss, set this
value to 0. Defaults to 2.0.
""" """
learning_rate: float = 3e-5 learning_rate: float = 3e-5
end_learning_rate: float = 0.0
batch_size: int = 48 batch_size: int = 48
epochs: int = 2 epochs: int = 2
optimizer: BertOptimizer = BertOptimizer.ADAMW optimizer: BertOptimizer = BertOptimizer.ADAMW
weight_decay: float = 0.01
desired_precisions: Sequence[float] = dataclasses.field(default_factory=list)
desired_recalls: Sequence[float] = dataclasses.field(default_factory=list)
gamma: float = 2.0
HParams = Union[BertHParams, AverageWordEmbeddingHParams] HParams = Union[BertHParams, AverageWordEmbeddingHParams]

View File

@ -79,11 +79,6 @@ mobilebert_classifier_spec = functools.partial(
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
), ),
name='MobileBert', name='MobileBert',
tflite_input_name={
'ids': 'serving_default_input_1:0',
'segment_ids': 'serving_default_input_2:0',
'mask': 'serving_default_input_3:0',
},
) )
exbert_classifier_spec = functools.partial( exbert_classifier_spec = functools.partial(
@ -93,11 +88,6 @@ exbert_classifier_spec = functools.partial(
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
), ),
name='ExBert', name='ExBert',
tflite_input_name={
'ids': 'serving_default_input_1:0',
'segment_ids': 'serving_default_input_2:0',
'mask': 'serving_default_input_3:0',
},
) )

View File

@ -46,11 +46,13 @@ class ModelSpecTest(tf.test.TestCase):
self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path())) self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path()))
self.assertTrue(model_spec_obj.do_lower_case) self.assertTrue(model_spec_obj.do_lower_case)
self.assertEqual( self.assertEqual(
model_spec_obj.tflite_input_name, { model_spec_obj.tflite_input_name,
'ids': 'serving_default_input_1:0', {
'mask': 'serving_default_input_3:0', 'ids': 'serving_default_input_word_ids:0',
'segment_ids': 'serving_default_input_2:0' 'mask': 'serving_default_input_mask:0',
}) 'segment_ids': 'serving_default_input_type_ids:0',
},
)
self.assertEqual( self.assertEqual(
model_spec_obj.model_options, model_spec_obj.model_options,
classifier_model_options.BertModelOptions( classifier_model_options.BertModelOptions(

View File

@ -16,8 +16,8 @@
} }
}, },
{ {
"name": "mask", "name": "segment_ids",
"description": "Mask with 1 for real tokens and 0 for padding tokens.", "description": "0 for the first sequence, 1 for the second sequence if exists.",
"content": { "content": {
"content_properties_type": "FeatureProperties", "content_properties_type": "FeatureProperties",
"content_properties": { "content_properties": {
@ -27,8 +27,8 @@
} }
}, },
{ {
"name": "segment_ids", "name": "mask",
"description": "0 for the first sequence, 1 for the second sequence if exists.", "description": "Mask with 1 for real tokens and 0 for padding tokens.",
"content": { "content": {
"content_properties_type": "FeatureProperties", "content_properties_type": "FeatureProperties",
"content_properties": { "content_properties": {

View File

@ -24,6 +24,7 @@ import tensorflow_hub as hub
from mediapipe.model_maker.python.core.data import dataset as ds from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.tasks import classifier from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import loss_functions
from mediapipe.model_maker.python.core.utils import metrics from mediapipe.model_maker.python.core.utils import metrics
from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import quantization
@ -116,17 +117,14 @@ class TextClassifier(classifier.Classifier):
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
): ):
text_classifier = ( text_classifier = _BertClassifier.create_bert_classifier(
_BertClassifier.create_bert_classifier(train_data, validation_data, train_data, validation_data, options
options, )
train_data.label_names))
elif (options.supported_model == elif (options.supported_model ==
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER): ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
text_classifier = ( text_classifier = _AverageWordEmbeddingClassifier.create_average_word_embedding_classifier(
_AverageWordEmbeddingClassifier train_data, validation_data, options
.create_average_word_embedding_classifier(train_data, validation_data, )
options,
train_data.label_names))
else: else:
raise ValueError(f"Unknown model {options.supported_model}") raise ValueError(f"Unknown model {options.supported_model}")
@ -166,28 +164,8 @@ class TextClassifier(classifier.Classifier):
processed_data = self._text_preprocessor.preprocess(data) processed_data = self._text_preprocessor.preprocess(data)
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False) dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
additional_metrics = [] with self._hparams.get_strategy().scope():
if desired_precisions and len(data.label_names) == 2: return self._model.evaluate(dataset)
for precision in desired_precisions:
additional_metrics.append(
metrics.BinarySparseRecallAtPrecision(
precision, name=f"recall_at_precision_{precision}"
)
)
if desired_recalls and len(data.label_names) == 2:
for recall in desired_recalls:
additional_metrics.append(
metrics.BinarySparsePrecisionAtRecall(
recall, name=f"precision_at_recall_{recall}"
)
)
metric_functions = self._metric_functions + additional_metrics
self._model.compile(
optimizer=self._optimizer,
loss=self._loss_function,
metrics=metric_functions,
)
return self._model.evaluate(dataset)
def export_model( def export_model(
self, self,
@ -255,16 +233,17 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
@classmethod @classmethod
def create_average_word_embedding_classifier( def create_average_word_embedding_classifier(
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset, cls,
train_data: text_ds.Dataset,
validation_data: text_ds.Dataset,
options: text_classifier_options.TextClassifierOptions, options: text_classifier_options.TextClassifierOptions,
label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier": ) -> "_AverageWordEmbeddingClassifier":
"""Creates, trains, and returns an Average Word Embedding classifier. """Creates, trains, and returns an Average Word Embedding classifier.
Args: Args:
train_data: Training data. train_data: Training data.
validation_data: Validation data. validation_data: Validation data.
options: Options for creating and training the text classifier. options: Options for creating and training the text classifier.
label_names: Label names used in the data.
Returns: Returns:
An Average Word Embedding classifier. An Average Word Embedding classifier.
@ -370,28 +349,25 @@ class _BertClassifier(TextClassifier):
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._model_options = model_options self._model_options = model_options
with self._hparams.get_strategy().scope(): with self._hparams.get_strategy().scope():
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy() self._loss_function = loss_functions.SparseFocalLoss(
self._metric_functions = [ self._hparams.gamma, self._num_classes
tf.keras.metrics.SparseCategoricalAccuracy( )
"test_accuracy", dtype=tf.float32 self._metric_functions = self._create_metrics()
), self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
metrics.SparsePrecision(name="precision", dtype=tf.float32),
metrics.SparseRecall(name="recall", dtype=tf.float32),
]
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
@classmethod @classmethod
def create_bert_classifier( def create_bert_classifier(
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset, cls,
train_data: text_ds.Dataset,
validation_data: text_ds.Dataset,
options: text_classifier_options.TextClassifierOptions, options: text_classifier_options.TextClassifierOptions,
label_names: Sequence[str]) -> "_BertClassifier": ) -> "_BertClassifier":
"""Creates, trains, and returns a BERT-based classifier. """Creates, trains, and returns a BERT-based classifier.
Args: Args:
train_data: Training data. train_data: Training data.
validation_data: Validation data. validation_data: Validation data.
options: Options for creating and training the text classifier. options: Options for creating and training the text classifier.
label_names: Label names used in the data.
Returns: Returns:
A BERT-based classifier. A BERT-based classifier.
@ -437,8 +413,57 @@ class _BertClassifier(TextClassifier):
uri=self._model_spec.downloaded_files.get_path(), uri=self._model_spec.downloaded_files.get_path(),
model_name=self._model_spec.name, model_name=self._model_spec.name,
) )
return (self._text_preprocessor.preprocess(train_data), return (
self._text_preprocessor.preprocess(validation_data)) self._text_preprocessor.preprocess(train_data),
self._text_preprocessor.preprocess(validation_data),
)
def _create_metrics(self):
"""Creates metrics for training and evaluation.
The default metrics are accuracy, precision, and recall.
For binary classification tasks only (num_classes=2):
Users can configure PrecisionAtRecall and RecallAtPrecision metrics using
the desired_presisions and desired_recalls fields in BertHParams.
Returns:
A list of tf.keras.Metric subclasses which can be used with model.compile
"""
metric_functions = [
tf.keras.metrics.SparseCategoricalAccuracy(
"accuracy", dtype=tf.float32
),
metrics.SparsePrecision(name="precision", dtype=tf.float32),
metrics.SparseRecall(name="recall", dtype=tf.float32),
]
if self._num_classes == 2:
if self._hparams.desired_precisions:
for desired_precision in self._hparams.desired_precisions:
metric_functions.append(
metrics.BinarySparseRecallAtPrecision(
desired_precision,
name=f"recall_at_precision_{desired_precision}",
num_thresholds=1000,
)
)
if self._hparams.desired_recalls:
for desired_recall in self._hparams.desired_recalls:
metric_functions.append(
metrics.BinarySparseRecallAtPrecision(
desired_recall,
name=f"precision_at_recall_{desired_recall}",
num_thresholds=1000,
)
)
else:
if self._hparams.desired_precisions or self._hparams.desired_recalls:
raise ValueError(
"desired_recalls and desired_precisions parameters are binary"
" metrics and not supported for num_classes > 2. Found"
f" num_classes: {self._num_classes}"
)
return metric_functions
def _create_model(self): def _create_model(self):
"""Creates a BERT-based classifier model. """Creates a BERT-based classifier model.
@ -448,11 +473,20 @@ class _BertClassifier(TextClassifier):
""" """
encoder_inputs = dict( encoder_inputs = dict(
input_word_ids=tf.keras.layers.Input( input_word_ids=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32), shape=(self._model_options.seq_len,),
dtype=tf.int32,
name="input_word_ids",
),
input_mask=tf.keras.layers.Input( input_mask=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32), shape=(self._model_options.seq_len,),
dtype=tf.int32,
name="input_mask",
),
input_type_ids=tf.keras.layers.Input( input_type_ids=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32), shape=(self._model_options.seq_len,),
dtype=tf.int32,
name="input_type_ids",
),
) )
encoder = hub.KerasLayer( encoder = hub.KerasLayer(
self._model_spec.downloaded_files.get_path(), self._model_spec.downloaded_files.get_path(),
@ -494,16 +528,21 @@ class _BertClassifier(TextClassifier):
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=initial_lr, initial_learning_rate=initial_lr,
decay_steps=total_steps, decay_steps=total_steps,
end_learning_rate=0.0, end_learning_rate=self._hparams.end_learning_rate,
power=1.0) power=1.0,
)
if warmup_steps: if warmup_steps:
lr_schedule = model_util.WarmUp( lr_schedule = model_util.WarmUp(
initial_learning_rate=initial_lr, initial_learning_rate=initial_lr,
decay_schedule_fn=lr_schedule, decay_schedule_fn=lr_schedule,
warmup_steps=warmup_steps) warmup_steps=warmup_steps,
)
if self._hparams.optimizer == hp.BertOptimizer.ADAMW: if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
self._optimizer = tf.keras.optimizers.experimental.AdamW( self._optimizer = tf.keras.optimizers.experimental.AdamW(
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0 lr_schedule,
weight_decay=self._hparams.weight_decay,
epsilon=1e-6,
global_clipnorm=1.0,
) )
self._optimizer.exclude_from_weight_decay( self._optimizer.exclude_from_weight_decay(
var_names=["LayerNorm", "layer_norm", "bias"] var_names=["LayerNorm", "layer_norm", "bias"]
@ -511,7 +550,7 @@ class _BertClassifier(TextClassifier):
elif self._hparams.optimizer == hp.BertOptimizer.LAMB: elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
self._optimizer = tfa_optimizers.LAMB( self._optimizer = tfa_optimizers.LAMB(
lr_schedule, lr_schedule,
weight_decay_rate=0.01, weight_decay_rate=self._hparams.weight_decay,
epsilon=1e-6, epsilon=1e-6,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
global_clipnorm=1.0, global_clipnorm=1.0,

View File

@ -16,17 +16,17 @@ import csv
import filecmp import filecmp
import os import os
import tempfile import tempfile
import unittest
from unittest import mock as unittest_mock from unittest import mock as unittest_mock
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.utils import loss_functions
from mediapipe.model_maker.python.text import text_classifier from mediapipe.model_maker.python.text import text_classifier
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
@unittest.skip('b/275624089') class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
class TextClassifierTest(tf.test.TestCase):
_AVERAGE_WORD_EMBEDDING_JSON_FILE = ( _AVERAGE_WORD_EMBEDDING_JSON_FILE = (
test_utils.get_test_data_path('average_word_embedding_metadata.json')) test_utils.get_test_data_path('average_word_embedding_metadata.json'))
@ -78,8 +78,8 @@ class TextClassifierTest(tf.test.TestCase):
text_classifier.TextClassifier.create(train_data, validation_data, text_classifier.TextClassifier.create(train_data, validation_data,
options)) options))
_, accuracy = average_word_embedding_classifier.evaluate(validation_data) metrics = average_word_embedding_classifier.evaluate(validation_data)
self.assertGreaterEqual(accuracy, 0.0) self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy
# Test export_model # Test export_model
average_word_embedding_classifier.export_model() average_word_embedding_classifier.export_model()
@ -98,12 +98,25 @@ class TextClassifierTest(tf.test.TestCase):
filecmp.cmp( filecmp.cmp(
output_metadata_file, output_metadata_file,
self._AVERAGE_WORD_EMBEDDING_JSON_FILE, self._AVERAGE_WORD_EMBEDDING_JSON_FILE,
shallow=False)) shallow=False,
)
)
def test_create_and_train_bert(self): @parameterized.named_parameters(
# Skipping mobilebert b/c OSS test timeout/flakiness: b/275624089
# dict(
# testcase_name='mobilebert',
# supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
# ),
dict(
testcase_name='exbert',
supported_model=text_classifier.SupportedModels.EXBERT_CLASSIFIER,
),
)
def test_create_and_train_bert(self, supported_model):
train_data, validation_data = self._get_data() train_data, validation_data = self._get_data()
options = text_classifier.TextClassifierOptions( options = text_classifier.TextClassifierOptions(
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER, supported_model=supported_model,
model_options=text_classifier.BertModelOptions( model_options=text_classifier.BertModelOptions(
do_fine_tuning=False, seq_len=2 do_fine_tuning=False, seq_len=2
), ),
@ -117,8 +130,8 @@ class TextClassifierTest(tf.test.TestCase):
bert_classifier = text_classifier.TextClassifier.create( bert_classifier = text_classifier.TextClassifier.create(
train_data, validation_data, options) train_data, validation_data, options)
_, accuracy = bert_classifier.evaluate(validation_data) metrics = bert_classifier.evaluate(validation_data)
self.assertGreaterEqual(accuracy, 0.0) self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy
# Test export_model # Test export_model
bert_classifier.export_model() bert_classifier.export_model()
@ -142,45 +155,93 @@ class TextClassifierTest(tf.test.TestCase):
) )
def test_label_mismatch(self): def test_label_mismatch(self):
options = ( options = text_classifier.TextClassifierOptions(
text_classifier.TextClassifierOptions( supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER)
supported_model=( )
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER)))
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
train_data = text_classifier.Dataset(train_tf_dataset, 1, ['foo']) train_data = text_classifier.Dataset(train_tf_dataset, ['foo'], 1)
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
validation_data = text_classifier.Dataset(validation_tf_dataset, 1, ['bar']) validation_data = text_classifier.Dataset(validation_tf_dataset, ['bar'], 1)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
'Training data label names .* not equal to validation data label names' 'Training data label names .* not equal to validation data label names',
): ):
text_classifier.TextClassifier.create(train_data, validation_data, text_classifier.TextClassifier.create(
options) train_data, validation_data, options
)
def test_options_mismatch(self): def test_options_mismatch(self):
train_data, validation_data = self._get_data() train_data, validation_data = self._get_data()
avg_options = ( avg_options = text_classifier.TextClassifierOptions(
text_classifier.TextClassifierOptions( supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER),
supported_model=( model_options=text_classifier.AverageWordEmbeddingModelOptions(),
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER), )
model_options=text_classifier.AverageWordEmbeddingModelOptions())) with self.assertRaisesWithLiteralMatch(
with self.assertRaisesRegex( ValueError,
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got' 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
' SupportedModels.MOBILEBERT_CLASSIFIER'): ' SupportedModels.EXBERT_CLASSIFIER',
text_classifier.TextClassifier.create(train_data, validation_data, ):
avg_options) text_classifier.TextClassifier.create(
train_data, validation_data, avg_options
)
bert_options = ( bert_options = text_classifier.TextClassifierOptions(
text_classifier.TextClassifierOptions( supported_model=(
supported_model=(text_classifier.SupportedModels text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
.AVERAGE_WORD_EMBEDDING_CLASSIFIER), ),
model_options=text_classifier.BertModelOptions())) model_options=text_classifier.BertModelOptions(),
with self.assertRaisesRegex( )
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got' with self.assertRaisesWithLiteralMatch(
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'): ValueError,
text_classifier.TextClassifier.create(train_data, validation_data, 'Expected a Bert Classifier(MobileBERT or EXBERT), got'
bert_options) ' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER',
):
text_classifier.TextClassifier.create(
train_data, validation_data, bert_options
)
def test_bert_loss_and_metrics_creation(self):
train_data, validation_data = self._get_data()
supported_model = text_classifier.SupportedModels.EXBERT_CLASSIFIER
hparams = text_classifier.BertHParams(
desired_recalls=[0.2],
desired_precisions=[0.9],
epochs=1,
batch_size=1,
learning_rate=3e-5,
distribution_strategy='off',
gamma=3.5,
)
options = text_classifier.TextClassifierOptions(
supported_model=supported_model, hparams=hparams
)
bert_classifier = text_classifier.TextClassifier.create(
train_data, validation_data, options
)
loss_fn = bert_classifier._loss_function
self.assertIsInstance(loss_fn, loss_functions.SparseFocalLoss)
self.assertEqual(loss_fn._gamma, 3.5)
self.assertEqual(loss_fn._num_classes, 2)
metric_names = [m.name for m in bert_classifier._metric_functions]
expected_metric_names = [
'accuracy',
'recall',
'precision',
'precision_at_recall_0.2',
'recall_at_precision_0.9',
]
self.assertCountEqual(metric_names, expected_metric_names)
# Non-binary data
tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
data = text_classifier.Dataset(tf_dataset, ['foo', 'bar', 'baz'], 1)
with self.assertRaisesWithLiteralMatch(
ValueError,
'desired_recalls and desired_precisions parameters are binary metrics'
' and not supported for num_classes > 2. Found num_classes: 3',
):
text_classifier.TextClassifier.create(data, data, options)
if __name__ == '__main__': if __name__ == '__main__':