1. Model maker core classifier change _metric_function field to _metric_functions in order to support having multiple metrics.
2. Add SparsePrecision, SparseRecall, BinarySparsePrecisionAtRecall, and BinarySparseRecallAtPrecision to the shared metrics library. 3. Add SparsePrecision, SparseRecall to text classifier, and have the option to evaluate the model with BinarySparsePrecisionAtRecall and BinarySparseRecallAtPrecision PiperOrigin-RevId: 542376451
This commit is contained in:
parent
7edb6b8fcb
commit
895c685df6
|
@ -43,7 +43,7 @@ class Classifier(custom_model.CustomModel):
|
||||||
self._model: tf.keras.Model = None
|
self._model: tf.keras.Model = None
|
||||||
self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None
|
self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None
|
||||||
self._loss_function: Union[str, tf.keras.losses.Loss] = None
|
self._loss_function: Union[str, tf.keras.losses.Loss] = None
|
||||||
self._metric_function: Union[str, tf.keras.metrics.Metric] = None
|
self._metric_functions: Sequence[Union[str, tf.keras.metrics.Metric]] = None
|
||||||
self._callbacks: Sequence[tf.keras.callbacks.Callback] = None
|
self._callbacks: Sequence[tf.keras.callbacks.Callback] = None
|
||||||
self._hparams: hp.BaseHParams = None
|
self._hparams: hp.BaseHParams = None
|
||||||
self._history: tf.keras.callbacks.History = None
|
self._history: tf.keras.callbacks.History = None
|
||||||
|
@ -92,7 +92,8 @@ class Classifier(custom_model.CustomModel):
|
||||||
self._model.compile(
|
self._model.compile(
|
||||||
optimizer=self._optimizer,
|
optimizer=self._optimizer,
|
||||||
loss=self._loss_function,
|
loss=self._loss_function,
|
||||||
metrics=[self._metric_function])
|
metrics=self._metric_functions,
|
||||||
|
)
|
||||||
|
|
||||||
latest_checkpoint = (
|
latest_checkpoint = (
|
||||||
tf.train.latest_checkpoint(checkpoint_path)
|
tf.train.latest_checkpoint(checkpoint_path)
|
||||||
|
|
|
@ -80,6 +80,17 @@ py_test(
|
||||||
deps = [":loss_functions"],
|
deps = [":loss_functions"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "metrics",
|
||||||
|
srcs = ["metrics.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "metrics_test",
|
||||||
|
srcs = ["metrics_test.py"],
|
||||||
|
deps = [":metrics"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "quantization",
|
name = "quantization",
|
||||||
srcs = ["quantization.py"],
|
srcs = ["quantization.py"],
|
||||||
|
|
104
mediapipe/model_maker/python/core/utils/metrics.py
Normal file
104
mediapipe/model_maker/python/core/utils/metrics.py
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Metrics utility library."""
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def _get_binary_sparse_metric(metric: tf.metrics.Metric):
|
||||||
|
"""Helper method to create a BinarySparse version of a tf.keras.Metric.
|
||||||
|
|
||||||
|
BinarySparse is an implementation where the update_state(y_true, y_pred) takes
|
||||||
|
in shapes y_true=(batch_size, 1) y_pred=(batch_size, 2). Note that this only
|
||||||
|
supports the binary classification case, and that class_id=0 is the negative
|
||||||
|
class and class_id=1 is the positive class.
|
||||||
|
|
||||||
|
Currently supported tf.metric.Metric classes
|
||||||
|
1. BinarySparseRecallAtPrecision
|
||||||
|
2. BinarySparsePrecisionAtRecall
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metric: A tf.metric.Metric class for which we want to generate a
|
||||||
|
BinarySparse version of this metric.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A class for the BinarySparse version of the specified tf.metrics.Metric
|
||||||
|
"""
|
||||||
|
|
||||||
|
class BinarySparseMetric(metric):
|
||||||
|
"""A BinarySparse wrapper class for a tf.keras.Metric.
|
||||||
|
|
||||||
|
This class has the same parameters and functions as the underlying
|
||||||
|
metric class. For example, the parameters for BinarySparseRecallAtPrecision
|
||||||
|
is the same as tf.keras.metrics.RecallAtPrecision. The only new constraint
|
||||||
|
is that class_id must be set to 1 (or not specified) for the Binary metric.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
if 'class_id' in kwargs and kwargs['class_id'] != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f'Custom BinarySparseMetric for class:{metric.__name__} is '
|
||||||
|
'only supported for class_id=1, got class_id='
|
||||||
|
f'{kwargs["class_id"]} instead'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kwargs['class_id'] = 1
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||||
|
y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32)
|
||||||
|
y_true_one_hot = tf.one_hot(y_true, 2)
|
||||||
|
return super().update_state(
|
||||||
|
y_true_one_hot, y_pred, sample_weight=sample_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
return BinarySparseMetric
|
||||||
|
|
||||||
|
|
||||||
|
def _get_sparse_metric(metric: tf.metrics.Metric):
|
||||||
|
"""Helper method to create a Sparse version of a tf.keras.Metric.
|
||||||
|
|
||||||
|
Sparse is an implementation where the update_state(y_true, y_pred) takes in
|
||||||
|
shapes y_true=(batch_size, 1) and y_pred=(batch_size, num_classes).
|
||||||
|
|
||||||
|
Currently supported tf.metrics.Metric classes:
|
||||||
|
1. tf.metrics.Recall
|
||||||
|
2. tf.metrics.Precision
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metric: A tf.metric.Metric class for which we want to generate a Sparse
|
||||||
|
version of this metric.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A class for the Sparse version of the specified tf.keras.Metric.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class SparseMetric(metric):
|
||||||
|
"""A Sparse wrapper class for a tf.keras.Metric."""
|
||||||
|
|
||||||
|
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||||
|
y_pred = tf.math.argmax(y_pred, axis=-1)
|
||||||
|
return super().update_state(y_true, y_pred, sample_weight=sample_weight)
|
||||||
|
|
||||||
|
return SparseMetric
|
||||||
|
|
||||||
|
|
||||||
|
SparseRecall = _get_sparse_metric(tf.metrics.Recall)
|
||||||
|
SparsePrecision = _get_sparse_metric(tf.metrics.Precision)
|
||||||
|
BinarySparseRecallAtPrecision = _get_binary_sparse_metric(
|
||||||
|
tf.metrics.RecallAtPrecision
|
||||||
|
)
|
||||||
|
BinarySparsePrecisionAtRecall = _get_binary_sparse_metric(
|
||||||
|
tf.metrics.PrecisionAtRecall
|
||||||
|
)
|
74
mediapipe/model_maker/python/core/utils/metrics_test.py
Normal file
74
mediapipe/model_maker/python/core/utils/metrics_test.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import metrics
|
||||||
|
|
||||||
|
|
||||||
|
class SparseMetricTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.y_true = [0, 0, 1, 1, 0, 1]
|
||||||
|
self.y_pred = [
|
||||||
|
[0.9, 0.1], # 0, 0 y
|
||||||
|
[0.8, 0.2], # 0, 0 y
|
||||||
|
[0.7, 0.3], # 0, 1 n
|
||||||
|
[0.6, 0.4], # 0, 1 n
|
||||||
|
[0.3, 0.7], # 1, 0 y
|
||||||
|
[0.3, 0.7], # 1, 1 y
|
||||||
|
]
|
||||||
|
self.num_classes = 3
|
||||||
|
|
||||||
|
def _assert_metric_equals(self, metric, value):
|
||||||
|
metric.update_state(self.y_true, self.y_pred)
|
||||||
|
self.assertEqual(metric.result(), value)
|
||||||
|
|
||||||
|
def test_sparse_recall(self):
|
||||||
|
metric = metrics.SparseRecall()
|
||||||
|
self._assert_metric_equals(metric, 1 / 3)
|
||||||
|
|
||||||
|
def test_sparse_precision(self):
|
||||||
|
metric = metrics.SparsePrecision()
|
||||||
|
self._assert_metric_equals(metric, 1 / 2)
|
||||||
|
|
||||||
|
def test_binary_sparse_recall_at_precision(self):
|
||||||
|
metric = metrics.BinarySparseRecallAtPrecision(1.0)
|
||||||
|
self._assert_metric_equals(metric, 0.0) # impossible to achieve precision=1
|
||||||
|
metric = metrics.BinarySparseRecallAtPrecision(0.4)
|
||||||
|
self._assert_metric_equals(metric, 1.0)
|
||||||
|
|
||||||
|
def test_binary_sparse_precision_at_recall(self):
|
||||||
|
metric = metrics.BinarySparsePrecisionAtRecall(1.0)
|
||||||
|
self._assert_metric_equals(metric, 3 / 4)
|
||||||
|
metric = metrics.BinarySparsePrecisionAtRecall(0.7)
|
||||||
|
self._assert_metric_equals(metric, 3 / 4)
|
||||||
|
|
||||||
|
def test_binary_sparse_precision_at_recall_class_id_error(self):
|
||||||
|
# class_id=1 case should not error
|
||||||
|
_ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=1)
|
||||||
|
# class_id=2 case should error
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
'Custom BinarySparseMetric for class:PrecisionAtRecall is only'
|
||||||
|
' supported for class_id=1, got class_id=2 instead',
|
||||||
|
):
|
||||||
|
_ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -118,6 +118,7 @@ py_library(
|
||||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
"//mediapipe/model_maker/python/core/data:dataset",
|
"//mediapipe/model_maker/python/core/data:dataset",
|
||||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:metrics",
|
||||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||||
|
|
|
@ -24,6 +24,7 @@ import tensorflow_hub as hub
|
||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||||
from mediapipe.model_maker.python.core.tasks import classifier
|
from mediapipe.model_maker.python.core.tasks import classifier
|
||||||
|
from mediapipe.model_maker.python.core.utils import metrics
|
||||||
from mediapipe.model_maker.python.core.utils import model_util
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
from mediapipe.model_maker.python.core.utils import quantization
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
||||||
|
@ -123,12 +124,24 @@ class TextClassifier(classifier.Classifier):
|
||||||
|
|
||||||
return text_classifier
|
return text_classifier
|
||||||
|
|
||||||
def evaluate(self, data: ds.Dataset, batch_size: int = 32) -> Any:
|
def evaluate(
|
||||||
|
self,
|
||||||
|
data: ds.Dataset,
|
||||||
|
batch_size: int = 32,
|
||||||
|
desired_precisions: Optional[Sequence[float]] = None,
|
||||||
|
desired_recalls: Optional[Sequence[float]] = None,
|
||||||
|
) -> Any:
|
||||||
"""Overrides Classifier.evaluate().
|
"""Overrides Classifier.evaluate().
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: Evaluation dataset. Must be a TextClassifier Dataset.
|
data: Evaluation dataset. Must be a TextClassifier Dataset.
|
||||||
batch_size: Number of samples per evaluation step.
|
batch_size: Number of samples per evaluation step.
|
||||||
|
desired_precisions: If specified, adds a RecallAtPrecision metric per
|
||||||
|
desired_precisions[i] entry which tracks the recall given the constraint
|
||||||
|
on precision. Only supported for binary classification.
|
||||||
|
desired_recalls: If specified, adds a PrecisionAtRecall metric per
|
||||||
|
desired_recalls[i] entry which tracks the precision given the constraint
|
||||||
|
on recall. Only supported for binary classification.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The loss value and accuracy.
|
The loss value and accuracy.
|
||||||
|
@ -144,6 +157,28 @@ class TextClassifier(classifier.Classifier):
|
||||||
|
|
||||||
processed_data = self._text_preprocessor.preprocess(data)
|
processed_data = self._text_preprocessor.preprocess(data)
|
||||||
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
|
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
|
||||||
|
|
||||||
|
additional_metrics = []
|
||||||
|
if desired_precisions and len(data.label_names) == 2:
|
||||||
|
for precision in desired_precisions:
|
||||||
|
additional_metrics.append(
|
||||||
|
metrics.BinarySparseRecallAtPrecision(
|
||||||
|
precision, name=f"recall_at_precision_{precision}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if desired_recalls and len(data.label_names) == 2:
|
||||||
|
for recall in desired_recalls:
|
||||||
|
additional_metrics.append(
|
||||||
|
metrics.BinarySparsePrecisionAtRecall(
|
||||||
|
recall, name=f"precision_at_recall_{recall}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
metric_functions = self._metric_functions + additional_metrics
|
||||||
|
self._model.compile(
|
||||||
|
optimizer=self._optimizer,
|
||||||
|
loss=self._loss_function,
|
||||||
|
metrics=metric_functions,
|
||||||
|
)
|
||||||
return self._model.evaluate(dataset)
|
return self._model.evaluate(dataset)
|
||||||
|
|
||||||
def export_model(
|
def export_model(
|
||||||
|
@ -196,7 +231,11 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
|
||||||
super().__init__(model_spec, hparams, label_names)
|
super().__init__(model_spec, hparams, label_names)
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
self._loss_function = "sparse_categorical_crossentropy"
|
self._loss_function = "sparse_categorical_crossentropy"
|
||||||
self._metric_function = "accuracy"
|
self._metric_functions = [
|
||||||
|
"accuracy",
|
||||||
|
metrics.SparsePrecision(name="precision", dtype=tf.float32),
|
||||||
|
metrics.SparseRecall(name="recall", dtype=tf.float32),
|
||||||
|
]
|
||||||
self._text_preprocessor: (
|
self._text_preprocessor: (
|
||||||
preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None
|
preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None
|
||||||
|
|
||||||
|
@ -312,9 +351,13 @@ class _BertClassifier(TextClassifier):
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
with self._hparams.get_strategy().scope():
|
with self._hparams.get_strategy().scope():
|
||||||
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||||
self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy(
|
self._metric_functions = [
|
||||||
|
tf.keras.metrics.SparseCategoricalAccuracy(
|
||||||
"test_accuracy", dtype=tf.float32
|
"test_accuracy", dtype=tf.float32
|
||||||
)
|
),
|
||||||
|
metrics.SparsePrecision(name="precision", dtype=tf.float32),
|
||||||
|
metrics.SparseRecall(name="recall", dtype=tf.float32),
|
||||||
|
]
|
||||||
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -54,7 +54,7 @@ class GestureRecognizer(classifier.Classifier):
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
self._hparams = hparams
|
self._hparams = hparams
|
||||||
self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma)
|
self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma)
|
||||||
self._metric_function = 'categorical_accuracy'
|
self._metric_functions = ['categorical_accuracy']
|
||||||
self._optimizer = 'adam'
|
self._optimizer = 'adam'
|
||||||
self._callbacks = self._get_callbacks()
|
self._callbacks = self._get_callbacks()
|
||||||
self._history = None
|
self._history = None
|
||||||
|
|
|
@ -59,7 +59,7 @@ class ImageClassifier(classifier.Classifier):
|
||||||
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||||
self._loss_function = tf.keras.losses.CategoricalCrossentropy(
|
self._loss_function = tf.keras.losses.CategoricalCrossentropy(
|
||||||
label_smoothing=self._hparams.label_smoothing)
|
label_smoothing=self._hparams.label_smoothing)
|
||||||
self._metric_function = 'accuracy'
|
self._metric_functions = ['accuracy']
|
||||||
self._history = None # Training history returned from `keras_model.fit`.
|
self._history = None # Training history returned from `keras_model.fit`.
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
Loading…
Reference in New Issue
Block a user