1. Move evaluation onto GPU/TPU hardware if available.
2. Move desired_precision and desired_recall from evaluate to hyperparameters so recall@precision metrics will be reported for both training and evaluation. This also fixes a bug where recompiling the model with the previously initialized metric objects would not properly reset the metric states. 3. Remove redundant label_names from create_... class methods in text_classifier. This information is already provided by the datasets. 4. Change loss function to FocalLoss. 5. Re-enable text_classifier unit tests using ExBert 6. Add input names to avoid flaky auto-assigned input names. PiperOrigin-RevId: 550992146
This commit is contained in:
		
							parent
							
								
									85c3fed70a
								
							
						
					
					
						commit
						bd7888cc0c
					
				|  | @ -59,7 +59,7 @@ class FocalLoss(tf.keras.losses.Loss): | ||||||
|   """ |   """ | ||||||
| 
 | 
 | ||||||
|   def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None): |   def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None): | ||||||
|     """Constructor. |     """Initializes FocalLoss. | ||||||
| 
 | 
 | ||||||
|     Args: |     Args: | ||||||
|       gamma: Focal loss gamma, as described in class docs. |       gamma: Focal loss gamma, as described in class docs. | ||||||
|  | @ -115,6 +115,51 @@ class FocalLoss(tf.keras.losses.Loss): | ||||||
|     return tf.reduce_sum(losses) / batch_size |     return tf.reduce_sum(losses) / batch_size | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class SparseFocalLoss(FocalLoss): | ||||||
|  |   """Sparse implementation of Focal Loss. | ||||||
|  | 
 | ||||||
|  |   This is the same as FocalLoss, except the labels are expected to be class ids | ||||||
|  |   instead of 1-hot encoded vectors. See FocalLoss class documentation defined | ||||||
|  |   in this same file for more details. | ||||||
|  | 
 | ||||||
|  |   Example usage: | ||||||
|  |   >>> y_true = [1, 2] | ||||||
|  |   >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] | ||||||
|  |   >>> gamma = 2 | ||||||
|  |   >>> focal_loss = SparseFocalLoss(gamma, 3) | ||||||
|  |   >>> focal_loss(y_true, y_pred).numpy() | ||||||
|  |   0.9326 | ||||||
|  | 
 | ||||||
|  |   >>> # Calling with 'sample_weight'. | ||||||
|  |   >>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy() | ||||||
|  |   0.6528 | ||||||
|  |   """ | ||||||
|  | 
 | ||||||
|  |   def __init__( | ||||||
|  |       self, gamma, num_classes, class_weight: Optional[Sequence[float]] = None | ||||||
|  |   ): | ||||||
|  |     """Initializes SparseFocalLoss. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |       gamma: Focal loss gamma, as described in class docs. | ||||||
|  |       num_classes: Number of classes. | ||||||
|  |       class_weight: A weight to apply to the loss, one for each class. The | ||||||
|  |         weight is applied for each input where the ground truth label matches. | ||||||
|  |     """ | ||||||
|  |     super().__init__(gamma, class_weight=class_weight) | ||||||
|  |     self._num_classes = num_classes | ||||||
|  | 
 | ||||||
|  |   def __call__( | ||||||
|  |       self, | ||||||
|  |       y_true: tf.Tensor, | ||||||
|  |       y_pred: tf.Tensor, | ||||||
|  |       sample_weight: Optional[tf.Tensor] = None, | ||||||
|  |   ) -> tf.Tensor: | ||||||
|  |     y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32) | ||||||
|  |     y_true_one_hot = tf.one_hot(y_true, self._num_classes) | ||||||
|  |     return super().__call__(y_true_one_hot, y_pred, sample_weight=sample_weight) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @dataclasses.dataclass | @dataclasses.dataclass | ||||||
| class PerceptualLossWeight: | class PerceptualLossWeight: | ||||||
|   """The weight for each perceptual loss. |   """The weight for each perceptual loss. | ||||||
|  |  | ||||||
|  | @ -101,6 +101,23 @@ class FocalLossTest(tf.test.TestCase, parameterized.TestCase): | ||||||
|     self.assertNear(loss, expected_loss, 1e-4) |     self.assertNear(loss, expected_loss, 1e-4) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class SparseFocalLossTest(tf.test.TestCase): | ||||||
|  | 
 | ||||||
|  |   def test_sparse_focal_loss_matches_focal_loss(self): | ||||||
|  |     num_classes = 2 | ||||||
|  |     y_pred = tf.constant([[0.8, 0.2], [0.3, 0.7]]) | ||||||
|  |     y_true = tf.constant([1, 0]) | ||||||
|  |     y_true_one_hot = tf.one_hot(y_true, num_classes) | ||||||
|  |     for gamma in [0.0, 0.5, 1.0]: | ||||||
|  |       expected_loss_fn = loss_functions.FocalLoss(gamma=gamma) | ||||||
|  |       loss_fn = loss_functions.SparseFocalLoss( | ||||||
|  |           gamma=gamma, num_classes=num_classes | ||||||
|  |       ) | ||||||
|  |       expected_loss = expected_loss_fn(y_true_one_hot, y_pred) | ||||||
|  |       loss = loss_fn(y_true, y_pred) | ||||||
|  |       self.assertNear(loss, expected_loss, 1e-4) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class MockPerceptualLoss(loss_functions.PerceptualLoss): | class MockPerceptualLoss(loss_functions.PerceptualLoss): | ||||||
|   """A mock class with implementation of abstract methods for testing.""" |   """A mock class with implementation of abstract methods for testing.""" | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -131,6 +131,7 @@ py_library( | ||||||
|         ":text_classifier_options", |         ":text_classifier_options", | ||||||
|         "//mediapipe/model_maker/python/core/data:dataset", |         "//mediapipe/model_maker/python/core/data:dataset", | ||||||
|         "//mediapipe/model_maker/python/core/tasks:classifier", |         "//mediapipe/model_maker/python/core/tasks:classifier", | ||||||
|  |         "//mediapipe/model_maker/python/core/utils:loss_functions", | ||||||
|         "//mediapipe/model_maker/python/core/utils:metrics", |         "//mediapipe/model_maker/python/core/utils:metrics", | ||||||
|         "//mediapipe/model_maker/python/core/utils:model_util", |         "//mediapipe/model_maker/python/core/utils:model_util", | ||||||
|         "//mediapipe/model_maker/python/core/utils:quantization", |         "//mediapipe/model_maker/python/core/utils:quantization", | ||||||
|  | @ -154,6 +155,7 @@ py_test( | ||||||
|     ], |     ], | ||||||
|     deps = [ |     deps = [ | ||||||
|         ":text_classifier_import", |         ":text_classifier_import", | ||||||
|  |         "//mediapipe/model_maker/python/core/utils:loss_functions", | ||||||
|         "//mediapipe/tasks/python/test:test_utils", |         "//mediapipe/tasks/python/test:test_utils", | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | @ -15,7 +15,7 @@ | ||||||
| 
 | 
 | ||||||
| import dataclasses | import dataclasses | ||||||
| import enum | import enum | ||||||
| from typing import Union | from typing import Sequence, Union | ||||||
| 
 | 
 | ||||||
| from mediapipe.model_maker.python.core import hyperparameters as hp | from mediapipe.model_maker.python.core import hyperparameters as hp | ||||||
| 
 | 
 | ||||||
|  | @ -39,16 +39,34 @@ class BertHParams(hp.BaseHParams): | ||||||
| 
 | 
 | ||||||
|   Attributes: |   Attributes: | ||||||
|     learning_rate: Learning rate to use for gradient descent training. |     learning_rate: Learning rate to use for gradient descent training. | ||||||
|     batch_size: Batch size for training. |     end_learning_rate: End learning rate for linear decay. Defaults to 0. | ||||||
|     epochs: Number of training iterations over the dataset. |     batch_size: Batch size for training. Defaults to 48. | ||||||
|     optimizer: Optimizer to use for training. Only supported values are "adamw" |     epochs: Number of training iterations over the dataset. Defaults to 2. | ||||||
|       and "lamb". |     optimizer: Optimizer to use for training. Supported values are defined in | ||||||
|  |       BertOptimizer enum: ADAMW and LAMB. | ||||||
|  |     weight_decay: Weight decay of the optimizer. Defaults to 0.01. | ||||||
|  |     desired_precisions: If specified, adds a RecallAtPrecision metric per | ||||||
|  |       desired_precisions[i] entry which tracks the recall given the constraint | ||||||
|  |       on precision. Only supported for binary classification. | ||||||
|  |     desired_recalls: If specified, adds a PrecisionAtRecall metric per | ||||||
|  |       desired_recalls[i] entry which tracks the precision given the constraint | ||||||
|  |       on recall. Only supported for binary classification. | ||||||
|  |     gamma: Gamma parameter for focal loss. To use cross entropy loss, set this | ||||||
|  |       value to 0. Defaults to 2.0. | ||||||
|   """ |   """ | ||||||
| 
 | 
 | ||||||
|   learning_rate: float = 3e-5 |   learning_rate: float = 3e-5 | ||||||
|  |   end_learning_rate: float = 0.0 | ||||||
|  | 
 | ||||||
|   batch_size: int = 48 |   batch_size: int = 48 | ||||||
|   epochs: int = 2 |   epochs: int = 2 | ||||||
|   optimizer: BertOptimizer = BertOptimizer.ADAMW |   optimizer: BertOptimizer = BertOptimizer.ADAMW | ||||||
|  |   weight_decay: float = 0.01 | ||||||
|  | 
 | ||||||
|  |   desired_precisions: Sequence[float] = dataclasses.field(default_factory=list) | ||||||
|  |   desired_recalls: Sequence[float] = dataclasses.field(default_factory=list) | ||||||
|  | 
 | ||||||
|  |   gamma: float = 2.0 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| HParams = Union[BertHParams, AverageWordEmbeddingHParams] | HParams = Union[BertHParams, AverageWordEmbeddingHParams] | ||||||
|  |  | ||||||
|  | @ -79,11 +79,6 @@ mobilebert_classifier_spec = functools.partial( | ||||||
|         epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' |         epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' | ||||||
|     ), |     ), | ||||||
|     name='MobileBert', |     name='MobileBert', | ||||||
|     tflite_input_name={ |  | ||||||
|         'ids': 'serving_default_input_1:0', |  | ||||||
|         'segment_ids': 'serving_default_input_2:0', |  | ||||||
|         'mask': 'serving_default_input_3:0', |  | ||||||
|     }, |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| exbert_classifier_spec = functools.partial( | exbert_classifier_spec = functools.partial( | ||||||
|  | @ -93,11 +88,6 @@ exbert_classifier_spec = functools.partial( | ||||||
|         epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' |         epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' | ||||||
|     ), |     ), | ||||||
|     name='ExBert', |     name='ExBert', | ||||||
|     tflite_input_name={ |  | ||||||
|         'ids': 'serving_default_input_1:0', |  | ||||||
|         'segment_ids': 'serving_default_input_2:0', |  | ||||||
|         'mask': 'serving_default_input_3:0', |  | ||||||
|     }, |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -46,11 +46,13 @@ class ModelSpecTest(tf.test.TestCase): | ||||||
|     self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path())) |     self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path())) | ||||||
|     self.assertTrue(model_spec_obj.do_lower_case) |     self.assertTrue(model_spec_obj.do_lower_case) | ||||||
|     self.assertEqual( |     self.assertEqual( | ||||||
|         model_spec_obj.tflite_input_name, { |         model_spec_obj.tflite_input_name, | ||||||
|             'ids': 'serving_default_input_1:0', |         { | ||||||
|             'mask': 'serving_default_input_3:0', |             'ids': 'serving_default_input_word_ids:0', | ||||||
|             'segment_ids': 'serving_default_input_2:0' |             'mask': 'serving_default_input_mask:0', | ||||||
|         }) |             'segment_ids': 'serving_default_input_type_ids:0', | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|     self.assertEqual( |     self.assertEqual( | ||||||
|         model_spec_obj.model_options, |         model_spec_obj.model_options, | ||||||
|         classifier_model_options.BertModelOptions( |         classifier_model_options.BertModelOptions( | ||||||
|  |  | ||||||
|  | @ -16,8 +16,8 @@ | ||||||
|           } |           } | ||||||
|         }, |         }, | ||||||
|         { |         { | ||||||
|           "name": "mask", |           "name": "segment_ids", | ||||||
|           "description": "Mask with 1 for real tokens and 0 for padding tokens.", |           "description": "0 for the first sequence, 1 for the second sequence if exists.", | ||||||
|           "content": { |           "content": { | ||||||
|             "content_properties_type": "FeatureProperties", |             "content_properties_type": "FeatureProperties", | ||||||
|             "content_properties": { |             "content_properties": { | ||||||
|  | @ -27,8 +27,8 @@ | ||||||
|           } |           } | ||||||
|         }, |         }, | ||||||
|         { |         { | ||||||
|           "name": "segment_ids", |           "name": "mask", | ||||||
|           "description": "0 for the first sequence, 1 for the second sequence if exists.", |           "description": "Mask with 1 for real tokens and 0 for padding tokens.", | ||||||
|           "content": { |           "content": { | ||||||
|             "content_properties_type": "FeatureProperties", |             "content_properties_type": "FeatureProperties", | ||||||
|             "content_properties": { |             "content_properties": { | ||||||
|  |  | ||||||
|  | @ -24,6 +24,7 @@ import tensorflow_hub as hub | ||||||
| 
 | 
 | ||||||
| from mediapipe.model_maker.python.core.data import dataset as ds | from mediapipe.model_maker.python.core.data import dataset as ds | ||||||
| from mediapipe.model_maker.python.core.tasks import classifier | from mediapipe.model_maker.python.core.tasks import classifier | ||||||
|  | from mediapipe.model_maker.python.core.utils import loss_functions | ||||||
| from mediapipe.model_maker.python.core.utils import metrics | from mediapipe.model_maker.python.core.utils import metrics | ||||||
| from mediapipe.model_maker.python.core.utils import model_util | from mediapipe.model_maker.python.core.utils import model_util | ||||||
| from mediapipe.model_maker.python.core.utils import quantization | from mediapipe.model_maker.python.core.utils import quantization | ||||||
|  | @ -116,17 +117,14 @@ class TextClassifier(classifier.Classifier): | ||||||
|         options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER |         options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER | ||||||
|         or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER |         or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER | ||||||
|     ): |     ): | ||||||
|       text_classifier = ( |       text_classifier = _BertClassifier.create_bert_classifier( | ||||||
|           _BertClassifier.create_bert_classifier(train_data, validation_data, |           train_data, validation_data, options | ||||||
|                                                  options, |       ) | ||||||
|                                                  train_data.label_names)) |  | ||||||
|     elif (options.supported_model == |     elif (options.supported_model == | ||||||
|           ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER): |           ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER): | ||||||
|       text_classifier = ( |       text_classifier = _AverageWordEmbeddingClassifier.create_average_word_embedding_classifier( | ||||||
|           _AverageWordEmbeddingClassifier |           train_data, validation_data, options | ||||||
|           .create_average_word_embedding_classifier(train_data, validation_data, |       ) | ||||||
|                                                     options, |  | ||||||
|                                                     train_data.label_names)) |  | ||||||
|     else: |     else: | ||||||
|       raise ValueError(f"Unknown model {options.supported_model}") |       raise ValueError(f"Unknown model {options.supported_model}") | ||||||
| 
 | 
 | ||||||
|  | @ -166,28 +164,8 @@ class TextClassifier(classifier.Classifier): | ||||||
|     processed_data = self._text_preprocessor.preprocess(data) |     processed_data = self._text_preprocessor.preprocess(data) | ||||||
|     dataset = processed_data.gen_tf_dataset(batch_size, is_training=False) |     dataset = processed_data.gen_tf_dataset(batch_size, is_training=False) | ||||||
| 
 | 
 | ||||||
|     additional_metrics = [] |     with self._hparams.get_strategy().scope(): | ||||||
|     if desired_precisions and len(data.label_names) == 2: |       return self._model.evaluate(dataset) | ||||||
|       for precision in desired_precisions: |  | ||||||
|         additional_metrics.append( |  | ||||||
|             metrics.BinarySparseRecallAtPrecision( |  | ||||||
|                 precision, name=f"recall_at_precision_{precision}" |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|     if desired_recalls and len(data.label_names) == 2: |  | ||||||
|       for recall in desired_recalls: |  | ||||||
|         additional_metrics.append( |  | ||||||
|             metrics.BinarySparsePrecisionAtRecall( |  | ||||||
|                 recall, name=f"precision_at_recall_{recall}" |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|     metric_functions = self._metric_functions + additional_metrics |  | ||||||
|     self._model.compile( |  | ||||||
|         optimizer=self._optimizer, |  | ||||||
|         loss=self._loss_function, |  | ||||||
|         metrics=metric_functions, |  | ||||||
|     ) |  | ||||||
|     return self._model.evaluate(dataset) |  | ||||||
| 
 | 
 | ||||||
|   def export_model( |   def export_model( | ||||||
|       self, |       self, | ||||||
|  | @ -255,16 +233,17 @@ class _AverageWordEmbeddingClassifier(TextClassifier): | ||||||
| 
 | 
 | ||||||
|   @classmethod |   @classmethod | ||||||
|   def create_average_word_embedding_classifier( |   def create_average_word_embedding_classifier( | ||||||
|       cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset, |       cls, | ||||||
|  |       train_data: text_ds.Dataset, | ||||||
|  |       validation_data: text_ds.Dataset, | ||||||
|       options: text_classifier_options.TextClassifierOptions, |       options: text_classifier_options.TextClassifierOptions, | ||||||
|       label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier": |   ) -> "_AverageWordEmbeddingClassifier": | ||||||
|     """Creates, trains, and returns an Average Word Embedding classifier. |     """Creates, trains, and returns an Average Word Embedding classifier. | ||||||
| 
 | 
 | ||||||
|     Args: |     Args: | ||||||
|       train_data: Training data. |       train_data: Training data. | ||||||
|       validation_data: Validation data. |       validation_data: Validation data. | ||||||
|       options: Options for creating and training the text classifier. |       options: Options for creating and training the text classifier. | ||||||
|       label_names: Label names used in the data. |  | ||||||
| 
 | 
 | ||||||
|     Returns: |     Returns: | ||||||
|       An Average Word Embedding classifier. |       An Average Word Embedding classifier. | ||||||
|  | @ -370,28 +349,25 @@ class _BertClassifier(TextClassifier): | ||||||
|     self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) |     self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) | ||||||
|     self._model_options = model_options |     self._model_options = model_options | ||||||
|     with self._hparams.get_strategy().scope(): |     with self._hparams.get_strategy().scope(): | ||||||
|       self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy() |       self._loss_function = loss_functions.SparseFocalLoss( | ||||||
|       self._metric_functions = [ |           self._hparams.gamma, self._num_classes | ||||||
|           tf.keras.metrics.SparseCategoricalAccuracy( |       ) | ||||||
|               "test_accuracy", dtype=tf.float32 |       self._metric_functions = self._create_metrics() | ||||||
|           ), |       self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None | ||||||
|           metrics.SparsePrecision(name="precision", dtype=tf.float32), |  | ||||||
|           metrics.SparseRecall(name="recall", dtype=tf.float32), |  | ||||||
|       ] |  | ||||||
|     self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None |  | ||||||
| 
 | 
 | ||||||
|   @classmethod |   @classmethod | ||||||
|   def create_bert_classifier( |   def create_bert_classifier( | ||||||
|       cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset, |       cls, | ||||||
|  |       train_data: text_ds.Dataset, | ||||||
|  |       validation_data: text_ds.Dataset, | ||||||
|       options: text_classifier_options.TextClassifierOptions, |       options: text_classifier_options.TextClassifierOptions, | ||||||
|       label_names: Sequence[str]) -> "_BertClassifier": |   ) -> "_BertClassifier": | ||||||
|     """Creates, trains, and returns a BERT-based classifier. |     """Creates, trains, and returns a BERT-based classifier. | ||||||
| 
 | 
 | ||||||
|     Args: |     Args: | ||||||
|       train_data: Training data. |       train_data: Training data. | ||||||
|       validation_data: Validation data. |       validation_data: Validation data. | ||||||
|       options: Options for creating and training the text classifier. |       options: Options for creating and training the text classifier. | ||||||
|       label_names: Label names used in the data. |  | ||||||
| 
 | 
 | ||||||
|     Returns: |     Returns: | ||||||
|       A BERT-based classifier. |       A BERT-based classifier. | ||||||
|  | @ -437,8 +413,57 @@ class _BertClassifier(TextClassifier): | ||||||
|         uri=self._model_spec.downloaded_files.get_path(), |         uri=self._model_spec.downloaded_files.get_path(), | ||||||
|         model_name=self._model_spec.name, |         model_name=self._model_spec.name, | ||||||
|     ) |     ) | ||||||
|     return (self._text_preprocessor.preprocess(train_data), |     return ( | ||||||
|             self._text_preprocessor.preprocess(validation_data)) |         self._text_preprocessor.preprocess(train_data), | ||||||
|  |         self._text_preprocessor.preprocess(validation_data), | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |   def _create_metrics(self): | ||||||
|  |     """Creates metrics for training and evaluation. | ||||||
|  | 
 | ||||||
|  |     The default metrics are accuracy, precision, and recall. | ||||||
|  | 
 | ||||||
|  |     For binary classification tasks only (num_classes=2): | ||||||
|  |       Users can configure PrecisionAtRecall and RecallAtPrecision metrics using | ||||||
|  |       the desired_presisions and desired_recalls fields in BertHParams. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |       A list of tf.keras.Metric subclasses which can be used with model.compile | ||||||
|  |     """ | ||||||
|  |     metric_functions = [ | ||||||
|  |         tf.keras.metrics.SparseCategoricalAccuracy( | ||||||
|  |             "accuracy", dtype=tf.float32 | ||||||
|  |         ), | ||||||
|  |         metrics.SparsePrecision(name="precision", dtype=tf.float32), | ||||||
|  |         metrics.SparseRecall(name="recall", dtype=tf.float32), | ||||||
|  |     ] | ||||||
|  |     if self._num_classes == 2: | ||||||
|  |       if self._hparams.desired_precisions: | ||||||
|  |         for desired_precision in self._hparams.desired_precisions: | ||||||
|  |           metric_functions.append( | ||||||
|  |               metrics.BinarySparseRecallAtPrecision( | ||||||
|  |                   desired_precision, | ||||||
|  |                   name=f"recall_at_precision_{desired_precision}", | ||||||
|  |                   num_thresholds=1000, | ||||||
|  |               ) | ||||||
|  |           ) | ||||||
|  |       if self._hparams.desired_recalls: | ||||||
|  |         for desired_recall in self._hparams.desired_recalls: | ||||||
|  |           metric_functions.append( | ||||||
|  |               metrics.BinarySparseRecallAtPrecision( | ||||||
|  |                   desired_recall, | ||||||
|  |                   name=f"precision_at_recall_{desired_recall}", | ||||||
|  |                   num_thresholds=1000, | ||||||
|  |               ) | ||||||
|  |           ) | ||||||
|  |     else: | ||||||
|  |       if self._hparams.desired_precisions or self._hparams.desired_recalls: | ||||||
|  |         raise ValueError( | ||||||
|  |             "desired_recalls and desired_precisions parameters are binary" | ||||||
|  |             " metrics and not supported for num_classes > 2. Found" | ||||||
|  |             f" num_classes: {self._num_classes}" | ||||||
|  |         ) | ||||||
|  |     return metric_functions | ||||||
| 
 | 
 | ||||||
|   def _create_model(self): |   def _create_model(self): | ||||||
|     """Creates a BERT-based classifier model. |     """Creates a BERT-based classifier model. | ||||||
|  | @ -448,11 +473,20 @@ class _BertClassifier(TextClassifier): | ||||||
|     """ |     """ | ||||||
|     encoder_inputs = dict( |     encoder_inputs = dict( | ||||||
|         input_word_ids=tf.keras.layers.Input( |         input_word_ids=tf.keras.layers.Input( | ||||||
|             shape=(self._model_options.seq_len,), dtype=tf.int32), |             shape=(self._model_options.seq_len,), | ||||||
|  |             dtype=tf.int32, | ||||||
|  |             name="input_word_ids", | ||||||
|  |         ), | ||||||
|         input_mask=tf.keras.layers.Input( |         input_mask=tf.keras.layers.Input( | ||||||
|             shape=(self._model_options.seq_len,), dtype=tf.int32), |             shape=(self._model_options.seq_len,), | ||||||
|  |             dtype=tf.int32, | ||||||
|  |             name="input_mask", | ||||||
|  |         ), | ||||||
|         input_type_ids=tf.keras.layers.Input( |         input_type_ids=tf.keras.layers.Input( | ||||||
|             shape=(self._model_options.seq_len,), dtype=tf.int32), |             shape=(self._model_options.seq_len,), | ||||||
|  |             dtype=tf.int32, | ||||||
|  |             name="input_type_ids", | ||||||
|  |         ), | ||||||
|     ) |     ) | ||||||
|     encoder = hub.KerasLayer( |     encoder = hub.KerasLayer( | ||||||
|         self._model_spec.downloaded_files.get_path(), |         self._model_spec.downloaded_files.get_path(), | ||||||
|  | @ -494,16 +528,21 @@ class _BertClassifier(TextClassifier): | ||||||
|     lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( |     lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( | ||||||
|         initial_learning_rate=initial_lr, |         initial_learning_rate=initial_lr, | ||||||
|         decay_steps=total_steps, |         decay_steps=total_steps, | ||||||
|         end_learning_rate=0.0, |         end_learning_rate=self._hparams.end_learning_rate, | ||||||
|         power=1.0) |         power=1.0, | ||||||
|  |     ) | ||||||
|     if warmup_steps: |     if warmup_steps: | ||||||
|       lr_schedule = model_util.WarmUp( |       lr_schedule = model_util.WarmUp( | ||||||
|           initial_learning_rate=initial_lr, |           initial_learning_rate=initial_lr, | ||||||
|           decay_schedule_fn=lr_schedule, |           decay_schedule_fn=lr_schedule, | ||||||
|           warmup_steps=warmup_steps) |           warmup_steps=warmup_steps, | ||||||
|  |       ) | ||||||
|     if self._hparams.optimizer == hp.BertOptimizer.ADAMW: |     if self._hparams.optimizer == hp.BertOptimizer.ADAMW: | ||||||
|       self._optimizer = tf.keras.optimizers.experimental.AdamW( |       self._optimizer = tf.keras.optimizers.experimental.AdamW( | ||||||
|           lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0 |           lr_schedule, | ||||||
|  |           weight_decay=self._hparams.weight_decay, | ||||||
|  |           epsilon=1e-6, | ||||||
|  |           global_clipnorm=1.0, | ||||||
|       ) |       ) | ||||||
|       self._optimizer.exclude_from_weight_decay( |       self._optimizer.exclude_from_weight_decay( | ||||||
|           var_names=["LayerNorm", "layer_norm", "bias"] |           var_names=["LayerNorm", "layer_norm", "bias"] | ||||||
|  | @ -511,7 +550,7 @@ class _BertClassifier(TextClassifier): | ||||||
|     elif self._hparams.optimizer == hp.BertOptimizer.LAMB: |     elif self._hparams.optimizer == hp.BertOptimizer.LAMB: | ||||||
|       self._optimizer = tfa_optimizers.LAMB( |       self._optimizer = tfa_optimizers.LAMB( | ||||||
|           lr_schedule, |           lr_schedule, | ||||||
|           weight_decay_rate=0.01, |           weight_decay_rate=self._hparams.weight_decay, | ||||||
|           epsilon=1e-6, |           epsilon=1e-6, | ||||||
|           exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], |           exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], | ||||||
|           global_clipnorm=1.0, |           global_clipnorm=1.0, | ||||||
|  |  | ||||||
|  | @ -16,17 +16,17 @@ import csv | ||||||
| import filecmp | import filecmp | ||||||
| import os | import os | ||||||
| import tempfile | import tempfile | ||||||
| import unittest |  | ||||||
| from unittest import mock as unittest_mock | from unittest import mock as unittest_mock | ||||||
| 
 | 
 | ||||||
|  | from absl.testing import parameterized | ||||||
| import tensorflow as tf | import tensorflow as tf | ||||||
| 
 | 
 | ||||||
|  | from mediapipe.model_maker.python.core.utils import loss_functions | ||||||
| from mediapipe.model_maker.python.text import text_classifier | from mediapipe.model_maker.python.text import text_classifier | ||||||
| from mediapipe.tasks.python.test import test_utils | from mediapipe.tasks.python.test import test_utils | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @unittest.skip('b/275624089') | class TextClassifierTest(tf.test.TestCase, parameterized.TestCase): | ||||||
| class TextClassifierTest(tf.test.TestCase): |  | ||||||
| 
 | 
 | ||||||
|   _AVERAGE_WORD_EMBEDDING_JSON_FILE = ( |   _AVERAGE_WORD_EMBEDDING_JSON_FILE = ( | ||||||
|       test_utils.get_test_data_path('average_word_embedding_metadata.json')) |       test_utils.get_test_data_path('average_word_embedding_metadata.json')) | ||||||
|  | @ -78,8 +78,8 @@ class TextClassifierTest(tf.test.TestCase): | ||||||
|         text_classifier.TextClassifier.create(train_data, validation_data, |         text_classifier.TextClassifier.create(train_data, validation_data, | ||||||
|                                               options)) |                                               options)) | ||||||
| 
 | 
 | ||||||
|     _, accuracy = average_word_embedding_classifier.evaluate(validation_data) |     metrics = average_word_embedding_classifier.evaluate(validation_data) | ||||||
|     self.assertGreaterEqual(accuracy, 0.0) |     self.assertGreaterEqual(metrics[1], 0.0)  # metrics[1] is accuracy | ||||||
| 
 | 
 | ||||||
|     # Test export_model |     # Test export_model | ||||||
|     average_word_embedding_classifier.export_model() |     average_word_embedding_classifier.export_model() | ||||||
|  | @ -98,12 +98,25 @@ class TextClassifierTest(tf.test.TestCase): | ||||||
|         filecmp.cmp( |         filecmp.cmp( | ||||||
|             output_metadata_file, |             output_metadata_file, | ||||||
|             self._AVERAGE_WORD_EMBEDDING_JSON_FILE, |             self._AVERAGE_WORD_EMBEDDING_JSON_FILE, | ||||||
|             shallow=False)) |             shallow=False, | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
| 
 | 
 | ||||||
|   def test_create_and_train_bert(self): |   @parameterized.named_parameters( | ||||||
|  |       # Skipping mobilebert b/c OSS test timeout/flakiness: b/275624089 | ||||||
|  |       # dict( | ||||||
|  |       #     testcase_name='mobilebert', | ||||||
|  |       #     supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER, | ||||||
|  |       # ), | ||||||
|  |       dict( | ||||||
|  |           testcase_name='exbert', | ||||||
|  |           supported_model=text_classifier.SupportedModels.EXBERT_CLASSIFIER, | ||||||
|  |       ), | ||||||
|  |   ) | ||||||
|  |   def test_create_and_train_bert(self, supported_model): | ||||||
|     train_data, validation_data = self._get_data() |     train_data, validation_data = self._get_data() | ||||||
|     options = text_classifier.TextClassifierOptions( |     options = text_classifier.TextClassifierOptions( | ||||||
|         supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER, |         supported_model=supported_model, | ||||||
|         model_options=text_classifier.BertModelOptions( |         model_options=text_classifier.BertModelOptions( | ||||||
|             do_fine_tuning=False, seq_len=2 |             do_fine_tuning=False, seq_len=2 | ||||||
|         ), |         ), | ||||||
|  | @ -117,8 +130,8 @@ class TextClassifierTest(tf.test.TestCase): | ||||||
|     bert_classifier = text_classifier.TextClassifier.create( |     bert_classifier = text_classifier.TextClassifier.create( | ||||||
|         train_data, validation_data, options) |         train_data, validation_data, options) | ||||||
| 
 | 
 | ||||||
|     _, accuracy = bert_classifier.evaluate(validation_data) |     metrics = bert_classifier.evaluate(validation_data) | ||||||
|     self.assertGreaterEqual(accuracy, 0.0) |     self.assertGreaterEqual(metrics[1], 0.0)  # metrics[1] is accuracy | ||||||
| 
 | 
 | ||||||
|     # Test export_model |     # Test export_model | ||||||
|     bert_classifier.export_model() |     bert_classifier.export_model() | ||||||
|  | @ -142,45 +155,93 @@ class TextClassifierTest(tf.test.TestCase): | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|   def test_label_mismatch(self): |   def test_label_mismatch(self): | ||||||
|     options = ( |     options = text_classifier.TextClassifierOptions( | ||||||
|         text_classifier.TextClassifierOptions( |         supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER) | ||||||
|             supported_model=( |     ) | ||||||
|                 text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER))) |  | ||||||
|     train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) |     train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) | ||||||
|     train_data = text_classifier.Dataset(train_tf_dataset, 1, ['foo']) |     train_data = text_classifier.Dataset(train_tf_dataset, ['foo'], 1) | ||||||
|     validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) |     validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) | ||||||
|     validation_data = text_classifier.Dataset(validation_tf_dataset, 1, ['bar']) |     validation_data = text_classifier.Dataset(validation_tf_dataset, ['bar'], 1) | ||||||
|     with self.assertRaisesRegex( |     with self.assertRaisesRegex( | ||||||
|         ValueError, |         ValueError, | ||||||
|         'Training data label names .* not equal to validation data label names' |         'Training data label names .* not equal to validation data label names', | ||||||
|     ): |     ): | ||||||
|       text_classifier.TextClassifier.create(train_data, validation_data, |       text_classifier.TextClassifier.create( | ||||||
|                                             options) |           train_data, validation_data, options | ||||||
|  |       ) | ||||||
| 
 | 
 | ||||||
|   def test_options_mismatch(self): |   def test_options_mismatch(self): | ||||||
|     train_data, validation_data = self._get_data() |     train_data, validation_data = self._get_data() | ||||||
| 
 | 
 | ||||||
|     avg_options = ( |     avg_options = text_classifier.TextClassifierOptions( | ||||||
|         text_classifier.TextClassifierOptions( |         supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER), | ||||||
|             supported_model=( |         model_options=text_classifier.AverageWordEmbeddingModelOptions(), | ||||||
|                 text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER), |     ) | ||||||
|             model_options=text_classifier.AverageWordEmbeddingModelOptions())) |     with self.assertRaisesWithLiteralMatch( | ||||||
|     with self.assertRaisesRegex( |         ValueError, | ||||||
|         ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got' |         'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got' | ||||||
|         ' SupportedModels.MOBILEBERT_CLASSIFIER'): |         ' SupportedModels.EXBERT_CLASSIFIER', | ||||||
|       text_classifier.TextClassifier.create(train_data, validation_data, |     ): | ||||||
|                                             avg_options) |       text_classifier.TextClassifier.create( | ||||||
|  |           train_data, validation_data, avg_options | ||||||
|  |       ) | ||||||
| 
 | 
 | ||||||
|     bert_options = ( |     bert_options = text_classifier.TextClassifierOptions( | ||||||
|         text_classifier.TextClassifierOptions( |         supported_model=( | ||||||
|             supported_model=(text_classifier.SupportedModels |             text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER | ||||||
|                              .AVERAGE_WORD_EMBEDDING_CLASSIFIER), |         ), | ||||||
|             model_options=text_classifier.BertModelOptions())) |         model_options=text_classifier.BertModelOptions(), | ||||||
|     with self.assertRaisesRegex( |     ) | ||||||
|         ValueError, 'Expected MOBILEBERT_CLASSIFIER, got' |     with self.assertRaisesWithLiteralMatch( | ||||||
|         ' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'): |         ValueError, | ||||||
|       text_classifier.TextClassifier.create(train_data, validation_data, |         'Expected a Bert Classifier(MobileBERT or EXBERT), got' | ||||||
|                                             bert_options) |         ' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER', | ||||||
|  |     ): | ||||||
|  |       text_classifier.TextClassifier.create( | ||||||
|  |           train_data, validation_data, bert_options | ||||||
|  |       ) | ||||||
|  | 
 | ||||||
|  |   def test_bert_loss_and_metrics_creation(self): | ||||||
|  |     train_data, validation_data = self._get_data() | ||||||
|  |     supported_model = text_classifier.SupportedModels.EXBERT_CLASSIFIER | ||||||
|  |     hparams = text_classifier.BertHParams( | ||||||
|  |         desired_recalls=[0.2], | ||||||
|  |         desired_precisions=[0.9], | ||||||
|  |         epochs=1, | ||||||
|  |         batch_size=1, | ||||||
|  |         learning_rate=3e-5, | ||||||
|  |         distribution_strategy='off', | ||||||
|  |         gamma=3.5, | ||||||
|  |     ) | ||||||
|  |     options = text_classifier.TextClassifierOptions( | ||||||
|  |         supported_model=supported_model, hparams=hparams | ||||||
|  |     ) | ||||||
|  |     bert_classifier = text_classifier.TextClassifier.create( | ||||||
|  |         train_data, validation_data, options | ||||||
|  |     ) | ||||||
|  |     loss_fn = bert_classifier._loss_function | ||||||
|  |     self.assertIsInstance(loss_fn, loss_functions.SparseFocalLoss) | ||||||
|  |     self.assertEqual(loss_fn._gamma, 3.5) | ||||||
|  |     self.assertEqual(loss_fn._num_classes, 2) | ||||||
|  |     metric_names = [m.name for m in bert_classifier._metric_functions] | ||||||
|  |     expected_metric_names = [ | ||||||
|  |         'accuracy', | ||||||
|  |         'recall', | ||||||
|  |         'precision', | ||||||
|  |         'precision_at_recall_0.2', | ||||||
|  |         'recall_at_precision_0.9', | ||||||
|  |     ] | ||||||
|  |     self.assertCountEqual(metric_names, expected_metric_names) | ||||||
|  | 
 | ||||||
|  |     # Non-binary data | ||||||
|  |     tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) | ||||||
|  |     data = text_classifier.Dataset(tf_dataset, ['foo', 'bar', 'baz'], 1) | ||||||
|  |     with self.assertRaisesWithLiteralMatch( | ||||||
|  |         ValueError, | ||||||
|  |         'desired_recalls and desired_precisions parameters are binary metrics' | ||||||
|  |         ' and not supported for num_classes > 2. Found num_classes: 3', | ||||||
|  |     ): | ||||||
|  |       text_classifier.TextClassifier.create(data, data, options) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user