Add BinaryAUC metric and Best Checkpoint callback to Text Classifier
PiperOrigin-RevId: 581276382
This commit is contained in:
parent
fd4859c178
commit
d772bf8134
|
@ -94,6 +94,23 @@ def _get_sparse_metric(metric: tf.metrics.Metric):
|
|||
return SparseMetric
|
||||
|
||||
|
||||
class BinaryAUC(tf.keras.metrics.AUC):
|
||||
"""A Binary AUC metric for binary classification tasks.
|
||||
|
||||
For update state, the shapes of y_true and y_pred are expected to be:
|
||||
- y_true: [batch_size x 1] array of 0 for negatives and 1 for positives
|
||||
- y_pred: [batch_size x 2] array of probabilities where y_pred[:,0] are the
|
||||
probabilities of the 0th(negative) class and y_pred[:,1] are the
|
||||
probabilities of the 1st(positive) class
|
||||
|
||||
See https://www.tensorflow.org/api_docs/python/tf/keras/metrics/AUC for
|
||||
details.
|
||||
"""
|
||||
|
||||
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||
super().update_state(y_true, y_pred[:, 1], sample_weight)
|
||||
|
||||
|
||||
SparseRecall = _get_sparse_metric(tf.metrics.Recall)
|
||||
SparsePrecision = _get_sparse_metric(tf.metrics.Precision)
|
||||
BinarySparseRecallAtPrecision = _get_binary_sparse_metric(
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import metrics
|
||||
|
@ -23,16 +24,15 @@ class SparseMetricTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.y_true = [0, 0, 1, 1, 0, 1]
|
||||
self.y_pred = [
|
||||
self.y_true = np.array([0, 0, 1, 1, 0, 1])
|
||||
self.y_pred = np.array([
|
||||
[0.9, 0.1], # 0, 0 y
|
||||
[0.8, 0.2], # 0, 0 y
|
||||
[0.7, 0.3], # 0, 1 n
|
||||
[0.6, 0.4], # 0, 1 n
|
||||
[0.3, 0.7], # 1, 0 y
|
||||
[0.3, 0.7], # 1, 1 y
|
||||
]
|
||||
self.num_classes = 3
|
||||
])
|
||||
|
||||
def _assert_metric_equals(self, metric, value):
|
||||
metric.update_state(self.y_true, self.y_pred)
|
||||
|
@ -69,6 +69,10 @@ class SparseMetricTest(tf.test.TestCase, parameterized.TestCase):
|
|||
):
|
||||
_ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=2)
|
||||
|
||||
def test_binary_auc(self):
|
||||
metric = metrics.BinaryAUC(num_thresholds=1000)
|
||||
self._assert_metric_equals(metric, 0.7222222)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
|
@ -372,9 +372,19 @@ class _BertClassifier(TextClassifier):
|
|||
):
|
||||
super().__init__(model_spec, label_names, hparams.shuffle)
|
||||
self._hparams = hparams
|
||||
self._callbacks = model_util.get_default_callbacks(
|
||||
self._callbacks = list(
|
||||
model_util.get_default_callbacks(
|
||||
self._hparams.export_dir, self._hparams.checkpoint_frequency
|
||||
)
|
||||
) + [
|
||||
tf.keras.callbacks.ModelCheckpoint(
|
||||
os.path.join(self._hparams.export_dir, "best_model"),
|
||||
monitor="val_auc",
|
||||
mode="max",
|
||||
save_best_only=True,
|
||||
save_weights_only=False,
|
||||
)
|
||||
]
|
||||
self._model_options = model_options
|
||||
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||
with self._hparams.get_strategy().scope():
|
||||
|
@ -465,6 +475,7 @@ class _BertClassifier(TextClassifier):
|
|||
),
|
||||
metrics.SparsePrecision(name="precision", dtype=tf.float32),
|
||||
metrics.SparseRecall(name="recall", dtype=tf.float32),
|
||||
metrics.BinaryAUC(name="auc", num_thresholds=1000),
|
||||
]
|
||||
if self._num_classes == 2:
|
||||
if self._hparams.desired_precisions:
|
||||
|
|
|
@ -230,6 +230,7 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
'accuracy',
|
||||
'recall',
|
||||
'precision',
|
||||
'auc',
|
||||
'precision_at_recall_0.2',
|
||||
'recall_at_precision_0.9',
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue
Block a user