From d772bf8134d54d7106632205fd13b15ddd9532d2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 10 Nov 2023 09:01:33 -0800 Subject: [PATCH] Add BinaryAUC metric and Best Checkpoint callback to Text Classifier PiperOrigin-RevId: 581276382 --- .../model_maker/python/core/utils/metrics.py | 17 +++++++++++++++++ .../python/core/utils/metrics_test.py | 12 ++++++++---- .../text/text_classifier/text_classifier.py | 17 ++++++++++++++--- .../text_classifier/text_classifier_test.py | 1 + 4 files changed, 40 insertions(+), 7 deletions(-) diff --git a/mediapipe/model_maker/python/core/utils/metrics.py b/mediapipe/model_maker/python/core/utils/metrics.py index 310146168..cf0be6d08 100644 --- a/mediapipe/model_maker/python/core/utils/metrics.py +++ b/mediapipe/model_maker/python/core/utils/metrics.py @@ -94,6 +94,23 @@ def _get_sparse_metric(metric: tf.metrics.Metric): return SparseMetric +class BinaryAUC(tf.keras.metrics.AUC): + """A Binary AUC metric for binary classification tasks. + + For update state, the shapes of y_true and y_pred are expected to be: + - y_true: [batch_size x 1] array of 0 for negatives and 1 for positives + - y_pred: [batch_size x 2] array of probabilities where y_pred[:,0] are the + probabilities of the 0th(negative) class and y_pred[:,1] are the + probabilities of the 1st(positive) class + + See https://www.tensorflow.org/api_docs/python/tf/keras/metrics/AUC for + details. + """ + + def update_state(self, y_true, y_pred, sample_weight=None): + super().update_state(y_true, y_pred[:, 1], sample_weight) + + SparseRecall = _get_sparse_metric(tf.metrics.Recall) SparsePrecision = _get_sparse_metric(tf.metrics.Precision) BinarySparseRecallAtPrecision = _get_binary_sparse_metric( diff --git a/mediapipe/model_maker/python/core/utils/metrics_test.py b/mediapipe/model_maker/python/core/utils/metrics_test.py index 842335273..2ea8769d2 100644 --- a/mediapipe/model_maker/python/core/utils/metrics_test.py +++ b/mediapipe/model_maker/python/core/utils/metrics_test.py @@ -14,6 +14,7 @@ from absl.testing import parameterized +import numpy as np import tensorflow as tf from mediapipe.model_maker.python.core.utils import metrics @@ -23,16 +24,15 @@ class SparseMetricTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): super().setUp() - self.y_true = [0, 0, 1, 1, 0, 1] - self.y_pred = [ + self.y_true = np.array([0, 0, 1, 1, 0, 1]) + self.y_pred = np.array([ [0.9, 0.1], # 0, 0 y [0.8, 0.2], # 0, 0 y [0.7, 0.3], # 0, 1 n [0.6, 0.4], # 0, 1 n [0.3, 0.7], # 1, 0 y [0.3, 0.7], # 1, 1 y - ] - self.num_classes = 3 + ]) def _assert_metric_equals(self, metric, value): metric.update_state(self.y_true, self.y_pred) @@ -69,6 +69,10 @@ class SparseMetricTest(tf.test.TestCase, parameterized.TestCase): ): _ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=2) + def test_binary_auc(self): + metric = metrics.BinaryAUC(num_thresholds=1000) + self._assert_metric_equals(metric, 0.7222222) + if __name__ == '__main__': tf.test.main() diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index 386e9360e..aea9224ff 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -372,9 +372,19 @@ class _BertClassifier(TextClassifier): ): super().__init__(model_spec, label_names, hparams.shuffle) self._hparams = hparams - self._callbacks = model_util.get_default_callbacks( - self._hparams.export_dir, self._hparams.checkpoint_frequency - ) + self._callbacks = list( + model_util.get_default_callbacks( + self._hparams.export_dir, self._hparams.checkpoint_frequency + ) + ) + [ + tf.keras.callbacks.ModelCheckpoint( + os.path.join(self._hparams.export_dir, "best_model"), + monitor="val_auc", + mode="max", + save_best_only=True, + save_weights_only=False, + ) + ] self._model_options = model_options self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None with self._hparams.get_strategy().scope(): @@ -465,6 +475,7 @@ class _BertClassifier(TextClassifier): ), metrics.SparsePrecision(name="precision", dtype=tf.float32), metrics.SparseRecall(name="recall", dtype=tf.float32), + metrics.BinaryAUC(name="auc", num_thresholds=1000), ] if self._num_classes == 2: if self._hparams.desired_precisions: diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index fdc2613a9..02b3fe4d5 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -230,6 +230,7 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase): 'accuracy', 'recall', 'precision', + 'auc', 'precision_at_recall_0.2', 'recall_at_precision_0.9', ]