diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index 60c00f0de..a042c0ec7 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -43,7 +43,7 @@ class Classifier(custom_model.CustomModel): self._model: tf.keras.Model = None self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None self._loss_function: Union[str, tf.keras.losses.Loss] = None - self._metric_function: Union[str, tf.keras.metrics.Metric] = None + self._metric_functions: Sequence[Union[str, tf.keras.metrics.Metric]] = None self._callbacks: Sequence[tf.keras.callbacks.Callback] = None self._hparams: hp.BaseHParams = None self._history: tf.keras.callbacks.History = None @@ -92,7 +92,8 @@ class Classifier(custom_model.CustomModel): self._model.compile( optimizer=self._optimizer, loss=self._loss_function, - metrics=[self._metric_function]) + metrics=self._metric_functions, + ) latest_checkpoint = ( tf.train.latest_checkpoint(checkpoint_path) diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index ef9cab290..81bd68d3e 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -80,6 +80,17 @@ py_test( deps = [":loss_functions"], ) +py_library( + name = "metrics", + srcs = ["metrics.py"], +) + +py_test( + name = "metrics_test", + srcs = ["metrics_test.py"], + deps = [":metrics"], +) + py_library( name = "quantization", srcs = ["quantization.py"], diff --git a/mediapipe/model_maker/python/core/utils/metrics.py b/mediapipe/model_maker/python/core/utils/metrics.py new file mode 100644 index 000000000..310146168 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/metrics.py @@ -0,0 +1,104 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Metrics utility library.""" + +import tensorflow as tf + + +def _get_binary_sparse_metric(metric: tf.metrics.Metric): + """Helper method to create a BinarySparse version of a tf.keras.Metric. + + BinarySparse is an implementation where the update_state(y_true, y_pred) takes + in shapes y_true=(batch_size, 1) y_pred=(batch_size, 2). Note that this only + supports the binary classification case, and that class_id=0 is the negative + class and class_id=1 is the positive class. + + Currently supported tf.metric.Metric classes + 1. BinarySparseRecallAtPrecision + 2. BinarySparsePrecisionAtRecall + + Args: + metric: A tf.metric.Metric class for which we want to generate a + BinarySparse version of this metric. + + Returns: + A class for the BinarySparse version of the specified tf.metrics.Metric + """ + + class BinarySparseMetric(metric): + """A BinarySparse wrapper class for a tf.keras.Metric. + + This class has the same parameters and functions as the underlying + metric class. For example, the parameters for BinarySparseRecallAtPrecision + is the same as tf.keras.metrics.RecallAtPrecision. The only new constraint + is that class_id must be set to 1 (or not specified) for the Binary metric. + """ + + def __init__(self, *args, **kwargs): + if 'class_id' in kwargs and kwargs['class_id'] != 1: + raise ValueError( + f'Custom BinarySparseMetric for class:{metric.__name__} is ' + 'only supported for class_id=1, got class_id=' + f'{kwargs["class_id"]} instead' + ) + else: + kwargs['class_id'] = 1 + super().__init__(*args, **kwargs) + + def update_state(self, y_true, y_pred, sample_weight=None): + y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32) + y_true_one_hot = tf.one_hot(y_true, 2) + return super().update_state( + y_true_one_hot, y_pred, sample_weight=sample_weight + ) + + return BinarySparseMetric + + +def _get_sparse_metric(metric: tf.metrics.Metric): + """Helper method to create a Sparse version of a tf.keras.Metric. + + Sparse is an implementation where the update_state(y_true, y_pred) takes in + shapes y_true=(batch_size, 1) and y_pred=(batch_size, num_classes). + + Currently supported tf.metrics.Metric classes: + 1. tf.metrics.Recall + 2. tf.metrics.Precision + + Args: + metric: A tf.metric.Metric class for which we want to generate a Sparse + version of this metric. + + Returns: + A class for the Sparse version of the specified tf.keras.Metric. + """ + + class SparseMetric(metric): + """A Sparse wrapper class for a tf.keras.Metric.""" + + def update_state(self, y_true, y_pred, sample_weight=None): + y_pred = tf.math.argmax(y_pred, axis=-1) + return super().update_state(y_true, y_pred, sample_weight=sample_weight) + + return SparseMetric + + +SparseRecall = _get_sparse_metric(tf.metrics.Recall) +SparsePrecision = _get_sparse_metric(tf.metrics.Precision) +BinarySparseRecallAtPrecision = _get_binary_sparse_metric( + tf.metrics.RecallAtPrecision +) +BinarySparsePrecisionAtRecall = _get_binary_sparse_metric( + tf.metrics.PrecisionAtRecall +) diff --git a/mediapipe/model_maker/python/core/utils/metrics_test.py b/mediapipe/model_maker/python/core/utils/metrics_test.py new file mode 100644 index 000000000..842335273 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/metrics_test.py @@ -0,0 +1,74 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from absl.testing import parameterized +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import metrics + + +class SparseMetricTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + self.y_true = [0, 0, 1, 1, 0, 1] + self.y_pred = [ + [0.9, 0.1], # 0, 0 y + [0.8, 0.2], # 0, 0 y + [0.7, 0.3], # 0, 1 n + [0.6, 0.4], # 0, 1 n + [0.3, 0.7], # 1, 0 y + [0.3, 0.7], # 1, 1 y + ] + self.num_classes = 3 + + def _assert_metric_equals(self, metric, value): + metric.update_state(self.y_true, self.y_pred) + self.assertEqual(metric.result(), value) + + def test_sparse_recall(self): + metric = metrics.SparseRecall() + self._assert_metric_equals(metric, 1 / 3) + + def test_sparse_precision(self): + metric = metrics.SparsePrecision() + self._assert_metric_equals(metric, 1 / 2) + + def test_binary_sparse_recall_at_precision(self): + metric = metrics.BinarySparseRecallAtPrecision(1.0) + self._assert_metric_equals(metric, 0.0) # impossible to achieve precision=1 + metric = metrics.BinarySparseRecallAtPrecision(0.4) + self._assert_metric_equals(metric, 1.0) + + def test_binary_sparse_precision_at_recall(self): + metric = metrics.BinarySparsePrecisionAtRecall(1.0) + self._assert_metric_equals(metric, 3 / 4) + metric = metrics.BinarySparsePrecisionAtRecall(0.7) + self._assert_metric_equals(metric, 3 / 4) + + def test_binary_sparse_precision_at_recall_class_id_error(self): + # class_id=1 case should not error + _ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=1) + # class_id=2 case should error + with self.assertRaisesRegex( + ValueError, + 'Custom BinarySparseMetric for class:PrecisionAtRecall is only' + ' supported for class_id=1, got class_id=2 instead', + ): + _ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=2) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 9fe96849b..26412d2cb 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -118,6 +118,7 @@ py_library( "//mediapipe/model_maker/python/core:hyperparameters", "//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/tasks:classifier", + "//mediapipe/model_maker/python/core/utils:metrics", "//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:quantization", "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index a6762176b..59369931d 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -24,6 +24,7 @@ import tensorflow_hub as hub from mediapipe.model_maker.python.core import hyperparameters as hp from mediapipe.model_maker.python.core.data import dataset as ds from mediapipe.model_maker.python.core.tasks import classifier +from mediapipe.model_maker.python.core.utils import metrics from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds @@ -123,12 +124,24 @@ class TextClassifier(classifier.Classifier): return text_classifier - def evaluate(self, data: ds.Dataset, batch_size: int = 32) -> Any: + def evaluate( + self, + data: ds.Dataset, + batch_size: int = 32, + desired_precisions: Optional[Sequence[float]] = None, + desired_recalls: Optional[Sequence[float]] = None, + ) -> Any: """Overrides Classifier.evaluate(). Args: data: Evaluation dataset. Must be a TextClassifier Dataset. batch_size: Number of samples per evaluation step. + desired_precisions: If specified, adds a RecallAtPrecision metric per + desired_precisions[i] entry which tracks the recall given the constraint + on precision. Only supported for binary classification. + desired_recalls: If specified, adds a PrecisionAtRecall metric per + desired_recalls[i] entry which tracks the precision given the constraint + on recall. Only supported for binary classification. Returns: The loss value and accuracy. @@ -144,6 +157,28 @@ class TextClassifier(classifier.Classifier): processed_data = self._text_preprocessor.preprocess(data) dataset = processed_data.gen_tf_dataset(batch_size, is_training=False) + + additional_metrics = [] + if desired_precisions and len(data.label_names) == 2: + for precision in desired_precisions: + additional_metrics.append( + metrics.BinarySparseRecallAtPrecision( + precision, name=f"recall_at_precision_{precision}" + ) + ) + if desired_recalls and len(data.label_names) == 2: + for recall in desired_recalls: + additional_metrics.append( + metrics.BinarySparsePrecisionAtRecall( + recall, name=f"precision_at_recall_{recall}" + ) + ) + metric_functions = self._metric_functions + additional_metrics + self._model.compile( + optimizer=self._optimizer, + loss=self._loss_function, + metrics=metric_functions, + ) return self._model.evaluate(dataset) def export_model( @@ -196,7 +231,11 @@ class _AverageWordEmbeddingClassifier(TextClassifier): super().__init__(model_spec, hparams, label_names) self._model_options = model_options self._loss_function = "sparse_categorical_crossentropy" - self._metric_function = "accuracy" + self._metric_functions = [ + "accuracy", + metrics.SparsePrecision(name="precision", dtype=tf.float32), + metrics.SparseRecall(name="recall", dtype=tf.float32), + ] self._text_preprocessor: ( preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None @@ -312,9 +351,13 @@ class _BertClassifier(TextClassifier): self._model_options = model_options with self._hparams.get_strategy().scope(): self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy() - self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy( - "test_accuracy", dtype=tf.float32 - ) + self._metric_functions = [ + tf.keras.metrics.SparseCategoricalAccuracy( + "test_accuracy", dtype=tf.float32 + ), + metrics.SparsePrecision(name="precision", dtype=tf.float32), + metrics.SparseRecall(name="recall", dtype=tf.float32), + ] self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None @classmethod diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py index 66934304a..8335968b7 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py @@ -54,7 +54,7 @@ class GestureRecognizer(classifier.Classifier): self._model_options = model_options self._hparams = hparams self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma) - self._metric_function = 'categorical_accuracy' + self._metric_functions = ['categorical_accuracy'] self._optimizer = 'adam' self._callbacks = self._get_callbacks() self._history = None diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py index 3838a5a1a..8acf59f66 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -59,7 +59,7 @@ class ImageClassifier(classifier.Classifier): self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) self._loss_function = tf.keras.losses.CategoricalCrossentropy( label_smoothing=self._hparams.label_smoothing) - self._metric_function = 'accuracy' + self._metric_functions = ['accuracy'] self._history = None # Training history returned from `keras_model.fit`. @classmethod