1. Model maker core classifier change _metric_function field to _metric_functions in order to support having multiple metrics.

2. Add SparsePrecision, SparseRecall, BinarySparsePrecisionAtRecall, and BinarySparseRecallAtPrecision to the shared metrics library.
3. Add SparsePrecision, SparseRecall to text classifier, and have the option to evaluate the model with BinarySparsePrecisionAtRecall and BinarySparseRecallAtPrecision

PiperOrigin-RevId: 542376451
This commit is contained in:
MediaPipe Team 2023-06-21 15:15:30 -07:00 committed by Copybara-Service
parent 7edb6b8fcb
commit 895c685df6
8 changed files with 243 additions and 9 deletions

View File

@ -43,7 +43,7 @@ class Classifier(custom_model.CustomModel):
self._model: tf.keras.Model = None self._model: tf.keras.Model = None
self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None
self._loss_function: Union[str, tf.keras.losses.Loss] = None self._loss_function: Union[str, tf.keras.losses.Loss] = None
self._metric_function: Union[str, tf.keras.metrics.Metric] = None self._metric_functions: Sequence[Union[str, tf.keras.metrics.Metric]] = None
self._callbacks: Sequence[tf.keras.callbacks.Callback] = None self._callbacks: Sequence[tf.keras.callbacks.Callback] = None
self._hparams: hp.BaseHParams = None self._hparams: hp.BaseHParams = None
self._history: tf.keras.callbacks.History = None self._history: tf.keras.callbacks.History = None
@ -92,7 +92,8 @@ class Classifier(custom_model.CustomModel):
self._model.compile( self._model.compile(
optimizer=self._optimizer, optimizer=self._optimizer,
loss=self._loss_function, loss=self._loss_function,
metrics=[self._metric_function]) metrics=self._metric_functions,
)
latest_checkpoint = ( latest_checkpoint = (
tf.train.latest_checkpoint(checkpoint_path) tf.train.latest_checkpoint(checkpoint_path)

View File

@ -80,6 +80,17 @@ py_test(
deps = [":loss_functions"], deps = [":loss_functions"],
) )
py_library(
name = "metrics",
srcs = ["metrics.py"],
)
py_test(
name = "metrics_test",
srcs = ["metrics_test.py"],
deps = [":metrics"],
)
py_library( py_library(
name = "quantization", name = "quantization",
srcs = ["quantization.py"], srcs = ["quantization.py"],

View File

@ -0,0 +1,104 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Metrics utility library."""
import tensorflow as tf
def _get_binary_sparse_metric(metric: tf.metrics.Metric):
"""Helper method to create a BinarySparse version of a tf.keras.Metric.
BinarySparse is an implementation where the update_state(y_true, y_pred) takes
in shapes y_true=(batch_size, 1) y_pred=(batch_size, 2). Note that this only
supports the binary classification case, and that class_id=0 is the negative
class and class_id=1 is the positive class.
Currently supported tf.metric.Metric classes
1. BinarySparseRecallAtPrecision
2. BinarySparsePrecisionAtRecall
Args:
metric: A tf.metric.Metric class for which we want to generate a
BinarySparse version of this metric.
Returns:
A class for the BinarySparse version of the specified tf.metrics.Metric
"""
class BinarySparseMetric(metric):
"""A BinarySparse wrapper class for a tf.keras.Metric.
This class has the same parameters and functions as the underlying
metric class. For example, the parameters for BinarySparseRecallAtPrecision
is the same as tf.keras.metrics.RecallAtPrecision. The only new constraint
is that class_id must be set to 1 (or not specified) for the Binary metric.
"""
def __init__(self, *args, **kwargs):
if 'class_id' in kwargs and kwargs['class_id'] != 1:
raise ValueError(
f'Custom BinarySparseMetric for class:{metric.__name__} is '
'only supported for class_id=1, got class_id='
f'{kwargs["class_id"]} instead'
)
else:
kwargs['class_id'] = 1
super().__init__(*args, **kwargs)
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32)
y_true_one_hot = tf.one_hot(y_true, 2)
return super().update_state(
y_true_one_hot, y_pred, sample_weight=sample_weight
)
return BinarySparseMetric
def _get_sparse_metric(metric: tf.metrics.Metric):
"""Helper method to create a Sparse version of a tf.keras.Metric.
Sparse is an implementation where the update_state(y_true, y_pred) takes in
shapes y_true=(batch_size, 1) and y_pred=(batch_size, num_classes).
Currently supported tf.metrics.Metric classes:
1. tf.metrics.Recall
2. tf.metrics.Precision
Args:
metric: A tf.metric.Metric class for which we want to generate a Sparse
version of this metric.
Returns:
A class for the Sparse version of the specified tf.keras.Metric.
"""
class SparseMetric(metric):
"""A Sparse wrapper class for a tf.keras.Metric."""
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.math.argmax(y_pred, axis=-1)
return super().update_state(y_true, y_pred, sample_weight=sample_weight)
return SparseMetric
SparseRecall = _get_sparse_metric(tf.metrics.Recall)
SparsePrecision = _get_sparse_metric(tf.metrics.Precision)
BinarySparseRecallAtPrecision = _get_binary_sparse_metric(
tf.metrics.RecallAtPrecision
)
BinarySparsePrecisionAtRecall = _get_binary_sparse_metric(
tf.metrics.PrecisionAtRecall
)

View File

@ -0,0 +1,74 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import metrics
class SparseMetricTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.y_true = [0, 0, 1, 1, 0, 1]
self.y_pred = [
[0.9, 0.1], # 0, 0 y
[0.8, 0.2], # 0, 0 y
[0.7, 0.3], # 0, 1 n
[0.6, 0.4], # 0, 1 n
[0.3, 0.7], # 1, 0 y
[0.3, 0.7], # 1, 1 y
]
self.num_classes = 3
def _assert_metric_equals(self, metric, value):
metric.update_state(self.y_true, self.y_pred)
self.assertEqual(metric.result(), value)
def test_sparse_recall(self):
metric = metrics.SparseRecall()
self._assert_metric_equals(metric, 1 / 3)
def test_sparse_precision(self):
metric = metrics.SparsePrecision()
self._assert_metric_equals(metric, 1 / 2)
def test_binary_sparse_recall_at_precision(self):
metric = metrics.BinarySparseRecallAtPrecision(1.0)
self._assert_metric_equals(metric, 0.0) # impossible to achieve precision=1
metric = metrics.BinarySparseRecallAtPrecision(0.4)
self._assert_metric_equals(metric, 1.0)
def test_binary_sparse_precision_at_recall(self):
metric = metrics.BinarySparsePrecisionAtRecall(1.0)
self._assert_metric_equals(metric, 3 / 4)
metric = metrics.BinarySparsePrecisionAtRecall(0.7)
self._assert_metric_equals(metric, 3 / 4)
def test_binary_sparse_precision_at_recall_class_id_error(self):
# class_id=1 case should not error
_ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=1)
# class_id=2 case should error
with self.assertRaisesRegex(
ValueError,
'Custom BinarySparseMetric for class:PrecisionAtRecall is only'
' supported for class_id=1, got class_id=2 instead',
):
_ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=2)
if __name__ == '__main__':
tf.test.main()

View File

@ -118,6 +118,7 @@ py_library(
"//mediapipe/model_maker/python/core:hyperparameters", "//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/data:dataset",
"//mediapipe/model_maker/python/core/tasks:classifier", "//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:metrics",
"//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization", "//mediapipe/model_maker/python/core/utils:quantization",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",

View File

@ -24,6 +24,7 @@ import tensorflow_hub as hub
from mediapipe.model_maker.python.core import hyperparameters as hp from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.data import dataset as ds from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.tasks import classifier from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import metrics
from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
@ -123,12 +124,24 @@ class TextClassifier(classifier.Classifier):
return text_classifier return text_classifier
def evaluate(self, data: ds.Dataset, batch_size: int = 32) -> Any: def evaluate(
self,
data: ds.Dataset,
batch_size: int = 32,
desired_precisions: Optional[Sequence[float]] = None,
desired_recalls: Optional[Sequence[float]] = None,
) -> Any:
"""Overrides Classifier.evaluate(). """Overrides Classifier.evaluate().
Args: Args:
data: Evaluation dataset. Must be a TextClassifier Dataset. data: Evaluation dataset. Must be a TextClassifier Dataset.
batch_size: Number of samples per evaluation step. batch_size: Number of samples per evaluation step.
desired_precisions: If specified, adds a RecallAtPrecision metric per
desired_precisions[i] entry which tracks the recall given the constraint
on precision. Only supported for binary classification.
desired_recalls: If specified, adds a PrecisionAtRecall metric per
desired_recalls[i] entry which tracks the precision given the constraint
on recall. Only supported for binary classification.
Returns: Returns:
The loss value and accuracy. The loss value and accuracy.
@ -144,6 +157,28 @@ class TextClassifier(classifier.Classifier):
processed_data = self._text_preprocessor.preprocess(data) processed_data = self._text_preprocessor.preprocess(data)
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False) dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
additional_metrics = []
if desired_precisions and len(data.label_names) == 2:
for precision in desired_precisions:
additional_metrics.append(
metrics.BinarySparseRecallAtPrecision(
precision, name=f"recall_at_precision_{precision}"
)
)
if desired_recalls and len(data.label_names) == 2:
for recall in desired_recalls:
additional_metrics.append(
metrics.BinarySparsePrecisionAtRecall(
recall, name=f"precision_at_recall_{recall}"
)
)
metric_functions = self._metric_functions + additional_metrics
self._model.compile(
optimizer=self._optimizer,
loss=self._loss_function,
metrics=metric_functions,
)
return self._model.evaluate(dataset) return self._model.evaluate(dataset)
def export_model( def export_model(
@ -196,7 +231,11 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
super().__init__(model_spec, hparams, label_names) super().__init__(model_spec, hparams, label_names)
self._model_options = model_options self._model_options = model_options
self._loss_function = "sparse_categorical_crossentropy" self._loss_function = "sparse_categorical_crossentropy"
self._metric_function = "accuracy" self._metric_functions = [
"accuracy",
metrics.SparsePrecision(name="precision", dtype=tf.float32),
metrics.SparseRecall(name="recall", dtype=tf.float32),
]
self._text_preprocessor: ( self._text_preprocessor: (
preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None
@ -312,9 +351,13 @@ class _BertClassifier(TextClassifier):
self._model_options = model_options self._model_options = model_options
with self._hparams.get_strategy().scope(): with self._hparams.get_strategy().scope():
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy() self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy( self._metric_functions = [
"test_accuracy", dtype=tf.float32 tf.keras.metrics.SparseCategoricalAccuracy(
) "test_accuracy", dtype=tf.float32
),
metrics.SparsePrecision(name="precision", dtype=tf.float32),
metrics.SparseRecall(name="recall", dtype=tf.float32),
]
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
@classmethod @classmethod

View File

@ -54,7 +54,7 @@ class GestureRecognizer(classifier.Classifier):
self._model_options = model_options self._model_options = model_options
self._hparams = hparams self._hparams = hparams
self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma) self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma)
self._metric_function = 'categorical_accuracy' self._metric_functions = ['categorical_accuracy']
self._optimizer = 'adam' self._optimizer = 'adam'
self._callbacks = self._get_callbacks() self._callbacks = self._get_callbacks()
self._history = None self._history = None

View File

@ -59,7 +59,7 @@ class ImageClassifier(classifier.Classifier):
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._loss_function = tf.keras.losses.CategoricalCrossentropy( self._loss_function = tf.keras.losses.CategoricalCrossentropy(
label_smoothing=self._hparams.label_smoothing) label_smoothing=self._hparams.label_smoothing)
self._metric_function = 'accuracy' self._metric_functions = ['accuracy']
self._history = None # Training history returned from `keras_model.fit`. self._history = None # Training history returned from `keras_model.fit`.
@classmethod @classmethod