1. Move evaluation onto GPU/TPU hardware if available.
2. Move desired_precision and desired_recall from evaluate to hyperparameters so recall@precision metrics will be reported for both training and evaluation. This also fixes a bug where recompiling the model with the previously initialized metric objects would not properly reset the metric states. 3. Remove redundant label_names from create_... class methods in text_classifier. This information is already provided by the datasets. 4. Change loss function to FocalLoss. 5. Re-enable text_classifier unit tests using ExBert 6. Add input names to avoid flaky auto-assigned input names. PiperOrigin-RevId: 550992146
This commit is contained in:
parent
85c3fed70a
commit
bd7888cc0c
|
@ -59,7 +59,7 @@ class FocalLoss(tf.keras.losses.Loss):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
|
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
|
||||||
"""Constructor.
|
"""Initializes FocalLoss.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
gamma: Focal loss gamma, as described in class docs.
|
gamma: Focal loss gamma, as described in class docs.
|
||||||
|
@ -115,6 +115,51 @@ class FocalLoss(tf.keras.losses.Loss):
|
||||||
return tf.reduce_sum(losses) / batch_size
|
return tf.reduce_sum(losses) / batch_size
|
||||||
|
|
||||||
|
|
||||||
|
class SparseFocalLoss(FocalLoss):
|
||||||
|
"""Sparse implementation of Focal Loss.
|
||||||
|
|
||||||
|
This is the same as FocalLoss, except the labels are expected to be class ids
|
||||||
|
instead of 1-hot encoded vectors. See FocalLoss class documentation defined
|
||||||
|
in this same file for more details.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
>>> y_true = [1, 2]
|
||||||
|
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
|
||||||
|
>>> gamma = 2
|
||||||
|
>>> focal_loss = SparseFocalLoss(gamma, 3)
|
||||||
|
>>> focal_loss(y_true, y_pred).numpy()
|
||||||
|
0.9326
|
||||||
|
|
||||||
|
>>> # Calling with 'sample_weight'.
|
||||||
|
>>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
|
||||||
|
0.6528
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, gamma, num_classes, class_weight: Optional[Sequence[float]] = None
|
||||||
|
):
|
||||||
|
"""Initializes SparseFocalLoss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gamma: Focal loss gamma, as described in class docs.
|
||||||
|
num_classes: Number of classes.
|
||||||
|
class_weight: A weight to apply to the loss, one for each class. The
|
||||||
|
weight is applied for each input where the ground truth label matches.
|
||||||
|
"""
|
||||||
|
super().__init__(gamma, class_weight=class_weight)
|
||||||
|
self._num_classes = num_classes
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
y_true: tf.Tensor,
|
||||||
|
y_pred: tf.Tensor,
|
||||||
|
sample_weight: Optional[tf.Tensor] = None,
|
||||||
|
) -> tf.Tensor:
|
||||||
|
y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32)
|
||||||
|
y_true_one_hot = tf.one_hot(y_true, self._num_classes)
|
||||||
|
return super().__call__(y_true_one_hot, y_pred, sample_weight=sample_weight)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class PerceptualLossWeight:
|
class PerceptualLossWeight:
|
||||||
"""The weight for each perceptual loss.
|
"""The weight for each perceptual loss.
|
||||||
|
|
|
@ -101,6 +101,23 @@ class FocalLossTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertNear(loss, expected_loss, 1e-4)
|
self.assertNear(loss, expected_loss, 1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseFocalLossTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_sparse_focal_loss_matches_focal_loss(self):
|
||||||
|
num_classes = 2
|
||||||
|
y_pred = tf.constant([[0.8, 0.2], [0.3, 0.7]])
|
||||||
|
y_true = tf.constant([1, 0])
|
||||||
|
y_true_one_hot = tf.one_hot(y_true, num_classes)
|
||||||
|
for gamma in [0.0, 0.5, 1.0]:
|
||||||
|
expected_loss_fn = loss_functions.FocalLoss(gamma=gamma)
|
||||||
|
loss_fn = loss_functions.SparseFocalLoss(
|
||||||
|
gamma=gamma, num_classes=num_classes
|
||||||
|
)
|
||||||
|
expected_loss = expected_loss_fn(y_true_one_hot, y_pred)
|
||||||
|
loss = loss_fn(y_true, y_pred)
|
||||||
|
self.assertNear(loss, expected_loss, 1e-4)
|
||||||
|
|
||||||
|
|
||||||
class MockPerceptualLoss(loss_functions.PerceptualLoss):
|
class MockPerceptualLoss(loss_functions.PerceptualLoss):
|
||||||
"""A mock class with implementation of abstract methods for testing."""
|
"""A mock class with implementation of abstract methods for testing."""
|
||||||
|
|
||||||
|
|
|
@ -131,6 +131,7 @@ py_library(
|
||||||
":text_classifier_options",
|
":text_classifier_options",
|
||||||
"//mediapipe/model_maker/python/core/data:dataset",
|
"//mediapipe/model_maker/python/core/data:dataset",
|
||||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:loss_functions",
|
||||||
"//mediapipe/model_maker/python/core/utils:metrics",
|
"//mediapipe/model_maker/python/core/utils:metrics",
|
||||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||||
|
@ -154,6 +155,7 @@ py_test(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":text_classifier_import",
|
":text_classifier_import",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:loss_functions",
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import enum
|
import enum
|
||||||
from typing import Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
|
||||||
|
@ -39,16 +39,34 @@ class BertHParams(hp.BaseHParams):
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
learning_rate: Learning rate to use for gradient descent training.
|
learning_rate: Learning rate to use for gradient descent training.
|
||||||
batch_size: Batch size for training.
|
end_learning_rate: End learning rate for linear decay. Defaults to 0.
|
||||||
epochs: Number of training iterations over the dataset.
|
batch_size: Batch size for training. Defaults to 48.
|
||||||
optimizer: Optimizer to use for training. Only supported values are "adamw"
|
epochs: Number of training iterations over the dataset. Defaults to 2.
|
||||||
and "lamb".
|
optimizer: Optimizer to use for training. Supported values are defined in
|
||||||
|
BertOptimizer enum: ADAMW and LAMB.
|
||||||
|
weight_decay: Weight decay of the optimizer. Defaults to 0.01.
|
||||||
|
desired_precisions: If specified, adds a RecallAtPrecision metric per
|
||||||
|
desired_precisions[i] entry which tracks the recall given the constraint
|
||||||
|
on precision. Only supported for binary classification.
|
||||||
|
desired_recalls: If specified, adds a PrecisionAtRecall metric per
|
||||||
|
desired_recalls[i] entry which tracks the precision given the constraint
|
||||||
|
on recall. Only supported for binary classification.
|
||||||
|
gamma: Gamma parameter for focal loss. To use cross entropy loss, set this
|
||||||
|
value to 0. Defaults to 2.0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
learning_rate: float = 3e-5
|
learning_rate: float = 3e-5
|
||||||
|
end_learning_rate: float = 0.0
|
||||||
|
|
||||||
batch_size: int = 48
|
batch_size: int = 48
|
||||||
epochs: int = 2
|
epochs: int = 2
|
||||||
optimizer: BertOptimizer = BertOptimizer.ADAMW
|
optimizer: BertOptimizer = BertOptimizer.ADAMW
|
||||||
|
weight_decay: float = 0.01
|
||||||
|
|
||||||
|
desired_precisions: Sequence[float] = dataclasses.field(default_factory=list)
|
||||||
|
desired_recalls: Sequence[float] = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
|
gamma: float = 2.0
|
||||||
|
|
||||||
|
|
||||||
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
|
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
|
||||||
|
|
|
@ -79,11 +79,6 @@ mobilebert_classifier_spec = functools.partial(
|
||||||
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||||
),
|
),
|
||||||
name='MobileBert',
|
name='MobileBert',
|
||||||
tflite_input_name={
|
|
||||||
'ids': 'serving_default_input_1:0',
|
|
||||||
'segment_ids': 'serving_default_input_2:0',
|
|
||||||
'mask': 'serving_default_input_3:0',
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
exbert_classifier_spec = functools.partial(
|
exbert_classifier_spec = functools.partial(
|
||||||
|
@ -93,11 +88,6 @@ exbert_classifier_spec = functools.partial(
|
||||||
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||||
),
|
),
|
||||||
name='ExBert',
|
name='ExBert',
|
||||||
tflite_input_name={
|
|
||||||
'ids': 'serving_default_input_1:0',
|
|
||||||
'segment_ids': 'serving_default_input_2:0',
|
|
||||||
'mask': 'serving_default_input_3:0',
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -46,11 +46,13 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path()))
|
self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path()))
|
||||||
self.assertTrue(model_spec_obj.do_lower_case)
|
self.assertTrue(model_spec_obj.do_lower_case)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
model_spec_obj.tflite_input_name, {
|
model_spec_obj.tflite_input_name,
|
||||||
'ids': 'serving_default_input_1:0',
|
{
|
||||||
'mask': 'serving_default_input_3:0',
|
'ids': 'serving_default_input_word_ids:0',
|
||||||
'segment_ids': 'serving_default_input_2:0'
|
'mask': 'serving_default_input_mask:0',
|
||||||
})
|
'segment_ids': 'serving_default_input_type_ids:0',
|
||||||
|
},
|
||||||
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
model_spec_obj.model_options,
|
model_spec_obj.model_options,
|
||||||
classifier_model_options.BertModelOptions(
|
classifier_model_options.BertModelOptions(
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "mask",
|
"name": "segment_ids",
|
||||||
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
|
"description": "0 for the first sequence, 1 for the second sequence if exists.",
|
||||||
"content": {
|
"content": {
|
||||||
"content_properties_type": "FeatureProperties",
|
"content_properties_type": "FeatureProperties",
|
||||||
"content_properties": {
|
"content_properties": {
|
||||||
|
@ -27,8 +27,8 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "segment_ids",
|
"name": "mask",
|
||||||
"description": "0 for the first sequence, 1 for the second sequence if exists.",
|
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
|
||||||
"content": {
|
"content": {
|
||||||
"content_properties_type": "FeatureProperties",
|
"content_properties_type": "FeatureProperties",
|
||||||
"content_properties": {
|
"content_properties": {
|
||||||
|
|
|
@ -24,6 +24,7 @@ import tensorflow_hub as hub
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||||
from mediapipe.model_maker.python.core.tasks import classifier
|
from mediapipe.model_maker.python.core.tasks import classifier
|
||||||
|
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||||
from mediapipe.model_maker.python.core.utils import metrics
|
from mediapipe.model_maker.python.core.utils import metrics
|
||||||
from mediapipe.model_maker.python.core.utils import model_util
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
from mediapipe.model_maker.python.core.utils import quantization
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
|
@ -116,17 +117,14 @@ class TextClassifier(classifier.Classifier):
|
||||||
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||||
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
|
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
|
||||||
):
|
):
|
||||||
text_classifier = (
|
text_classifier = _BertClassifier.create_bert_classifier(
|
||||||
_BertClassifier.create_bert_classifier(train_data, validation_data,
|
train_data, validation_data, options
|
||||||
options,
|
)
|
||||||
train_data.label_names))
|
|
||||||
elif (options.supported_model ==
|
elif (options.supported_model ==
|
||||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
||||||
text_classifier = (
|
text_classifier = _AverageWordEmbeddingClassifier.create_average_word_embedding_classifier(
|
||||||
_AverageWordEmbeddingClassifier
|
train_data, validation_data, options
|
||||||
.create_average_word_embedding_classifier(train_data, validation_data,
|
)
|
||||||
options,
|
|
||||||
train_data.label_names))
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown model {options.supported_model}")
|
raise ValueError(f"Unknown model {options.supported_model}")
|
||||||
|
|
||||||
|
@ -166,27 +164,7 @@ class TextClassifier(classifier.Classifier):
|
||||||
processed_data = self._text_preprocessor.preprocess(data)
|
processed_data = self._text_preprocessor.preprocess(data)
|
||||||
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
|
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
|
||||||
|
|
||||||
additional_metrics = []
|
with self._hparams.get_strategy().scope():
|
||||||
if desired_precisions and len(data.label_names) == 2:
|
|
||||||
for precision in desired_precisions:
|
|
||||||
additional_metrics.append(
|
|
||||||
metrics.BinarySparseRecallAtPrecision(
|
|
||||||
precision, name=f"recall_at_precision_{precision}"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if desired_recalls and len(data.label_names) == 2:
|
|
||||||
for recall in desired_recalls:
|
|
||||||
additional_metrics.append(
|
|
||||||
metrics.BinarySparsePrecisionAtRecall(
|
|
||||||
recall, name=f"precision_at_recall_{recall}"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
metric_functions = self._metric_functions + additional_metrics
|
|
||||||
self._model.compile(
|
|
||||||
optimizer=self._optimizer,
|
|
||||||
loss=self._loss_function,
|
|
||||||
metrics=metric_functions,
|
|
||||||
)
|
|
||||||
return self._model.evaluate(dataset)
|
return self._model.evaluate(dataset)
|
||||||
|
|
||||||
def export_model(
|
def export_model(
|
||||||
|
@ -255,16 +233,17 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_average_word_embedding_classifier(
|
def create_average_word_embedding_classifier(
|
||||||
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
|
cls,
|
||||||
|
train_data: text_ds.Dataset,
|
||||||
|
validation_data: text_ds.Dataset,
|
||||||
options: text_classifier_options.TextClassifierOptions,
|
options: text_classifier_options.TextClassifierOptions,
|
||||||
label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier":
|
) -> "_AverageWordEmbeddingClassifier":
|
||||||
"""Creates, trains, and returns an Average Word Embedding classifier.
|
"""Creates, trains, and returns an Average Word Embedding classifier.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_data: Training data.
|
train_data: Training data.
|
||||||
validation_data: Validation data.
|
validation_data: Validation data.
|
||||||
options: Options for creating and training the text classifier.
|
options: Options for creating and training the text classifier.
|
||||||
label_names: Label names used in the data.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An Average Word Embedding classifier.
|
An Average Word Embedding classifier.
|
||||||
|
@ -370,28 +349,25 @@ class _BertClassifier(TextClassifier):
|
||||||
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
with self._hparams.get_strategy().scope():
|
with self._hparams.get_strategy().scope():
|
||||||
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
self._loss_function = loss_functions.SparseFocalLoss(
|
||||||
self._metric_functions = [
|
self._hparams.gamma, self._num_classes
|
||||||
tf.keras.metrics.SparseCategoricalAccuracy(
|
)
|
||||||
"test_accuracy", dtype=tf.float32
|
self._metric_functions = self._create_metrics()
|
||||||
),
|
|
||||||
metrics.SparsePrecision(name="precision", dtype=tf.float32),
|
|
||||||
metrics.SparseRecall(name="recall", dtype=tf.float32),
|
|
||||||
]
|
|
||||||
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_bert_classifier(
|
def create_bert_classifier(
|
||||||
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
|
cls,
|
||||||
|
train_data: text_ds.Dataset,
|
||||||
|
validation_data: text_ds.Dataset,
|
||||||
options: text_classifier_options.TextClassifierOptions,
|
options: text_classifier_options.TextClassifierOptions,
|
||||||
label_names: Sequence[str]) -> "_BertClassifier":
|
) -> "_BertClassifier":
|
||||||
"""Creates, trains, and returns a BERT-based classifier.
|
"""Creates, trains, and returns a BERT-based classifier.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_data: Training data.
|
train_data: Training data.
|
||||||
validation_data: Validation data.
|
validation_data: Validation data.
|
||||||
options: Options for creating and training the text classifier.
|
options: Options for creating and training the text classifier.
|
||||||
label_names: Label names used in the data.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A BERT-based classifier.
|
A BERT-based classifier.
|
||||||
|
@ -437,8 +413,57 @@ class _BertClassifier(TextClassifier):
|
||||||
uri=self._model_spec.downloaded_files.get_path(),
|
uri=self._model_spec.downloaded_files.get_path(),
|
||||||
model_name=self._model_spec.name,
|
model_name=self._model_spec.name,
|
||||||
)
|
)
|
||||||
return (self._text_preprocessor.preprocess(train_data),
|
return (
|
||||||
self._text_preprocessor.preprocess(validation_data))
|
self._text_preprocessor.preprocess(train_data),
|
||||||
|
self._text_preprocessor.preprocess(validation_data),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_metrics(self):
|
||||||
|
"""Creates metrics for training and evaluation.
|
||||||
|
|
||||||
|
The default metrics are accuracy, precision, and recall.
|
||||||
|
|
||||||
|
For binary classification tasks only (num_classes=2):
|
||||||
|
Users can configure PrecisionAtRecall and RecallAtPrecision metrics using
|
||||||
|
the desired_presisions and desired_recalls fields in BertHParams.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of tf.keras.Metric subclasses which can be used with model.compile
|
||||||
|
"""
|
||||||
|
metric_functions = [
|
||||||
|
tf.keras.metrics.SparseCategoricalAccuracy(
|
||||||
|
"accuracy", dtype=tf.float32
|
||||||
|
),
|
||||||
|
metrics.SparsePrecision(name="precision", dtype=tf.float32),
|
||||||
|
metrics.SparseRecall(name="recall", dtype=tf.float32),
|
||||||
|
]
|
||||||
|
if self._num_classes == 2:
|
||||||
|
if self._hparams.desired_precisions:
|
||||||
|
for desired_precision in self._hparams.desired_precisions:
|
||||||
|
metric_functions.append(
|
||||||
|
metrics.BinarySparseRecallAtPrecision(
|
||||||
|
desired_precision,
|
||||||
|
name=f"recall_at_precision_{desired_precision}",
|
||||||
|
num_thresholds=1000,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self._hparams.desired_recalls:
|
||||||
|
for desired_recall in self._hparams.desired_recalls:
|
||||||
|
metric_functions.append(
|
||||||
|
metrics.BinarySparseRecallAtPrecision(
|
||||||
|
desired_recall,
|
||||||
|
name=f"precision_at_recall_{desired_recall}",
|
||||||
|
num_thresholds=1000,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self._hparams.desired_precisions or self._hparams.desired_recalls:
|
||||||
|
raise ValueError(
|
||||||
|
"desired_recalls and desired_precisions parameters are binary"
|
||||||
|
" metrics and not supported for num_classes > 2. Found"
|
||||||
|
f" num_classes: {self._num_classes}"
|
||||||
|
)
|
||||||
|
return metric_functions
|
||||||
|
|
||||||
def _create_model(self):
|
def _create_model(self):
|
||||||
"""Creates a BERT-based classifier model.
|
"""Creates a BERT-based classifier model.
|
||||||
|
@ -448,11 +473,20 @@ class _BertClassifier(TextClassifier):
|
||||||
"""
|
"""
|
||||||
encoder_inputs = dict(
|
encoder_inputs = dict(
|
||||||
input_word_ids=tf.keras.layers.Input(
|
input_word_ids=tf.keras.layers.Input(
|
||||||
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
shape=(self._model_options.seq_len,),
|
||||||
|
dtype=tf.int32,
|
||||||
|
name="input_word_ids",
|
||||||
|
),
|
||||||
input_mask=tf.keras.layers.Input(
|
input_mask=tf.keras.layers.Input(
|
||||||
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
shape=(self._model_options.seq_len,),
|
||||||
|
dtype=tf.int32,
|
||||||
|
name="input_mask",
|
||||||
|
),
|
||||||
input_type_ids=tf.keras.layers.Input(
|
input_type_ids=tf.keras.layers.Input(
|
||||||
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
shape=(self._model_options.seq_len,),
|
||||||
|
dtype=tf.int32,
|
||||||
|
name="input_type_ids",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
encoder = hub.KerasLayer(
|
encoder = hub.KerasLayer(
|
||||||
self._model_spec.downloaded_files.get_path(),
|
self._model_spec.downloaded_files.get_path(),
|
||||||
|
@ -494,16 +528,21 @@ class _BertClassifier(TextClassifier):
|
||||||
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
|
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
|
||||||
initial_learning_rate=initial_lr,
|
initial_learning_rate=initial_lr,
|
||||||
decay_steps=total_steps,
|
decay_steps=total_steps,
|
||||||
end_learning_rate=0.0,
|
end_learning_rate=self._hparams.end_learning_rate,
|
||||||
power=1.0)
|
power=1.0,
|
||||||
|
)
|
||||||
if warmup_steps:
|
if warmup_steps:
|
||||||
lr_schedule = model_util.WarmUp(
|
lr_schedule = model_util.WarmUp(
|
||||||
initial_learning_rate=initial_lr,
|
initial_learning_rate=initial_lr,
|
||||||
decay_schedule_fn=lr_schedule,
|
decay_schedule_fn=lr_schedule,
|
||||||
warmup_steps=warmup_steps)
|
warmup_steps=warmup_steps,
|
||||||
|
)
|
||||||
if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
|
if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
|
||||||
self._optimizer = tf.keras.optimizers.experimental.AdamW(
|
self._optimizer = tf.keras.optimizers.experimental.AdamW(
|
||||||
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0
|
lr_schedule,
|
||||||
|
weight_decay=self._hparams.weight_decay,
|
||||||
|
epsilon=1e-6,
|
||||||
|
global_clipnorm=1.0,
|
||||||
)
|
)
|
||||||
self._optimizer.exclude_from_weight_decay(
|
self._optimizer.exclude_from_weight_decay(
|
||||||
var_names=["LayerNorm", "layer_norm", "bias"]
|
var_names=["LayerNorm", "layer_norm", "bias"]
|
||||||
|
@ -511,7 +550,7 @@ class _BertClassifier(TextClassifier):
|
||||||
elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
|
elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
|
||||||
self._optimizer = tfa_optimizers.LAMB(
|
self._optimizer = tfa_optimizers.LAMB(
|
||||||
lr_schedule,
|
lr_schedule,
|
||||||
weight_decay_rate=0.01,
|
weight_decay_rate=self._hparams.weight_decay,
|
||||||
epsilon=1e-6,
|
epsilon=1e-6,
|
||||||
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
|
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
|
||||||
global_clipnorm=1.0,
|
global_clipnorm=1.0,
|
||||||
|
|
|
@ -16,17 +16,17 @@ import csv
|
||||||
import filecmp
|
import filecmp
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
|
||||||
from unittest import mock as unittest_mock
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||||
from mediapipe.model_maker.python.text import text_classifier
|
from mediapipe.model_maker.python.text import text_classifier
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
|
||||||
@unittest.skip('b/275624089')
|
class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
class TextClassifierTest(tf.test.TestCase):
|
|
||||||
|
|
||||||
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
||||||
test_utils.get_test_data_path('average_word_embedding_metadata.json'))
|
test_utils.get_test_data_path('average_word_embedding_metadata.json'))
|
||||||
|
@ -78,8 +78,8 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||||
options))
|
options))
|
||||||
|
|
||||||
_, accuracy = average_word_embedding_classifier.evaluate(validation_data)
|
metrics = average_word_embedding_classifier.evaluate(validation_data)
|
||||||
self.assertGreaterEqual(accuracy, 0.0)
|
self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy
|
||||||
|
|
||||||
# Test export_model
|
# Test export_model
|
||||||
average_word_embedding_classifier.export_model()
|
average_word_embedding_classifier.export_model()
|
||||||
|
@ -98,12 +98,25 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
filecmp.cmp(
|
filecmp.cmp(
|
||||||
output_metadata_file,
|
output_metadata_file,
|
||||||
self._AVERAGE_WORD_EMBEDDING_JSON_FILE,
|
self._AVERAGE_WORD_EMBEDDING_JSON_FILE,
|
||||||
shallow=False))
|
shallow=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def test_create_and_train_bert(self):
|
@parameterized.named_parameters(
|
||||||
|
# Skipping mobilebert b/c OSS test timeout/flakiness: b/275624089
|
||||||
|
# dict(
|
||||||
|
# testcase_name='mobilebert',
|
||||||
|
# supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||||
|
# ),
|
||||||
|
dict(
|
||||||
|
testcase_name='exbert',
|
||||||
|
supported_model=text_classifier.SupportedModels.EXBERT_CLASSIFIER,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_create_and_train_bert(self, supported_model):
|
||||||
train_data, validation_data = self._get_data()
|
train_data, validation_data = self._get_data()
|
||||||
options = text_classifier.TextClassifierOptions(
|
options = text_classifier.TextClassifierOptions(
|
||||||
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
supported_model=supported_model,
|
||||||
model_options=text_classifier.BertModelOptions(
|
model_options=text_classifier.BertModelOptions(
|
||||||
do_fine_tuning=False, seq_len=2
|
do_fine_tuning=False, seq_len=2
|
||||||
),
|
),
|
||||||
|
@ -117,8 +130,8 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
bert_classifier = text_classifier.TextClassifier.create(
|
bert_classifier = text_classifier.TextClassifier.create(
|
||||||
train_data, validation_data, options)
|
train_data, validation_data, options)
|
||||||
|
|
||||||
_, accuracy = bert_classifier.evaluate(validation_data)
|
metrics = bert_classifier.evaluate(validation_data)
|
||||||
self.assertGreaterEqual(accuracy, 0.0)
|
self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy
|
||||||
|
|
||||||
# Test export_model
|
# Test export_model
|
||||||
bert_classifier.export_model()
|
bert_classifier.export_model()
|
||||||
|
@ -142,45 +155,93 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_label_mismatch(self):
|
def test_label_mismatch(self):
|
||||||
options = (
|
options = text_classifier.TextClassifierOptions(
|
||||||
text_classifier.TextClassifierOptions(
|
supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER)
|
||||||
supported_model=(
|
)
|
||||||
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER)))
|
|
||||||
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||||
train_data = text_classifier.Dataset(train_tf_dataset, 1, ['foo'])
|
train_data = text_classifier.Dataset(train_tf_dataset, ['foo'], 1)
|
||||||
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||||
validation_data = text_classifier.Dataset(validation_tf_dataset, 1, ['bar'])
|
validation_data = text_classifier.Dataset(validation_tf_dataset, ['bar'], 1)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError,
|
ValueError,
|
||||||
'Training data label names .* not equal to validation data label names'
|
'Training data label names .* not equal to validation data label names',
|
||||||
):
|
):
|
||||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
text_classifier.TextClassifier.create(
|
||||||
options)
|
train_data, validation_data, options
|
||||||
|
)
|
||||||
|
|
||||||
def test_options_mismatch(self):
|
def test_options_mismatch(self):
|
||||||
train_data, validation_data = self._get_data()
|
train_data, validation_data = self._get_data()
|
||||||
|
|
||||||
avg_options = (
|
avg_options = text_classifier.TextClassifierOptions(
|
||||||
text_classifier.TextClassifierOptions(
|
supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER),
|
||||||
supported_model=(
|
model_options=text_classifier.AverageWordEmbeddingModelOptions(),
|
||||||
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER),
|
)
|
||||||
model_options=text_classifier.AverageWordEmbeddingModelOptions()))
|
with self.assertRaisesWithLiteralMatch(
|
||||||
with self.assertRaisesRegex(
|
ValueError,
|
||||||
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
||||||
' SupportedModels.MOBILEBERT_CLASSIFIER'):
|
' SupportedModels.EXBERT_CLASSIFIER',
|
||||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
):
|
||||||
avg_options)
|
text_classifier.TextClassifier.create(
|
||||||
|
train_data, validation_data, avg_options
|
||||||
|
)
|
||||||
|
|
||||||
bert_options = (
|
bert_options = text_classifier.TextClassifierOptions(
|
||||||
text_classifier.TextClassifierOptions(
|
supported_model=(
|
||||||
supported_model=(text_classifier.SupportedModels
|
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
|
||||||
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
),
|
||||||
model_options=text_classifier.BertModelOptions()))
|
model_options=text_classifier.BertModelOptions(),
|
||||||
with self.assertRaisesRegex(
|
)
|
||||||
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
|
with self.assertRaisesWithLiteralMatch(
|
||||||
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):
|
ValueError,
|
||||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
'Expected a Bert Classifier(MobileBERT or EXBERT), got'
|
||||||
bert_options)
|
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER',
|
||||||
|
):
|
||||||
|
text_classifier.TextClassifier.create(
|
||||||
|
train_data, validation_data, bert_options
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_bert_loss_and_metrics_creation(self):
|
||||||
|
train_data, validation_data = self._get_data()
|
||||||
|
supported_model = text_classifier.SupportedModels.EXBERT_CLASSIFIER
|
||||||
|
hparams = text_classifier.BertHParams(
|
||||||
|
desired_recalls=[0.2],
|
||||||
|
desired_precisions=[0.9],
|
||||||
|
epochs=1,
|
||||||
|
batch_size=1,
|
||||||
|
learning_rate=3e-5,
|
||||||
|
distribution_strategy='off',
|
||||||
|
gamma=3.5,
|
||||||
|
)
|
||||||
|
options = text_classifier.TextClassifierOptions(
|
||||||
|
supported_model=supported_model, hparams=hparams
|
||||||
|
)
|
||||||
|
bert_classifier = text_classifier.TextClassifier.create(
|
||||||
|
train_data, validation_data, options
|
||||||
|
)
|
||||||
|
loss_fn = bert_classifier._loss_function
|
||||||
|
self.assertIsInstance(loss_fn, loss_functions.SparseFocalLoss)
|
||||||
|
self.assertEqual(loss_fn._gamma, 3.5)
|
||||||
|
self.assertEqual(loss_fn._num_classes, 2)
|
||||||
|
metric_names = [m.name for m in bert_classifier._metric_functions]
|
||||||
|
expected_metric_names = [
|
||||||
|
'accuracy',
|
||||||
|
'recall',
|
||||||
|
'precision',
|
||||||
|
'precision_at_recall_0.2',
|
||||||
|
'recall_at_precision_0.9',
|
||||||
|
]
|
||||||
|
self.assertCountEqual(metric_names, expected_metric_names)
|
||||||
|
|
||||||
|
# Non-binary data
|
||||||
|
tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||||
|
data = text_classifier.Dataset(tf_dataset, ['foo', 'bar', 'baz'], 1)
|
||||||
|
with self.assertRaisesWithLiteralMatch(
|
||||||
|
ValueError,
|
||||||
|
'desired_recalls and desired_precisions parameters are binary metrics'
|
||||||
|
' and not supported for num_classes > 2. Found num_classes: 3',
|
||||||
|
):
|
||||||
|
text_classifier.TextClassifier.create(data, data, options)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user