Add BinaryAUC metric and Best Checkpoint callback to Text Classifier

PiperOrigin-RevId: 581276382
This commit is contained in:
MediaPipe Team 2023-11-10 09:01:33 -08:00 committed by Copybara-Service
parent fd4859c178
commit d772bf8134
4 changed files with 40 additions and 7 deletions

View File

@ -94,6 +94,23 @@ def _get_sparse_metric(metric: tf.metrics.Metric):
return SparseMetric return SparseMetric
class BinaryAUC(tf.keras.metrics.AUC):
"""A Binary AUC metric for binary classification tasks.
For update state, the shapes of y_true and y_pred are expected to be:
- y_true: [batch_size x 1] array of 0 for negatives and 1 for positives
- y_pred: [batch_size x 2] array of probabilities where y_pred[:,0] are the
probabilities of the 0th(negative) class and y_pred[:,1] are the
probabilities of the 1st(positive) class
See https://www.tensorflow.org/api_docs/python/tf/keras/metrics/AUC for
details.
"""
def update_state(self, y_true, y_pred, sample_weight=None):
super().update_state(y_true, y_pred[:, 1], sample_weight)
SparseRecall = _get_sparse_metric(tf.metrics.Recall) SparseRecall = _get_sparse_metric(tf.metrics.Recall)
SparsePrecision = _get_sparse_metric(tf.metrics.Precision) SparsePrecision = _get_sparse_metric(tf.metrics.Precision)
BinarySparseRecallAtPrecision = _get_binary_sparse_metric( BinarySparseRecallAtPrecision = _get_binary_sparse_metric(

View File

@ -14,6 +14,7 @@
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.utils import metrics from mediapipe.model_maker.python.core.utils import metrics
@ -23,16 +24,15 @@ class SparseMetricTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.y_true = [0, 0, 1, 1, 0, 1] self.y_true = np.array([0, 0, 1, 1, 0, 1])
self.y_pred = [ self.y_pred = np.array([
[0.9, 0.1], # 0, 0 y [0.9, 0.1], # 0, 0 y
[0.8, 0.2], # 0, 0 y [0.8, 0.2], # 0, 0 y
[0.7, 0.3], # 0, 1 n [0.7, 0.3], # 0, 1 n
[0.6, 0.4], # 0, 1 n [0.6, 0.4], # 0, 1 n
[0.3, 0.7], # 1, 0 y [0.3, 0.7], # 1, 0 y
[0.3, 0.7], # 1, 1 y [0.3, 0.7], # 1, 1 y
] ])
self.num_classes = 3
def _assert_metric_equals(self, metric, value): def _assert_metric_equals(self, metric, value):
metric.update_state(self.y_true, self.y_pred) metric.update_state(self.y_true, self.y_pred)
@ -69,6 +69,10 @@ class SparseMetricTest(tf.test.TestCase, parameterized.TestCase):
): ):
_ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=2) _ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=2)
def test_binary_auc(self):
metric = metrics.BinaryAUC(num_thresholds=1000)
self._assert_metric_equals(metric, 0.7222222)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()

View File

@ -372,9 +372,19 @@ class _BertClassifier(TextClassifier):
): ):
super().__init__(model_spec, label_names, hparams.shuffle) super().__init__(model_spec, label_names, hparams.shuffle)
self._hparams = hparams self._hparams = hparams
self._callbacks = model_util.get_default_callbacks( self._callbacks = list(
model_util.get_default_callbacks(
self._hparams.export_dir, self._hparams.checkpoint_frequency self._hparams.export_dir, self._hparams.checkpoint_frequency
) )
) + [
tf.keras.callbacks.ModelCheckpoint(
os.path.join(self._hparams.export_dir, "best_model"),
monitor="val_auc",
mode="max",
save_best_only=True,
save_weights_only=False,
)
]
self._model_options = model_options self._model_options = model_options
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
with self._hparams.get_strategy().scope(): with self._hparams.get_strategy().scope():
@ -465,6 +475,7 @@ class _BertClassifier(TextClassifier):
), ),
metrics.SparsePrecision(name="precision", dtype=tf.float32), metrics.SparsePrecision(name="precision", dtype=tf.float32),
metrics.SparseRecall(name="recall", dtype=tf.float32), metrics.SparseRecall(name="recall", dtype=tf.float32),
metrics.BinaryAUC(name="auc", num_thresholds=1000),
] ]
if self._num_classes == 2: if self._num_classes == 2:
if self._hparams.desired_precisions: if self._hparams.desired_precisions:

View File

@ -230,6 +230,7 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
'accuracy', 'accuracy',
'recall', 'recall',
'precision', 'precision',
'auc',
'precision_at_recall_0.2', 'precision_at_recall_0.2',
'recall_at_precision_0.9', 'recall_at_precision_0.9',
] ]