Support ExBert training and option to select between AdamW and LAMB optimizers for BertClassifier
PiperOrigin-RevId: 543905014
This commit is contained in:
		
							parent
							
								
									bed624f3b6
								
							
						
					
					
						commit
						1ee55d1f1b
					
				| 
						 | 
				
			
			@ -31,11 +31,11 @@ py_library(
 | 
			
		|||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":dataset",
 | 
			
		||||
        ":hyperparameters",
 | 
			
		||||
        ":model_options",
 | 
			
		||||
        ":model_spec",
 | 
			
		||||
        ":text_classifier",
 | 
			
		||||
        ":text_classifier_options",
 | 
			
		||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -45,12 +45,18 @@ py_library(
 | 
			
		|||
    deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "hyperparameters",
 | 
			
		||||
    srcs = ["hyperparameters.py"],
 | 
			
		||||
    deps = ["//mediapipe/model_maker/python/core:hyperparameters"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "model_spec",
 | 
			
		||||
    srcs = ["model_spec.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":hyperparameters",
 | 
			
		||||
        ":model_options",
 | 
			
		||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/utils:file_util",
 | 
			
		||||
        "//mediapipe/model_maker/python/text/core:bert_model_spec",
 | 
			
		||||
    ],
 | 
			
		||||
| 
						 | 
				
			
			@ -61,9 +67,9 @@ py_test(
 | 
			
		|||
    srcs = ["model_spec_test.py"],
 | 
			
		||||
    tags = ["requires-net:external"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":hyperparameters",
 | 
			
		||||
        ":model_options",
 | 
			
		||||
        ":model_spec",
 | 
			
		||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -100,9 +106,9 @@ py_library(
 | 
			
		|||
    name = "text_classifier_options",
 | 
			
		||||
    srcs = ["text_classifier_options.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":hyperparameters",
 | 
			
		||||
        ":model_options",
 | 
			
		||||
        ":model_spec",
 | 
			
		||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -111,11 +117,11 @@ py_library(
 | 
			
		|||
    srcs = ["text_classifier.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":dataset",
 | 
			
		||||
        ":hyperparameters",
 | 
			
		||||
        ":model_options",
 | 
			
		||||
        ":model_spec",
 | 
			
		||||
        ":preprocessor",
 | 
			
		||||
        ":text_classifier_options",
 | 
			
		||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/data:dataset",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/tasks:classifier",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/utils:metrics",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,19 +13,23 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
"""MediaPipe Public Python API for Text Classifier."""
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core import hyperparameters
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import dataset
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_options
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
 | 
			
		||||
 | 
			
		||||
HParams = hyperparameters.BaseHParams
 | 
			
		||||
 | 
			
		||||
AverageWordEmbeddingHParams = hyperparameters.AverageWordEmbeddingHParams
 | 
			
		||||
AverageWordEmbeddingModelOptions = (
 | 
			
		||||
    model_options.AverageWordEmbeddingModelOptions
 | 
			
		||||
)
 | 
			
		||||
BertOptimizer = hyperparameters.BertOptimizer
 | 
			
		||||
BertHParams = hyperparameters.BertHParams
 | 
			
		||||
BertModelOptions = model_options.BertModelOptions
 | 
			
		||||
CSVParams = dataset.CSVParameters
 | 
			
		||||
Dataset = dataset.Dataset
 | 
			
		||||
AverageWordEmbeddingModelOptions = (
 | 
			
		||||
    model_options.AverageWordEmbeddingModelOptions)
 | 
			
		||||
BertModelOptions = model_options.BertModelOptions
 | 
			
		||||
SupportedModels = model_spec.SupportedModels
 | 
			
		||||
TextClassifier = text_classifier.TextClassifier
 | 
			
		||||
TextClassifierOptions = text_classifier_options.TextClassifierOptions
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,54 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Hyperparameters for training object detection models."""
 | 
			
		||||
 | 
			
		||||
import dataclasses
 | 
			
		||||
import enum
 | 
			
		||||
from typing import Union
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class AverageWordEmbeddingHParams(hp.BaseHParams):
 | 
			
		||||
  """The hyperparameters for an AverageWordEmbeddingClassifier."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@enum.unique
 | 
			
		||||
class BertOptimizer(enum.Enum):
 | 
			
		||||
  """Supported Optimizers for Bert Text Classifier."""
 | 
			
		||||
 | 
			
		||||
  ADAMW = "adamw"
 | 
			
		||||
  LAMB = "lamb"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class BertHParams(hp.BaseHParams):
 | 
			
		||||
  """The hyperparameters for a Bert Classifier.
 | 
			
		||||
 | 
			
		||||
  Attributes:
 | 
			
		||||
    learning_rate: Learning rate to use for gradient descent training.
 | 
			
		||||
    batch_size: Batch size for training.
 | 
			
		||||
    epochs: Number of training iterations over the dataset.
 | 
			
		||||
    optimizer: Optimizer to use for training. Only supported values are "adamw"
 | 
			
		||||
      and "lamb".
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  learning_rate: float = 3e-5
 | 
			
		||||
  batch_size: int = 48
 | 
			
		||||
  epochs: int = 2
 | 
			
		||||
  optimizer: BertOptimizer = BertOptimizer.ADAMW
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
 | 
			
		||||
| 
						 | 
				
			
			@ -17,13 +17,11 @@ import dataclasses
 | 
			
		|||
import enum
 | 
			
		||||
import functools
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.core.utils import file_util
 | 
			
		||||
from mediapipe.model_maker.python.text.core import bert_model_spec
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
			
		||||
 | 
			
		||||
# BERT-based text classifier spec inherited from BertModelSpec
 | 
			
		||||
BertClassifierSpec = bert_model_spec.BertModelSpec
 | 
			
		||||
 | 
			
		||||
MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
 | 
			
		||||
    'text_classifier/mobilebert_tiny',
 | 
			
		||||
| 
						 | 
				
			
			@ -31,6 +29,12 @@ MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
 | 
			
		|||
    is_folder=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
EXBERT_FILES = file_util.DownloadedFiles(
 | 
			
		||||
    'text_classifier/exbert',
 | 
			
		||||
    'https://storage.googleapis.com/mediapipe-assets/exbert.tar.gz',
 | 
			
		||||
    is_folder=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class AverageWordEmbeddingClassifierSpec:
 | 
			
		||||
| 
						 | 
				
			
			@ -43,27 +47,53 @@ class AverageWordEmbeddingClassifierSpec:
 | 
			
		|||
  """
 | 
			
		||||
 | 
			
		||||
  # `learning_rate` is unused for the average word embedding model
 | 
			
		||||
  hparams: hp.BaseHParams = hp.BaseHParams(
 | 
			
		||||
      epochs=10, batch_size=32, learning_rate=0)
 | 
			
		||||
  hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams(
 | 
			
		||||
      epochs=10, batch_size=32, learning_rate=0
 | 
			
		||||
  )
 | 
			
		||||
  model_options: mo.AverageWordEmbeddingModelOptions = (
 | 
			
		||||
      mo.AverageWordEmbeddingModelOptions())
 | 
			
		||||
  name: str = 'AverageWordEmbedding'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
average_word_embedding_classifier_spec = functools.partial(
 | 
			
		||||
    AverageWordEmbeddingClassifierSpec)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class BertClassifierSpec(bert_model_spec.BertModelSpec):
 | 
			
		||||
  """Specification for a Bert classifier model.
 | 
			
		||||
 | 
			
		||||
  Only overrides the hparams attribute since the rest of the attributes are
 | 
			
		||||
  inherited from the BertModelSpec.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  hparams: hp.BertHParams = hp.BertHParams()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
mobilebert_classifier_spec = functools.partial(
 | 
			
		||||
    BertClassifierSpec,
 | 
			
		||||
    downloaded_files=MOBILEBERT_TINY_FILES,
 | 
			
		||||
    hparams=hp.BaseHParams(
 | 
			
		||||
    hparams=hp.BertHParams(
 | 
			
		||||
        epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
 | 
			
		||||
    ),
 | 
			
		||||
    name='MobileBert',
 | 
			
		||||
    tflite_input_name={
 | 
			
		||||
        'ids': 'serving_default_input_1:0',
 | 
			
		||||
        'mask': 'serving_default_input_3:0',
 | 
			
		||||
        'segment_ids': 'serving_default_input_2:0',
 | 
			
		||||
        'mask': 'serving_default_input_3:0',
 | 
			
		||||
    },
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
exbert_classifier_spec = functools.partial(
 | 
			
		||||
    BertClassifierSpec,
 | 
			
		||||
    downloaded_files=EXBERT_FILES,
 | 
			
		||||
    hparams=hp.BertHParams(
 | 
			
		||||
        epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
 | 
			
		||||
    ),
 | 
			
		||||
    name='ExBert',
 | 
			
		||||
    tflite_input_name={
 | 
			
		||||
        'ids': 'serving_default_input_1:0',
 | 
			
		||||
        'segment_ids': 'serving_default_input_2:0',
 | 
			
		||||
        'mask': 'serving_default_input_3:0',
 | 
			
		||||
    },
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -73,3 +103,4 @@ class SupportedModels(enum.Enum):
 | 
			
		|||
  """Predefined text classifier model specs supported by Model Maker."""
 | 
			
		||||
  AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
 | 
			
		||||
  MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
 | 
			
		||||
  EXBERT_CLASSIFIER = exbert_classifier_spec
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,7 +19,7 @@ from unittest import mock as unittest_mock
 | 
			
		|||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -57,11 +57,13 @@ class ModelSpecTest(tf.test.TestCase):
 | 
			
		|||
            seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
 | 
			
		||||
    self.assertEqual(
 | 
			
		||||
        model_spec_obj.hparams,
 | 
			
		||||
        hp.BaseHParams(
 | 
			
		||||
        hp.BertHParams(
 | 
			
		||||
            epochs=3,
 | 
			
		||||
            batch_size=48,
 | 
			
		||||
            learning_rate=3e-5,
 | 
			
		||||
            distribution_strategy='off'))
 | 
			
		||||
            distribution_strategy='off',
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
  def test_predefined_average_word_embedding_spec(self):
 | 
			
		||||
    model_spec_obj = (
 | 
			
		||||
| 
						 | 
				
			
			@ -78,7 +80,7 @@ class ModelSpecTest(tf.test.TestCase):
 | 
			
		|||
            dropout_rate=0.2))
 | 
			
		||||
    self.assertEqual(
 | 
			
		||||
        model_spec_obj.hparams,
 | 
			
		||||
        hp.BaseHParams(
 | 
			
		||||
        hp.AverageWordEmbeddingHParams(
 | 
			
		||||
            epochs=10,
 | 
			
		||||
            batch_size=32,
 | 
			
		||||
            learning_rate=0,
 | 
			
		||||
| 
						 | 
				
			
			@ -101,7 +103,7 @@ class ModelSpecTest(tf.test.TestCase):
 | 
			
		|||
                     custom_bert_classifier_options)
 | 
			
		||||
 | 
			
		||||
  def test_custom_average_word_embedding_spec(self):
 | 
			
		||||
    custom_hparams = hp.BaseHParams(
 | 
			
		||||
    custom_hparams = hp.AverageWordEmbeddingHParams(
 | 
			
		||||
        learning_rate=0.4,
 | 
			
		||||
        batch_size=64,
 | 
			
		||||
        epochs=10,
 | 
			
		||||
| 
						 | 
				
			
			@ -110,7 +112,8 @@ class ModelSpecTest(tf.test.TestCase):
 | 
			
		|||
        export_dir='foo/bar',
 | 
			
		||||
        distribution_strategy='mirrored',
 | 
			
		||||
        num_gpus=3,
 | 
			
		||||
        tpu='tpu/address')
 | 
			
		||||
        tpu='tpu/address',
 | 
			
		||||
    )
 | 
			
		||||
    custom_average_word_embedding_model_options = (
 | 
			
		||||
        classifier_model_options.AverageWordEmbeddingModelOptions(
 | 
			
		||||
            seq_len=512,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,15 +19,16 @@ import tempfile
 | 
			
		|||
from typing import Any, Optional, Sequence, Tuple
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
from tensorflow_addons import optimizers as tfa_optimizers
 | 
			
		||||
import tensorflow_hub as hub
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.core.data import dataset as ds
 | 
			
		||||
from mediapipe.model_maker.python.core.tasks import classifier
 | 
			
		||||
from mediapipe.model_maker.python.core.utils import metrics
 | 
			
		||||
from mediapipe.model_maker.python.core.utils import model_util
 | 
			
		||||
from mediapipe.model_maker.python.core.utils import quantization
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
 | 
			
		||||
| 
						 | 
				
			
			@ -55,22 +56,26 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
 | 
			
		|||
       ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
 | 
			
		||||
    raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
 | 
			
		||||
                     f" got {options.supported_model}")
 | 
			
		||||
  if (isinstance(options.model_options, mo.BertModelOptions) and
 | 
			
		||||
      (options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
 | 
			
		||||
  if isinstance(options.model_options, mo.BertModelOptions) and (
 | 
			
		||||
      options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER
 | 
			
		||||
      and options.supported_model != ms.SupportedModels.EXBERT_CLASSIFIER
 | 
			
		||||
  ):
 | 
			
		||||
    raise ValueError(
 | 
			
		||||
        f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
 | 
			
		||||
        "Expected a Bert Classifier(MobileBERT or EXBERT), got "
 | 
			
		||||
        f"{options.supported_model}"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TextClassifier(classifier.Classifier):
 | 
			
		||||
  """API for creating and training a text classification model."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, model_spec: Any, hparams: hp.BaseHParams,
 | 
			
		||||
               label_names: Sequence[str]):
 | 
			
		||||
  def __init__(
 | 
			
		||||
      self, model_spec: Any, label_names: Sequence[str], shuffle: bool
 | 
			
		||||
  ):
 | 
			
		||||
    super().__init__(
 | 
			
		||||
        model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
 | 
			
		||||
        model_spec=model_spec, label_names=label_names, shuffle=shuffle
 | 
			
		||||
    )
 | 
			
		||||
    self._model_spec = model_spec
 | 
			
		||||
    self._hparams = hparams
 | 
			
		||||
    self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
 | 
			
		||||
    self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
 | 
			
		||||
 | 
			
		||||
  @classmethod
 | 
			
		||||
| 
						 | 
				
			
			@ -107,7 +112,10 @@ class TextClassifier(classifier.Classifier):
 | 
			
		|||
    if options.hparams is None:
 | 
			
		||||
      options.hparams = options.supported_model.value().hparams
 | 
			
		||||
 | 
			
		||||
    if options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
 | 
			
		||||
    if (
 | 
			
		||||
        options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
 | 
			
		||||
        or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
 | 
			
		||||
    ):
 | 
			
		||||
      text_classifier = (
 | 
			
		||||
          _BertClassifier.create_bert_classifier(train_data, validation_data,
 | 
			
		||||
                                                 options,
 | 
			
		||||
| 
						 | 
				
			
			@ -225,11 +233,17 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
 | 
			
		|||
 | 
			
		||||
  _DELIM_REGEX_PATTERN = r"[^\w\']+"
 | 
			
		||||
 | 
			
		||||
  def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
 | 
			
		||||
  def __init__(
 | 
			
		||||
      self,
 | 
			
		||||
      model_spec: ms.AverageWordEmbeddingClassifierSpec,
 | 
			
		||||
      model_options: mo.AverageWordEmbeddingModelOptions,
 | 
			
		||||
               hparams: hp.BaseHParams, label_names: Sequence[str]):
 | 
			
		||||
    super().__init__(model_spec, hparams, label_names)
 | 
			
		||||
      hparams: hp.AverageWordEmbeddingHParams,
 | 
			
		||||
      label_names: Sequence[str],
 | 
			
		||||
  ):
 | 
			
		||||
    super().__init__(model_spec, label_names, hparams.shuffle)
 | 
			
		||||
    self._model_options = model_options
 | 
			
		||||
    self._hparams = hparams
 | 
			
		||||
    self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
 | 
			
		||||
    self._loss_function = "sparse_categorical_crossentropy"
 | 
			
		||||
    self._metric_functions = [
 | 
			
		||||
        "accuracy",
 | 
			
		||||
| 
						 | 
				
			
			@ -344,10 +358,16 @@ class _BertClassifier(TextClassifier):
 | 
			
		|||
 | 
			
		||||
  _INITIALIZER_RANGE = 0.02
 | 
			
		||||
 | 
			
		||||
  def __init__(self, model_spec: ms.BertClassifierSpec,
 | 
			
		||||
               model_options: mo.BertModelOptions, hparams: hp.BaseHParams,
 | 
			
		||||
               label_names: Sequence[str]):
 | 
			
		||||
    super().__init__(model_spec, hparams, label_names)
 | 
			
		||||
  def __init__(
 | 
			
		||||
      self,
 | 
			
		||||
      model_spec: ms.BertClassifierSpec,
 | 
			
		||||
      model_options: mo.BertModelOptions,
 | 
			
		||||
      hparams: hp.BertHParams,
 | 
			
		||||
      label_names: Sequence[str],
 | 
			
		||||
  ):
 | 
			
		||||
    super().__init__(model_spec, label_names, hparams.shuffle)
 | 
			
		||||
    self._hparams = hparams
 | 
			
		||||
    self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
 | 
			
		||||
    self._model_options = model_options
 | 
			
		||||
    with self._hparams.get_strategy().scope():
 | 
			
		||||
      self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
 | 
			
		||||
| 
						 | 
				
			
			@ -480,11 +500,26 @@ class _BertClassifier(TextClassifier):
 | 
			
		|||
          initial_learning_rate=initial_lr,
 | 
			
		||||
          decay_schedule_fn=lr_schedule,
 | 
			
		||||
          warmup_steps=warmup_steps)
 | 
			
		||||
 | 
			
		||||
    if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
 | 
			
		||||
      self._optimizer = tf.keras.optimizers.experimental.AdamW(
 | 
			
		||||
        lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0)
 | 
			
		||||
          lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0
 | 
			
		||||
      )
 | 
			
		||||
      self._optimizer.exclude_from_weight_decay(
 | 
			
		||||
        var_names=["LayerNorm", "layer_norm", "bias"])
 | 
			
		||||
          var_names=["LayerNorm", "layer_norm", "bias"]
 | 
			
		||||
      )
 | 
			
		||||
    elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
 | 
			
		||||
      self._optimizer = tfa_optimizers.LAMB(
 | 
			
		||||
          lr_schedule,
 | 
			
		||||
          weight_decay_rate=0.01,
 | 
			
		||||
          epsilon=1e-6,
 | 
			
		||||
          exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
 | 
			
		||||
          global_clipnorm=1.0,
 | 
			
		||||
      )
 | 
			
		||||
    else:
 | 
			
		||||
      raise ValueError(
 | 
			
		||||
          "BertHParams.optimizer must be set to ADAM or "
 | 
			
		||||
          f"LAMB. Got {self._hparams.optimizer}."
 | 
			
		||||
      )
 | 
			
		||||
 | 
			
		||||
  def _save_vocab(self, vocab_filepath: str):
 | 
			
		||||
    tf.io.gfile.copy(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -66,14 +66,16 @@ def run(data_dir,
 | 
			
		|||
  quantization_config = None
 | 
			
		||||
  if (supported_model ==
 | 
			
		||||
      text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
 | 
			
		||||
    hparams = text_classifier.HParams(
 | 
			
		||||
        epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
 | 
			
		||||
    hparams = text_classifier.AverageWordEmbeddingHParams(
 | 
			
		||||
        epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir
 | 
			
		||||
    )
 | 
			
		||||
  # Warning: This takes extremely long to run on CPU
 | 
			
		||||
  elif (
 | 
			
		||||
      supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
 | 
			
		||||
    quantization_config = quantization.QuantizationConfig.for_dynamic()
 | 
			
		||||
    hparams = text_classifier.HParams(
 | 
			
		||||
        epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
 | 
			
		||||
    hparams = text_classifier.BertHParams(
 | 
			
		||||
        epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
  # Fine-tunes the model.
 | 
			
		||||
  options = text_classifier.TextClassifierOptions(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,7 +16,7 @@
 | 
			
		|||
import dataclasses
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -34,5 +34,5 @@ class TextClassifierOptions:
 | 
			
		|||
      architecture of the `supported_model`.
 | 
			
		||||
  """
 | 
			
		||||
  supported_model: ms.SupportedModels
 | 
			
		||||
  hparams: Optional[hp.BaseHParams] = None
 | 
			
		||||
  hparams: Optional[hp.HParams] = None
 | 
			
		||||
  model_options: Optional[mo.TextClassifierModelOptions] = None
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -66,12 +66,14 @@ class TextClassifierTest(tf.test.TestCase):
 | 
			
		|||
 | 
			
		||||
  def test_create_and_train_average_word_embedding_model(self):
 | 
			
		||||
    train_data, validation_data = self._get_data()
 | 
			
		||||
    options = (
 | 
			
		||||
        text_classifier.TextClassifierOptions(
 | 
			
		||||
            supported_model=(text_classifier.SupportedModels
 | 
			
		||||
                             .AVERAGE_WORD_EMBEDDING_CLASSIFIER),
 | 
			
		||||
            hparams=text_classifier.HParams(
 | 
			
		||||
                epochs=1, batch_size=1, learning_rate=0)))
 | 
			
		||||
    options = text_classifier.TextClassifierOptions(
 | 
			
		||||
        supported_model=(
 | 
			
		||||
            text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
 | 
			
		||||
        ),
 | 
			
		||||
        hparams=text_classifier.AverageWordEmbeddingHParams(
 | 
			
		||||
            epochs=1, batch_size=1, learning_rate=0
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    average_word_embedding_classifier = (
 | 
			
		||||
        text_classifier.TextClassifier.create(train_data, validation_data,
 | 
			
		||||
                                              options))
 | 
			
		||||
| 
						 | 
				
			
			@ -103,12 +105,15 @@ class TextClassifierTest(tf.test.TestCase):
 | 
			
		|||
    options = text_classifier.TextClassifierOptions(
 | 
			
		||||
        supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
 | 
			
		||||
        model_options=text_classifier.BertModelOptions(
 | 
			
		||||
            do_fine_tuning=False, seq_len=2),
 | 
			
		||||
        hparams=text_classifier.HParams(
 | 
			
		||||
            do_fine_tuning=False, seq_len=2
 | 
			
		||||
        ),
 | 
			
		||||
        hparams=text_classifier.BertHParams(
 | 
			
		||||
            epochs=1,
 | 
			
		||||
            batch_size=1,
 | 
			
		||||
            learning_rate=3e-5,
 | 
			
		||||
            distribution_strategy='off'))
 | 
			
		||||
            distribution_strategy='off',
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    bert_classifier = text_classifier.TextClassifier.create(
 | 
			
		||||
        train_data, validation_data, options)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user