Add BinaryAUC metric and Best Checkpoint callback to Text Classifier
PiperOrigin-RevId: 581276382
This commit is contained in:
parent
fd4859c178
commit
d772bf8134
|
@ -94,6 +94,23 @@ def _get_sparse_metric(metric: tf.metrics.Metric):
|
||||||
return SparseMetric
|
return SparseMetric
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryAUC(tf.keras.metrics.AUC):
|
||||||
|
"""A Binary AUC metric for binary classification tasks.
|
||||||
|
|
||||||
|
For update state, the shapes of y_true and y_pred are expected to be:
|
||||||
|
- y_true: [batch_size x 1] array of 0 for negatives and 1 for positives
|
||||||
|
- y_pred: [batch_size x 2] array of probabilities where y_pred[:,0] are the
|
||||||
|
probabilities of the 0th(negative) class and y_pred[:,1] are the
|
||||||
|
probabilities of the 1st(positive) class
|
||||||
|
|
||||||
|
See https://www.tensorflow.org/api_docs/python/tf/keras/metrics/AUC for
|
||||||
|
details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||||
|
super().update_state(y_true, y_pred[:, 1], sample_weight)
|
||||||
|
|
||||||
|
|
||||||
SparseRecall = _get_sparse_metric(tf.metrics.Recall)
|
SparseRecall = _get_sparse_metric(tf.metrics.Recall)
|
||||||
SparsePrecision = _get_sparse_metric(tf.metrics.Precision)
|
SparsePrecision = _get_sparse_metric(tf.metrics.Precision)
|
||||||
BinarySparseRecallAtPrecision = _get_binary_sparse_metric(
|
BinarySparseRecallAtPrecision = _get_binary_sparse_metric(
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.utils import metrics
|
from mediapipe.model_maker.python.core.utils import metrics
|
||||||
|
@ -23,16 +24,15 @@ class SparseMetricTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.y_true = [0, 0, 1, 1, 0, 1]
|
self.y_true = np.array([0, 0, 1, 1, 0, 1])
|
||||||
self.y_pred = [
|
self.y_pred = np.array([
|
||||||
[0.9, 0.1], # 0, 0 y
|
[0.9, 0.1], # 0, 0 y
|
||||||
[0.8, 0.2], # 0, 0 y
|
[0.8, 0.2], # 0, 0 y
|
||||||
[0.7, 0.3], # 0, 1 n
|
[0.7, 0.3], # 0, 1 n
|
||||||
[0.6, 0.4], # 0, 1 n
|
[0.6, 0.4], # 0, 1 n
|
||||||
[0.3, 0.7], # 1, 0 y
|
[0.3, 0.7], # 1, 0 y
|
||||||
[0.3, 0.7], # 1, 1 y
|
[0.3, 0.7], # 1, 1 y
|
||||||
]
|
])
|
||||||
self.num_classes = 3
|
|
||||||
|
|
||||||
def _assert_metric_equals(self, metric, value):
|
def _assert_metric_equals(self, metric, value):
|
||||||
metric.update_state(self.y_true, self.y_pred)
|
metric.update_state(self.y_true, self.y_pred)
|
||||||
|
@ -69,6 +69,10 @@ class SparseMetricTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
):
|
):
|
||||||
_ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=2)
|
_ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=2)
|
||||||
|
|
||||||
|
def test_binary_auc(self):
|
||||||
|
metric = metrics.BinaryAUC(num_thresholds=1000)
|
||||||
|
self._assert_metric_equals(metric, 0.7222222)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
|
@ -372,9 +372,19 @@ class _BertClassifier(TextClassifier):
|
||||||
):
|
):
|
||||||
super().__init__(model_spec, label_names, hparams.shuffle)
|
super().__init__(model_spec, label_names, hparams.shuffle)
|
||||||
self._hparams = hparams
|
self._hparams = hparams
|
||||||
self._callbacks = model_util.get_default_callbacks(
|
self._callbacks = list(
|
||||||
self._hparams.export_dir, self._hparams.checkpoint_frequency
|
model_util.get_default_callbacks(
|
||||||
)
|
self._hparams.export_dir, self._hparams.checkpoint_frequency
|
||||||
|
)
|
||||||
|
) + [
|
||||||
|
tf.keras.callbacks.ModelCheckpoint(
|
||||||
|
os.path.join(self._hparams.export_dir, "best_model"),
|
||||||
|
monitor="val_auc",
|
||||||
|
mode="max",
|
||||||
|
save_best_only=True,
|
||||||
|
save_weights_only=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||||
with self._hparams.get_strategy().scope():
|
with self._hparams.get_strategy().scope():
|
||||||
|
@ -465,6 +475,7 @@ class _BertClassifier(TextClassifier):
|
||||||
),
|
),
|
||||||
metrics.SparsePrecision(name="precision", dtype=tf.float32),
|
metrics.SparsePrecision(name="precision", dtype=tf.float32),
|
||||||
metrics.SparseRecall(name="recall", dtype=tf.float32),
|
metrics.SparseRecall(name="recall", dtype=tf.float32),
|
||||||
|
metrics.BinaryAUC(name="auc", num_thresholds=1000),
|
||||||
]
|
]
|
||||||
if self._num_classes == 2:
|
if self._num_classes == 2:
|
||||||
if self._hparams.desired_precisions:
|
if self._hparams.desired_precisions:
|
||||||
|
|
|
@ -230,6 +230,7 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
'accuracy',
|
'accuracy',
|
||||||
'recall',
|
'recall',
|
||||||
'precision',
|
'precision',
|
||||||
|
'auc',
|
||||||
'precision_at_recall_0.2',
|
'precision_at_recall_0.2',
|
||||||
'recall_at_precision_0.9',
|
'recall_at_precision_0.9',
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user