1. Move evaluation onto GPU/TPU hardware if available.
2. Move desired_precision and desired_recall from evaluate to hyperparameters so recall@precision metrics will be reported for both training and evaluation. This also fixes a bug where recompiling the model with the previously initialized metric objects would not properly reset the metric states. 3. Remove redundant label_names from create_... class methods in text_classifier. This information is already provided by the datasets. 4. Change loss function to FocalLoss. 5. Re-enable text_classifier unit tests using ExBert 6. Add input names to avoid flaky auto-assigned input names. PiperOrigin-RevId: 550992146
This commit is contained in:
		
							parent
							
								
									85c3fed70a
								
							
						
					
					
						commit
						bd7888cc0c
					
				|  | @ -59,7 +59,7 @@ class FocalLoss(tf.keras.losses.Loss): | |||
|   """ | ||||
| 
 | ||||
|   def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None): | ||||
|     """Constructor. | ||||
|     """Initializes FocalLoss. | ||||
| 
 | ||||
|     Args: | ||||
|       gamma: Focal loss gamma, as described in class docs. | ||||
|  | @ -115,6 +115,51 @@ class FocalLoss(tf.keras.losses.Loss): | |||
|     return tf.reduce_sum(losses) / batch_size | ||||
| 
 | ||||
| 
 | ||||
| class SparseFocalLoss(FocalLoss): | ||||
|   """Sparse implementation of Focal Loss. | ||||
| 
 | ||||
|   This is the same as FocalLoss, except the labels are expected to be class ids | ||||
|   instead of 1-hot encoded vectors. See FocalLoss class documentation defined | ||||
|   in this same file for more details. | ||||
| 
 | ||||
|   Example usage: | ||||
|   >>> y_true = [1, 2] | ||||
|   >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] | ||||
|   >>> gamma = 2 | ||||
|   >>> focal_loss = SparseFocalLoss(gamma, 3) | ||||
|   >>> focal_loss(y_true, y_pred).numpy() | ||||
|   0.9326 | ||||
| 
 | ||||
|   >>> # Calling with 'sample_weight'. | ||||
|   >>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy() | ||||
|   0.6528 | ||||
|   """ | ||||
| 
 | ||||
|   def __init__( | ||||
|       self, gamma, num_classes, class_weight: Optional[Sequence[float]] = None | ||||
|   ): | ||||
|     """Initializes SparseFocalLoss. | ||||
| 
 | ||||
|     Args: | ||||
|       gamma: Focal loss gamma, as described in class docs. | ||||
|       num_classes: Number of classes. | ||||
|       class_weight: A weight to apply to the loss, one for each class. The | ||||
|         weight is applied for each input where the ground truth label matches. | ||||
|     """ | ||||
|     super().__init__(gamma, class_weight=class_weight) | ||||
|     self._num_classes = num_classes | ||||
| 
 | ||||
|   def __call__( | ||||
|       self, | ||||
|       y_true: tf.Tensor, | ||||
|       y_pred: tf.Tensor, | ||||
|       sample_weight: Optional[tf.Tensor] = None, | ||||
|   ) -> tf.Tensor: | ||||
|     y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32) | ||||
|     y_true_one_hot = tf.one_hot(y_true, self._num_classes) | ||||
|     return super().__call__(y_true_one_hot, y_pred, sample_weight=sample_weight) | ||||
| 
 | ||||
| 
 | ||||
| @dataclasses.dataclass | ||||
| class PerceptualLossWeight: | ||||
|   """The weight for each perceptual loss. | ||||
|  |  | |||
|  | @ -101,6 +101,23 @@ class FocalLossTest(tf.test.TestCase, parameterized.TestCase): | |||
|     self.assertNear(loss, expected_loss, 1e-4) | ||||
| 
 | ||||
| 
 | ||||
| class SparseFocalLossTest(tf.test.TestCase): | ||||
| 
 | ||||
|   def test_sparse_focal_loss_matches_focal_loss(self): | ||||
|     num_classes = 2 | ||||
|     y_pred = tf.constant([[0.8, 0.2], [0.3, 0.7]]) | ||||
|     y_true = tf.constant([1, 0]) | ||||
|     y_true_one_hot = tf.one_hot(y_true, num_classes) | ||||
|     for gamma in [0.0, 0.5, 1.0]: | ||||
|       expected_loss_fn = loss_functions.FocalLoss(gamma=gamma) | ||||
|       loss_fn = loss_functions.SparseFocalLoss( | ||||
|           gamma=gamma, num_classes=num_classes | ||||
|       ) | ||||
|       expected_loss = expected_loss_fn(y_true_one_hot, y_pred) | ||||
|       loss = loss_fn(y_true, y_pred) | ||||
|       self.assertNear(loss, expected_loss, 1e-4) | ||||
| 
 | ||||
| 
 | ||||
| class MockPerceptualLoss(loss_functions.PerceptualLoss): | ||||
|   """A mock class with implementation of abstract methods for testing.""" | ||||
| 
 | ||||
|  |  | |||
|  | @ -131,6 +131,7 @@ py_library( | |||
|         ":text_classifier_options", | ||||
|         "//mediapipe/model_maker/python/core/data:dataset", | ||||
|         "//mediapipe/model_maker/python/core/tasks:classifier", | ||||
|         "//mediapipe/model_maker/python/core/utils:loss_functions", | ||||
|         "//mediapipe/model_maker/python/core/utils:metrics", | ||||
|         "//mediapipe/model_maker/python/core/utils:model_util", | ||||
|         "//mediapipe/model_maker/python/core/utils:quantization", | ||||
|  | @ -154,6 +155,7 @@ py_test( | |||
|     ], | ||||
|     deps = [ | ||||
|         ":text_classifier_import", | ||||
|         "//mediapipe/model_maker/python/core/utils:loss_functions", | ||||
|         "//mediapipe/tasks/python/test:test_utils", | ||||
|     ], | ||||
| ) | ||||
|  |  | |||
|  | @ -15,7 +15,7 @@ | |||
| 
 | ||||
| import dataclasses | ||||
| import enum | ||||
| from typing import Union | ||||
| from typing import Sequence, Union | ||||
| 
 | ||||
| from mediapipe.model_maker.python.core import hyperparameters as hp | ||||
| 
 | ||||
|  | @ -39,16 +39,34 @@ class BertHParams(hp.BaseHParams): | |||
| 
 | ||||
|   Attributes: | ||||
|     learning_rate: Learning rate to use for gradient descent training. | ||||
|     batch_size: Batch size for training. | ||||
|     epochs: Number of training iterations over the dataset. | ||||
|     optimizer: Optimizer to use for training. Only supported values are "adamw" | ||||
|       and "lamb". | ||||
|     end_learning_rate: End learning rate for linear decay. Defaults to 0. | ||||
|     batch_size: Batch size for training. Defaults to 48. | ||||
|     epochs: Number of training iterations over the dataset. Defaults to 2. | ||||
|     optimizer: Optimizer to use for training. Supported values are defined in | ||||
|       BertOptimizer enum: ADAMW and LAMB. | ||||
|     weight_decay: Weight decay of the optimizer. Defaults to 0.01. | ||||
|     desired_precisions: If specified, adds a RecallAtPrecision metric per | ||||
|       desired_precisions[i] entry which tracks the recall given the constraint | ||||
|       on precision. Only supported for binary classification. | ||||
|     desired_recalls: If specified, adds a PrecisionAtRecall metric per | ||||
|       desired_recalls[i] entry which tracks the precision given the constraint | ||||
|       on recall. Only supported for binary classification. | ||||
|     gamma: Gamma parameter for focal loss. To use cross entropy loss, set this | ||||
|       value to 0. Defaults to 2.0. | ||||
|   """ | ||||
| 
 | ||||
|   learning_rate: float = 3e-5 | ||||
|   end_learning_rate: float = 0.0 | ||||
| 
 | ||||
|   batch_size: int = 48 | ||||
|   epochs: int = 2 | ||||
|   optimizer: BertOptimizer = BertOptimizer.ADAMW | ||||
|   weight_decay: float = 0.01 | ||||
| 
 | ||||
|   desired_precisions: Sequence[float] = dataclasses.field(default_factory=list) | ||||
|   desired_recalls: Sequence[float] = dataclasses.field(default_factory=list) | ||||
| 
 | ||||
|   gamma: float = 2.0 | ||||
| 
 | ||||
| 
 | ||||
| HParams = Union[BertHParams, AverageWordEmbeddingHParams] | ||||
|  |  | |||
|  | @ -79,11 +79,6 @@ mobilebert_classifier_spec = functools.partial( | |||
|         epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' | ||||
|     ), | ||||
|     name='MobileBert', | ||||
|     tflite_input_name={ | ||||
|         'ids': 'serving_default_input_1:0', | ||||
|         'segment_ids': 'serving_default_input_2:0', | ||||
|         'mask': 'serving_default_input_3:0', | ||||
|     }, | ||||
| ) | ||||
| 
 | ||||
| exbert_classifier_spec = functools.partial( | ||||
|  | @ -93,11 +88,6 @@ exbert_classifier_spec = functools.partial( | |||
|         epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' | ||||
|     ), | ||||
|     name='ExBert', | ||||
|     tflite_input_name={ | ||||
|         'ids': 'serving_default_input_1:0', | ||||
|         'segment_ids': 'serving_default_input_2:0', | ||||
|         'mask': 'serving_default_input_3:0', | ||||
|     }, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -46,11 +46,13 @@ class ModelSpecTest(tf.test.TestCase): | |||
|     self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path())) | ||||
|     self.assertTrue(model_spec_obj.do_lower_case) | ||||
|     self.assertEqual( | ||||
|         model_spec_obj.tflite_input_name, { | ||||
|             'ids': 'serving_default_input_1:0', | ||||
|             'mask': 'serving_default_input_3:0', | ||||
|             'segment_ids': 'serving_default_input_2:0' | ||||
|         }) | ||||
|         model_spec_obj.tflite_input_name, | ||||
|         { | ||||
|             'ids': 'serving_default_input_word_ids:0', | ||||
|             'mask': 'serving_default_input_mask:0', | ||||
|             'segment_ids': 'serving_default_input_type_ids:0', | ||||
|         }, | ||||
|     ) | ||||
|     self.assertEqual( | ||||
|         model_spec_obj.model_options, | ||||
|         classifier_model_options.BertModelOptions( | ||||
|  |  | |||
|  | @ -16,8 +16,8 @@ | |||
|           } | ||||
|         }, | ||||
|         { | ||||
|           "name": "mask", | ||||
|           "description": "Mask with 1 for real tokens and 0 for padding tokens.", | ||||
|           "name": "segment_ids", | ||||
|           "description": "0 for the first sequence, 1 for the second sequence if exists.", | ||||
|           "content": { | ||||
|             "content_properties_type": "FeatureProperties", | ||||
|             "content_properties": { | ||||
|  | @ -27,8 +27,8 @@ | |||
|           } | ||||
|         }, | ||||
|         { | ||||
|           "name": "segment_ids", | ||||
|           "description": "0 for the first sequence, 1 for the second sequence if exists.", | ||||
|           "name": "mask", | ||||
|           "description": "Mask with 1 for real tokens and 0 for padding tokens.", | ||||
|           "content": { | ||||
|             "content_properties_type": "FeatureProperties", | ||||
|             "content_properties": { | ||||
|  |  | |||
|  | @ -24,6 +24,7 @@ import tensorflow_hub as hub | |||
| 
 | ||||
| from mediapipe.model_maker.python.core.data import dataset as ds | ||||
| from mediapipe.model_maker.python.core.tasks import classifier | ||||
| from mediapipe.model_maker.python.core.utils import loss_functions | ||||
| from mediapipe.model_maker.python.core.utils import metrics | ||||
| from mediapipe.model_maker.python.core.utils import model_util | ||||
| from mediapipe.model_maker.python.core.utils import quantization | ||||
|  | @ -116,17 +117,14 @@ class TextClassifier(classifier.Classifier): | |||
|         options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER | ||||
|         or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER | ||||
|     ): | ||||
|       text_classifier = ( | ||||
|           _BertClassifier.create_bert_classifier(train_data, validation_data, | ||||
|                                                  options, | ||||
|                                                  train_data.label_names)) | ||||
|       text_classifier = _BertClassifier.create_bert_classifier( | ||||
|           train_data, validation_data, options | ||||
|       ) | ||||
|     elif (options.supported_model == | ||||
|           ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER): | ||||
|       text_classifier = ( | ||||
|           _AverageWordEmbeddingClassifier | ||||
|           .create_average_word_embedding_classifier(train_data, validation_data, | ||||
|                                                     options, | ||||
|                                                     train_data.label_names)) | ||||
|       text_classifier = _AverageWordEmbeddingClassifier.create_average_word_embedding_classifier( | ||||
|           train_data, validation_data, options | ||||
|       ) | ||||
|     else: | ||||
|       raise ValueError(f"Unknown model {options.supported_model}") | ||||
| 
 | ||||
|  | @ -166,27 +164,7 @@ class TextClassifier(classifier.Classifier): | |||
|     processed_data = self._text_preprocessor.preprocess(data) | ||||
|     dataset = processed_data.gen_tf_dataset(batch_size, is_training=False) | ||||
| 
 | ||||
|     additional_metrics = [] | ||||
|     if desired_precisions and len(data.label_names) == 2: | ||||
|       for precision in desired_precisions: | ||||
|         additional_metrics.append( | ||||
|             metrics.BinarySparseRecallAtPrecision( | ||||
|                 precision, name=f"recall_at_precision_{precision}" | ||||
|             ) | ||||
|         ) | ||||
|     if desired_recalls and len(data.label_names) == 2: | ||||
|       for recall in desired_recalls: | ||||
|         additional_metrics.append( | ||||
|             metrics.BinarySparsePrecisionAtRecall( | ||||
|                 recall, name=f"precision_at_recall_{recall}" | ||||
|             ) | ||||
|         ) | ||||
|     metric_functions = self._metric_functions + additional_metrics | ||||
|     self._model.compile( | ||||
|         optimizer=self._optimizer, | ||||
|         loss=self._loss_function, | ||||
|         metrics=metric_functions, | ||||
|     ) | ||||
|     with self._hparams.get_strategy().scope(): | ||||
|       return self._model.evaluate(dataset) | ||||
| 
 | ||||
|   def export_model( | ||||
|  | @ -255,16 +233,17 @@ class _AverageWordEmbeddingClassifier(TextClassifier): | |||
| 
 | ||||
|   @classmethod | ||||
|   def create_average_word_embedding_classifier( | ||||
|       cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset, | ||||
|       cls, | ||||
|       train_data: text_ds.Dataset, | ||||
|       validation_data: text_ds.Dataset, | ||||
|       options: text_classifier_options.TextClassifierOptions, | ||||
|       label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier": | ||||
|   ) -> "_AverageWordEmbeddingClassifier": | ||||
|     """Creates, trains, and returns an Average Word Embedding classifier. | ||||
| 
 | ||||
|     Args: | ||||
|       train_data: Training data. | ||||
|       validation_data: Validation data. | ||||
|       options: Options for creating and training the text classifier. | ||||
|       label_names: Label names used in the data. | ||||
| 
 | ||||
|     Returns: | ||||
|       An Average Word Embedding classifier. | ||||
|  | @ -370,28 +349,25 @@ class _BertClassifier(TextClassifier): | |||
|     self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) | ||||
|     self._model_options = model_options | ||||
|     with self._hparams.get_strategy().scope(): | ||||
|       self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy() | ||||
|       self._metric_functions = [ | ||||
|           tf.keras.metrics.SparseCategoricalAccuracy( | ||||
|               "test_accuracy", dtype=tf.float32 | ||||
|           ), | ||||
|           metrics.SparsePrecision(name="precision", dtype=tf.float32), | ||||
|           metrics.SparseRecall(name="recall", dtype=tf.float32), | ||||
|       ] | ||||
|       self._loss_function = loss_functions.SparseFocalLoss( | ||||
|           self._hparams.gamma, self._num_classes | ||||
|       ) | ||||
|       self._metric_functions = self._create_metrics() | ||||
|       self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None | ||||
| 
 | ||||
|   @classmethod | ||||
|   def create_bert_classifier( | ||||
|       cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset, | ||||
|       cls, | ||||
|       train_data: text_ds.Dataset, | ||||
|       validation_data: text_ds.Dataset, | ||||
|       options: text_classifier_options.TextClassifierOptions, | ||||
|       label_names: Sequence[str]) -> "_BertClassifier": | ||||
|   ) -> "_BertClassifier": | ||||
|     """Creates, trains, and returns a BERT-based classifier. | ||||
| 
 | ||||
|     Args: | ||||
|       train_data: Training data. | ||||
|       validation_data: Validation data. | ||||
|       options: Options for creating and training the text classifier. | ||||
|       label_names: Label names used in the data. | ||||
| 
 | ||||
|     Returns: | ||||
|       A BERT-based classifier. | ||||
|  | @ -437,8 +413,57 @@ class _BertClassifier(TextClassifier): | |||
|         uri=self._model_spec.downloaded_files.get_path(), | ||||
|         model_name=self._model_spec.name, | ||||
|     ) | ||||
|     return (self._text_preprocessor.preprocess(train_data), | ||||
|             self._text_preprocessor.preprocess(validation_data)) | ||||
|     return ( | ||||
|         self._text_preprocessor.preprocess(train_data), | ||||
|         self._text_preprocessor.preprocess(validation_data), | ||||
|     ) | ||||
| 
 | ||||
|   def _create_metrics(self): | ||||
|     """Creates metrics for training and evaluation. | ||||
| 
 | ||||
|     The default metrics are accuracy, precision, and recall. | ||||
| 
 | ||||
|     For binary classification tasks only (num_classes=2): | ||||
|       Users can configure PrecisionAtRecall and RecallAtPrecision metrics using | ||||
|       the desired_presisions and desired_recalls fields in BertHParams. | ||||
| 
 | ||||
|     Returns: | ||||
|       A list of tf.keras.Metric subclasses which can be used with model.compile | ||||
|     """ | ||||
|     metric_functions = [ | ||||
|         tf.keras.metrics.SparseCategoricalAccuracy( | ||||
|             "accuracy", dtype=tf.float32 | ||||
|         ), | ||||
|         metrics.SparsePrecision(name="precision", dtype=tf.float32), | ||||
|         metrics.SparseRecall(name="recall", dtype=tf.float32), | ||||
|     ] | ||||
|     if self._num_classes == 2: | ||||
|       if self._hparams.desired_precisions: | ||||
|         for desired_precision in self._hparams.desired_precisions: | ||||
|           metric_functions.append( | ||||
|               metrics.BinarySparseRecallAtPrecision( | ||||
|                   desired_precision, | ||||
|                   name=f"recall_at_precision_{desired_precision}", | ||||
|                   num_thresholds=1000, | ||||
|               ) | ||||
|           ) | ||||
|       if self._hparams.desired_recalls: | ||||
|         for desired_recall in self._hparams.desired_recalls: | ||||
|           metric_functions.append( | ||||
|               metrics.BinarySparseRecallAtPrecision( | ||||
|                   desired_recall, | ||||
|                   name=f"precision_at_recall_{desired_recall}", | ||||
|                   num_thresholds=1000, | ||||
|               ) | ||||
|           ) | ||||
|     else: | ||||
|       if self._hparams.desired_precisions or self._hparams.desired_recalls: | ||||
|         raise ValueError( | ||||
|             "desired_recalls and desired_precisions parameters are binary" | ||||
|             " metrics and not supported for num_classes > 2. Found" | ||||
|             f" num_classes: {self._num_classes}" | ||||
|         ) | ||||
|     return metric_functions | ||||
| 
 | ||||
|   def _create_model(self): | ||||
|     """Creates a BERT-based classifier model. | ||||
|  | @ -448,11 +473,20 @@ class _BertClassifier(TextClassifier): | |||
|     """ | ||||
|     encoder_inputs = dict( | ||||
|         input_word_ids=tf.keras.layers.Input( | ||||
|             shape=(self._model_options.seq_len,), dtype=tf.int32), | ||||
|             shape=(self._model_options.seq_len,), | ||||
|             dtype=tf.int32, | ||||
|             name="input_word_ids", | ||||
|         ), | ||||
|         input_mask=tf.keras.layers.Input( | ||||
|             shape=(self._model_options.seq_len,), dtype=tf.int32), | ||||
|             shape=(self._model_options.seq_len,), | ||||
|             dtype=tf.int32, | ||||
|             name="input_mask", | ||||
|         ), | ||||
|         input_type_ids=tf.keras.layers.Input( | ||||
|             shape=(self._model_options.seq_len,), dtype=tf.int32), | ||||
|             shape=(self._model_options.seq_len,), | ||||
|             dtype=tf.int32, | ||||
|             name="input_type_ids", | ||||
|         ), | ||||
|     ) | ||||
|     encoder = hub.KerasLayer( | ||||
|         self._model_spec.downloaded_files.get_path(), | ||||
|  | @ -494,16 +528,21 @@ class _BertClassifier(TextClassifier): | |||
|     lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( | ||||
|         initial_learning_rate=initial_lr, | ||||
|         decay_steps=total_steps, | ||||
|         end_learning_rate=0.0, | ||||
|         power=1.0) | ||||
|         end_learning_rate=self._hparams.end_learning_rate, | ||||
|         power=1.0, | ||||
|     ) | ||||
|     if warmup_steps: | ||||
|       lr_schedule = model_util.WarmUp( | ||||
|           initial_learning_rate=initial_lr, | ||||
|           decay_schedule_fn=lr_schedule, | ||||
|           warmup_steps=warmup_steps) | ||||
|           warmup_steps=warmup_steps, | ||||
|       ) | ||||
|     if self._hparams.optimizer == hp.BertOptimizer.ADAMW: | ||||
|       self._optimizer = tf.keras.optimizers.experimental.AdamW( | ||||
|           lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0 | ||||
|           lr_schedule, | ||||
|           weight_decay=self._hparams.weight_decay, | ||||
|           epsilon=1e-6, | ||||
|           global_clipnorm=1.0, | ||||
|       ) | ||||
|       self._optimizer.exclude_from_weight_decay( | ||||
|           var_names=["LayerNorm", "layer_norm", "bias"] | ||||
|  | @ -511,7 +550,7 @@ class _BertClassifier(TextClassifier): | |||
|     elif self._hparams.optimizer == hp.BertOptimizer.LAMB: | ||||
|       self._optimizer = tfa_optimizers.LAMB( | ||||
|           lr_schedule, | ||||
|           weight_decay_rate=0.01, | ||||
|           weight_decay_rate=self._hparams.weight_decay, | ||||
|           epsilon=1e-6, | ||||
|           exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], | ||||
|           global_clipnorm=1.0, | ||||
|  |  | |||
|  | @ -16,17 +16,17 @@ import csv | |||
| import filecmp | ||||
| import os | ||||
| import tempfile | ||||
| import unittest | ||||
| from unittest import mock as unittest_mock | ||||
| 
 | ||||
| from absl.testing import parameterized | ||||
| import tensorflow as tf | ||||
| 
 | ||||
| from mediapipe.model_maker.python.core.utils import loss_functions | ||||
| from mediapipe.model_maker.python.text import text_classifier | ||||
| from mediapipe.tasks.python.test import test_utils | ||||
| 
 | ||||
| 
 | ||||
| @unittest.skip('b/275624089') | ||||
| class TextClassifierTest(tf.test.TestCase): | ||||
| class TextClassifierTest(tf.test.TestCase, parameterized.TestCase): | ||||
| 
 | ||||
|   _AVERAGE_WORD_EMBEDDING_JSON_FILE = ( | ||||
|       test_utils.get_test_data_path('average_word_embedding_metadata.json')) | ||||
|  | @ -78,8 +78,8 @@ class TextClassifierTest(tf.test.TestCase): | |||
|         text_classifier.TextClassifier.create(train_data, validation_data, | ||||
|                                               options)) | ||||
| 
 | ||||
|     _, accuracy = average_word_embedding_classifier.evaluate(validation_data) | ||||
|     self.assertGreaterEqual(accuracy, 0.0) | ||||
|     metrics = average_word_embedding_classifier.evaluate(validation_data) | ||||
|     self.assertGreaterEqual(metrics[1], 0.0)  # metrics[1] is accuracy | ||||
| 
 | ||||
|     # Test export_model | ||||
|     average_word_embedding_classifier.export_model() | ||||
|  | @ -98,12 +98,25 @@ class TextClassifierTest(tf.test.TestCase): | |||
|         filecmp.cmp( | ||||
|             output_metadata_file, | ||||
|             self._AVERAGE_WORD_EMBEDDING_JSON_FILE, | ||||
|             shallow=False)) | ||||
|             shallow=False, | ||||
|         ) | ||||
|     ) | ||||
| 
 | ||||
|   def test_create_and_train_bert(self): | ||||
|   @parameterized.named_parameters( | ||||
|       # Skipping mobilebert b/c OSS test timeout/flakiness: b/275624089 | ||||
|       # dict( | ||||
|       #     testcase_name='mobilebert', | ||||
|       #     supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER, | ||||
|       # ), | ||||
|       dict( | ||||
|           testcase_name='exbert', | ||||
|           supported_model=text_classifier.SupportedModels.EXBERT_CLASSIFIER, | ||||
|       ), | ||||
|   ) | ||||
|   def test_create_and_train_bert(self, supported_model): | ||||
|     train_data, validation_data = self._get_data() | ||||
|     options = text_classifier.TextClassifierOptions( | ||||
|         supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER, | ||||
|         supported_model=supported_model, | ||||
|         model_options=text_classifier.BertModelOptions( | ||||
|             do_fine_tuning=False, seq_len=2 | ||||
|         ), | ||||
|  | @ -117,8 +130,8 @@ class TextClassifierTest(tf.test.TestCase): | |||
|     bert_classifier = text_classifier.TextClassifier.create( | ||||
|         train_data, validation_data, options) | ||||
| 
 | ||||
|     _, accuracy = bert_classifier.evaluate(validation_data) | ||||
|     self.assertGreaterEqual(accuracy, 0.0) | ||||
|     metrics = bert_classifier.evaluate(validation_data) | ||||
|     self.assertGreaterEqual(metrics[1], 0.0)  # metrics[1] is accuracy | ||||
| 
 | ||||
|     # Test export_model | ||||
|     bert_classifier.export_model() | ||||
|  | @ -142,45 +155,93 @@ class TextClassifierTest(tf.test.TestCase): | |||
|     ) | ||||
| 
 | ||||
|   def test_label_mismatch(self): | ||||
|     options = ( | ||||
|         text_classifier.TextClassifierOptions( | ||||
|             supported_model=( | ||||
|                 text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER))) | ||||
|     options = text_classifier.TextClassifierOptions( | ||||
|         supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER) | ||||
|     ) | ||||
|     train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) | ||||
|     train_data = text_classifier.Dataset(train_tf_dataset, 1, ['foo']) | ||||
|     train_data = text_classifier.Dataset(train_tf_dataset, ['foo'], 1) | ||||
|     validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) | ||||
|     validation_data = text_classifier.Dataset(validation_tf_dataset, 1, ['bar']) | ||||
|     validation_data = text_classifier.Dataset(validation_tf_dataset, ['bar'], 1) | ||||
|     with self.assertRaisesRegex( | ||||
|         ValueError, | ||||
|         'Training data label names .* not equal to validation data label names' | ||||
|         'Training data label names .* not equal to validation data label names', | ||||
|     ): | ||||
|       text_classifier.TextClassifier.create(train_data, validation_data, | ||||
|                                             options) | ||||
|       text_classifier.TextClassifier.create( | ||||
|           train_data, validation_data, options | ||||
|       ) | ||||
| 
 | ||||
|   def test_options_mismatch(self): | ||||
|     train_data, validation_data = self._get_data() | ||||
| 
 | ||||
|     avg_options = ( | ||||
|         text_classifier.TextClassifierOptions( | ||||
|             supported_model=( | ||||
|                 text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER), | ||||
|             model_options=text_classifier.AverageWordEmbeddingModelOptions())) | ||||
|     with self.assertRaisesRegex( | ||||
|         ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got' | ||||
|         ' SupportedModels.MOBILEBERT_CLASSIFIER'): | ||||
|       text_classifier.TextClassifier.create(train_data, validation_data, | ||||
|                                             avg_options) | ||||
|     avg_options = text_classifier.TextClassifierOptions( | ||||
|         supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER), | ||||
|         model_options=text_classifier.AverageWordEmbeddingModelOptions(), | ||||
|     ) | ||||
|     with self.assertRaisesWithLiteralMatch( | ||||
|         ValueError, | ||||
|         'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got' | ||||
|         ' SupportedModels.EXBERT_CLASSIFIER', | ||||
|     ): | ||||
|       text_classifier.TextClassifier.create( | ||||
|           train_data, validation_data, avg_options | ||||
|       ) | ||||
| 
 | ||||
|     bert_options = ( | ||||
|         text_classifier.TextClassifierOptions( | ||||
|             supported_model=(text_classifier.SupportedModels | ||||
|                              .AVERAGE_WORD_EMBEDDING_CLASSIFIER), | ||||
|             model_options=text_classifier.BertModelOptions())) | ||||
|     with self.assertRaisesRegex( | ||||
|         ValueError, 'Expected MOBILEBERT_CLASSIFIER, got' | ||||
|         ' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'): | ||||
|       text_classifier.TextClassifier.create(train_data, validation_data, | ||||
|                                             bert_options) | ||||
|     bert_options = text_classifier.TextClassifierOptions( | ||||
|         supported_model=( | ||||
|             text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER | ||||
|         ), | ||||
|         model_options=text_classifier.BertModelOptions(), | ||||
|     ) | ||||
|     with self.assertRaisesWithLiteralMatch( | ||||
|         ValueError, | ||||
|         'Expected a Bert Classifier(MobileBERT or EXBERT), got' | ||||
|         ' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER', | ||||
|     ): | ||||
|       text_classifier.TextClassifier.create( | ||||
|           train_data, validation_data, bert_options | ||||
|       ) | ||||
| 
 | ||||
|   def test_bert_loss_and_metrics_creation(self): | ||||
|     train_data, validation_data = self._get_data() | ||||
|     supported_model = text_classifier.SupportedModels.EXBERT_CLASSIFIER | ||||
|     hparams = text_classifier.BertHParams( | ||||
|         desired_recalls=[0.2], | ||||
|         desired_precisions=[0.9], | ||||
|         epochs=1, | ||||
|         batch_size=1, | ||||
|         learning_rate=3e-5, | ||||
|         distribution_strategy='off', | ||||
|         gamma=3.5, | ||||
|     ) | ||||
|     options = text_classifier.TextClassifierOptions( | ||||
|         supported_model=supported_model, hparams=hparams | ||||
|     ) | ||||
|     bert_classifier = text_classifier.TextClassifier.create( | ||||
|         train_data, validation_data, options | ||||
|     ) | ||||
|     loss_fn = bert_classifier._loss_function | ||||
|     self.assertIsInstance(loss_fn, loss_functions.SparseFocalLoss) | ||||
|     self.assertEqual(loss_fn._gamma, 3.5) | ||||
|     self.assertEqual(loss_fn._num_classes, 2) | ||||
|     metric_names = [m.name for m in bert_classifier._metric_functions] | ||||
|     expected_metric_names = [ | ||||
|         'accuracy', | ||||
|         'recall', | ||||
|         'precision', | ||||
|         'precision_at_recall_0.2', | ||||
|         'recall_at_precision_0.9', | ||||
|     ] | ||||
|     self.assertCountEqual(metric_names, expected_metric_names) | ||||
| 
 | ||||
|     # Non-binary data | ||||
|     tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) | ||||
|     data = text_classifier.Dataset(tf_dataset, ['foo', 'bar', 'baz'], 1) | ||||
|     with self.assertRaisesWithLiteralMatch( | ||||
|         ValueError, | ||||
|         'desired_recalls and desired_precisions parameters are binary metrics' | ||||
|         ' and not supported for num_classes > 2. Found num_classes: 3', | ||||
|     ): | ||||
|       text_classifier.TextClassifier.create(data, data, options) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user