From 895c685df6ee4eeb9fce5ccdb32c1dbeab2334a6 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 21 Jun 2023 15:15:30 -0700 Subject: [PATCH] 1. Model maker core classifier change _metric_function field to _metric_functions in order to support having multiple metrics. 2. Add SparsePrecision, SparseRecall, BinarySparsePrecisionAtRecall, and BinarySparseRecallAtPrecision to the shared metrics library. 3. Add SparsePrecision, SparseRecall to text classifier, and have the option to evaluate the model with BinarySparsePrecisionAtRecall and BinarySparseRecallAtPrecision PiperOrigin-RevId: 542376451 --- .../python/core/tasks/classifier.py | 5 +- mediapipe/model_maker/python/core/utils/BUILD | 11 ++ .../model_maker/python/core/utils/metrics.py | 104 ++++++++++++++++++ .../python/core/utils/metrics_test.py | 74 +++++++++++++ .../python/text/text_classifier/BUILD | 1 + .../text/text_classifier/text_classifier.py | 53 ++++++++- .../gesture_recognizer/gesture_recognizer.py | 2 +- .../image_classifier/image_classifier.py | 2 +- 8 files changed, 243 insertions(+), 9 deletions(-) create mode 100644 mediapipe/model_maker/python/core/utils/metrics.py create mode 100644 mediapipe/model_maker/python/core/utils/metrics_test.py diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index 60c00f0de..a042c0ec7 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -43,7 +43,7 @@ class Classifier(custom_model.CustomModel): self._model: tf.keras.Model = None self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None self._loss_function: Union[str, tf.keras.losses.Loss] = None - self._metric_function: Union[str, tf.keras.metrics.Metric] = None + self._metric_functions: Sequence[Union[str, tf.keras.metrics.Metric]] = None self._callbacks: Sequence[tf.keras.callbacks.Callback] = None self._hparams: hp.BaseHParams = None self._history: tf.keras.callbacks.History = None @@ -92,7 +92,8 @@ class Classifier(custom_model.CustomModel): self._model.compile( optimizer=self._optimizer, loss=self._loss_function, - metrics=[self._metric_function]) + metrics=self._metric_functions, + ) latest_checkpoint = ( tf.train.latest_checkpoint(checkpoint_path) diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index ef9cab290..81bd68d3e 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -80,6 +80,17 @@ py_test( deps = [":loss_functions"], ) +py_library( + name = "metrics", + srcs = ["metrics.py"], +) + +py_test( + name = "metrics_test", + srcs = ["metrics_test.py"], + deps = [":metrics"], +) + py_library( name = "quantization", srcs = ["quantization.py"], diff --git a/mediapipe/model_maker/python/core/utils/metrics.py b/mediapipe/model_maker/python/core/utils/metrics.py new file mode 100644 index 000000000..310146168 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/metrics.py @@ -0,0 +1,104 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Metrics utility library.""" + +import tensorflow as tf + + +def _get_binary_sparse_metric(metric: tf.metrics.Metric): + """Helper method to create a BinarySparse version of a tf.keras.Metric. + + BinarySparse is an implementation where the update_state(y_true, y_pred) takes + in shapes y_true=(batch_size, 1) y_pred=(batch_size, 2). Note that this only + supports the binary classification case, and that class_id=0 is the negative + class and class_id=1 is the positive class. + + Currently supported tf.metric.Metric classes + 1. BinarySparseRecallAtPrecision + 2. BinarySparsePrecisionAtRecall + + Args: + metric: A tf.metric.Metric class for which we want to generate a + BinarySparse version of this metric. + + Returns: + A class for the BinarySparse version of the specified tf.metrics.Metric + """ + + class BinarySparseMetric(metric): + """A BinarySparse wrapper class for a tf.keras.Metric. + + This class has the same parameters and functions as the underlying + metric class. For example, the parameters for BinarySparseRecallAtPrecision + is the same as tf.keras.metrics.RecallAtPrecision. The only new constraint + is that class_id must be set to 1 (or not specified) for the Binary metric. + """ + + def __init__(self, *args, **kwargs): + if 'class_id' in kwargs and kwargs['class_id'] != 1: + raise ValueError( + f'Custom BinarySparseMetric for class:{metric.__name__} is ' + 'only supported for class_id=1, got class_id=' + f'{kwargs["class_id"]} instead' + ) + else: + kwargs['class_id'] = 1 + super().__init__(*args, **kwargs) + + def update_state(self, y_true, y_pred, sample_weight=None): + y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32) + y_true_one_hot = tf.one_hot(y_true, 2) + return super().update_state( + y_true_one_hot, y_pred, sample_weight=sample_weight + ) + + return BinarySparseMetric + + +def _get_sparse_metric(metric: tf.metrics.Metric): + """Helper method to create a Sparse version of a tf.keras.Metric. + + Sparse is an implementation where the update_state(y_true, y_pred) takes in + shapes y_true=(batch_size, 1) and y_pred=(batch_size, num_classes). + + Currently supported tf.metrics.Metric classes: + 1. tf.metrics.Recall + 2. tf.metrics.Precision + + Args: + metric: A tf.metric.Metric class for which we want to generate a Sparse + version of this metric. + + Returns: + A class for the Sparse version of the specified tf.keras.Metric. + """ + + class SparseMetric(metric): + """A Sparse wrapper class for a tf.keras.Metric.""" + + def update_state(self, y_true, y_pred, sample_weight=None): + y_pred = tf.math.argmax(y_pred, axis=-1) + return super().update_state(y_true, y_pred, sample_weight=sample_weight) + + return SparseMetric + + +SparseRecall = _get_sparse_metric(tf.metrics.Recall) +SparsePrecision = _get_sparse_metric(tf.metrics.Precision) +BinarySparseRecallAtPrecision = _get_binary_sparse_metric( + tf.metrics.RecallAtPrecision +) +BinarySparsePrecisionAtRecall = _get_binary_sparse_metric( + tf.metrics.PrecisionAtRecall +) diff --git a/mediapipe/model_maker/python/core/utils/metrics_test.py b/mediapipe/model_maker/python/core/utils/metrics_test.py new file mode 100644 index 000000000..842335273 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/metrics_test.py @@ -0,0 +1,74 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from absl.testing import parameterized +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import metrics + + +class SparseMetricTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + self.y_true = [0, 0, 1, 1, 0, 1] + self.y_pred = [ + [0.9, 0.1], # 0, 0 y + [0.8, 0.2], # 0, 0 y + [0.7, 0.3], # 0, 1 n + [0.6, 0.4], # 0, 1 n + [0.3, 0.7], # 1, 0 y + [0.3, 0.7], # 1, 1 y + ] + self.num_classes = 3 + + def _assert_metric_equals(self, metric, value): + metric.update_state(self.y_true, self.y_pred) + self.assertEqual(metric.result(), value) + + def test_sparse_recall(self): + metric = metrics.SparseRecall() + self._assert_metric_equals(metric, 1 / 3) + + def test_sparse_precision(self): + metric = metrics.SparsePrecision() + self._assert_metric_equals(metric, 1 / 2) + + def test_binary_sparse_recall_at_precision(self): + metric = metrics.BinarySparseRecallAtPrecision(1.0) + self._assert_metric_equals(metric, 0.0) # impossible to achieve precision=1 + metric = metrics.BinarySparseRecallAtPrecision(0.4) + self._assert_metric_equals(metric, 1.0) + + def test_binary_sparse_precision_at_recall(self): + metric = metrics.BinarySparsePrecisionAtRecall(1.0) + self._assert_metric_equals(metric, 3 / 4) + metric = metrics.BinarySparsePrecisionAtRecall(0.7) + self._assert_metric_equals(metric, 3 / 4) + + def test_binary_sparse_precision_at_recall_class_id_error(self): + # class_id=1 case should not error + _ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=1) + # class_id=2 case should error + with self.assertRaisesRegex( + ValueError, + 'Custom BinarySparseMetric for class:PrecisionAtRecall is only' + ' supported for class_id=1, got class_id=2 instead', + ): + _ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=2) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 9fe96849b..26412d2cb 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -118,6 +118,7 @@ py_library( "//mediapipe/model_maker/python/core:hyperparameters", "//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/tasks:classifier", + "//mediapipe/model_maker/python/core/utils:metrics", "//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:quantization", "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index a6762176b..59369931d 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -24,6 +24,7 @@ import tensorflow_hub as hub from mediapipe.model_maker.python.core import hyperparameters as hp from mediapipe.model_maker.python.core.data import dataset as ds from mediapipe.model_maker.python.core.tasks import classifier +from mediapipe.model_maker.python.core.utils import metrics from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds @@ -123,12 +124,24 @@ class TextClassifier(classifier.Classifier): return text_classifier - def evaluate(self, data: ds.Dataset, batch_size: int = 32) -> Any: + def evaluate( + self, + data: ds.Dataset, + batch_size: int = 32, + desired_precisions: Optional[Sequence[float]] = None, + desired_recalls: Optional[Sequence[float]] = None, + ) -> Any: """Overrides Classifier.evaluate(). Args: data: Evaluation dataset. Must be a TextClassifier Dataset. batch_size: Number of samples per evaluation step. + desired_precisions: If specified, adds a RecallAtPrecision metric per + desired_precisions[i] entry which tracks the recall given the constraint + on precision. Only supported for binary classification. + desired_recalls: If specified, adds a PrecisionAtRecall metric per + desired_recalls[i] entry which tracks the precision given the constraint + on recall. Only supported for binary classification. Returns: The loss value and accuracy. @@ -144,6 +157,28 @@ class TextClassifier(classifier.Classifier): processed_data = self._text_preprocessor.preprocess(data) dataset = processed_data.gen_tf_dataset(batch_size, is_training=False) + + additional_metrics = [] + if desired_precisions and len(data.label_names) == 2: + for precision in desired_precisions: + additional_metrics.append( + metrics.BinarySparseRecallAtPrecision( + precision, name=f"recall_at_precision_{precision}" + ) + ) + if desired_recalls and len(data.label_names) == 2: + for recall in desired_recalls: + additional_metrics.append( + metrics.BinarySparsePrecisionAtRecall( + recall, name=f"precision_at_recall_{recall}" + ) + ) + metric_functions = self._metric_functions + additional_metrics + self._model.compile( + optimizer=self._optimizer, + loss=self._loss_function, + metrics=metric_functions, + ) return self._model.evaluate(dataset) def export_model( @@ -196,7 +231,11 @@ class _AverageWordEmbeddingClassifier(TextClassifier): super().__init__(model_spec, hparams, label_names) self._model_options = model_options self._loss_function = "sparse_categorical_crossentropy" - self._metric_function = "accuracy" + self._metric_functions = [ + "accuracy", + metrics.SparsePrecision(name="precision", dtype=tf.float32), + metrics.SparseRecall(name="recall", dtype=tf.float32), + ] self._text_preprocessor: ( preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None @@ -312,9 +351,13 @@ class _BertClassifier(TextClassifier): self._model_options = model_options with self._hparams.get_strategy().scope(): self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy() - self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy( - "test_accuracy", dtype=tf.float32 - ) + self._metric_functions = [ + tf.keras.metrics.SparseCategoricalAccuracy( + "test_accuracy", dtype=tf.float32 + ), + metrics.SparsePrecision(name="precision", dtype=tf.float32), + metrics.SparseRecall(name="recall", dtype=tf.float32), + ] self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None @classmethod diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py index 66934304a..8335968b7 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py @@ -54,7 +54,7 @@ class GestureRecognizer(classifier.Classifier): self._model_options = model_options self._hparams = hparams self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma) - self._metric_function = 'categorical_accuracy' + self._metric_functions = ['categorical_accuracy'] self._optimizer = 'adam' self._callbacks = self._get_callbacks() self._history = None diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py index 3838a5a1a..8acf59f66 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -59,7 +59,7 @@ class ImageClassifier(classifier.Classifier): self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) self._loss_function = tf.keras.losses.CategoricalCrossentropy( label_smoothing=self._hparams.label_smoothing) - self._metric_function = 'accuracy' + self._metric_functions = ['accuracy'] self._history = None # Training history returned from `keras_model.fit`. @classmethod