Support ExBert training and option to select between AdamW and LAMB optimizers for BertClassifier

PiperOrigin-RevId: 543905014
This commit is contained in:
MediaPipe Team 2023-06-27 18:02:59 -07:00 committed by Copybara-Service
parent bed624f3b6
commit 1ee55d1f1b
9 changed files with 202 additions and 62 deletions

View File

@ -31,11 +31,11 @@ py_library(
visibility = ["//visibility:public"],
deps = [
":dataset",
":hyperparameters",
":model_options",
":model_spec",
":text_classifier",
":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
@ -45,12 +45,18 @@ py_library(
deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
)
py_library(
name = "hyperparameters",
srcs = ["hyperparameters.py"],
deps = ["//mediapipe/model_maker/python/core:hyperparameters"],
)
py_library(
name = "model_spec",
srcs = ["model_spec.py"],
deps = [
":hyperparameters",
":model_options",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/core/utils:file_util",
"//mediapipe/model_maker/python/text/core:bert_model_spec",
],
@ -61,9 +67,9 @@ py_test(
srcs = ["model_spec_test.py"],
tags = ["requires-net:external"],
deps = [
":hyperparameters",
":model_options",
":model_spec",
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
@ -100,9 +106,9 @@ py_library(
name = "text_classifier_options",
srcs = ["text_classifier_options.py"],
deps = [
":hyperparameters",
":model_options",
":model_spec",
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
@ -111,11 +117,11 @@ py_library(
srcs = ["text_classifier.py"],
deps = [
":dataset",
":hyperparameters",
":model_options",
":model_spec",
":preprocessor",
":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/core/data:dataset",
"//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:metrics",

View File

@ -13,19 +13,23 @@
# limitations under the License.
"""MediaPipe Public Python API for Text Classifier."""
from mediapipe.model_maker.python.core import hyperparameters
from mediapipe.model_maker.python.text.text_classifier import dataset
from mediapipe.model_maker.python.text.text_classifier import hyperparameters
from mediapipe.model_maker.python.text.text_classifier import model_options
from mediapipe.model_maker.python.text.text_classifier import model_spec
from mediapipe.model_maker.python.text.text_classifier import text_classifier
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
HParams = hyperparameters.BaseHParams
AverageWordEmbeddingHParams = hyperparameters.AverageWordEmbeddingHParams
AverageWordEmbeddingModelOptions = (
model_options.AverageWordEmbeddingModelOptions
)
BertOptimizer = hyperparameters.BertOptimizer
BertHParams = hyperparameters.BertHParams
BertModelOptions = model_options.BertModelOptions
CSVParams = dataset.CSVParameters
Dataset = dataset.Dataset
AverageWordEmbeddingModelOptions = (
model_options.AverageWordEmbeddingModelOptions)
BertModelOptions = model_options.BertModelOptions
SupportedModels = model_spec.SupportedModels
TextClassifier = text_classifier.TextClassifier
TextClassifierOptions = text_classifier_options.TextClassifierOptions

View File

@ -0,0 +1,54 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hyperparameters for training object detection models."""
import dataclasses
import enum
from typing import Union
from mediapipe.model_maker.python.core import hyperparameters as hp
@dataclasses.dataclass
class AverageWordEmbeddingHParams(hp.BaseHParams):
"""The hyperparameters for an AverageWordEmbeddingClassifier."""
@enum.unique
class BertOptimizer(enum.Enum):
"""Supported Optimizers for Bert Text Classifier."""
ADAMW = "adamw"
LAMB = "lamb"
@dataclasses.dataclass
class BertHParams(hp.BaseHParams):
"""The hyperparameters for a Bert Classifier.
Attributes:
learning_rate: Learning rate to use for gradient descent training.
batch_size: Batch size for training.
epochs: Number of training iterations over the dataset.
optimizer: Optimizer to use for training. Only supported values are "adamw"
and "lamb".
"""
learning_rate: float = 3e-5
batch_size: int = 48
epochs: int = 2
optimizer: BertOptimizer = BertOptimizer.ADAMW
HParams = Union[BertHParams, AverageWordEmbeddingHParams]

View File

@ -17,13 +17,11 @@ import dataclasses
import enum
import functools
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.utils import file_util
from mediapipe.model_maker.python.text.core import bert_model_spec
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
# BERT-based text classifier spec inherited from BertModelSpec
BertClassifierSpec = bert_model_spec.BertModelSpec
MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
'text_classifier/mobilebert_tiny',
@ -31,6 +29,12 @@ MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
is_folder=True,
)
EXBERT_FILES = file_util.DownloadedFiles(
'text_classifier/exbert',
'https://storage.googleapis.com/mediapipe-assets/exbert.tar.gz',
is_folder=True,
)
@dataclasses.dataclass
class AverageWordEmbeddingClassifierSpec:
@ -43,27 +47,53 @@ class AverageWordEmbeddingClassifierSpec:
"""
# `learning_rate` is unused for the average word embedding model
hparams: hp.BaseHParams = hp.BaseHParams(
epochs=10, batch_size=32, learning_rate=0)
hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams(
epochs=10, batch_size=32, learning_rate=0
)
model_options: mo.AverageWordEmbeddingModelOptions = (
mo.AverageWordEmbeddingModelOptions())
name: str = 'AverageWordEmbedding'
average_word_embedding_classifier_spec = functools.partial(
AverageWordEmbeddingClassifierSpec)
@dataclasses.dataclass
class BertClassifierSpec(bert_model_spec.BertModelSpec):
"""Specification for a Bert classifier model.
Only overrides the hparams attribute since the rest of the attributes are
inherited from the BertModelSpec.
"""
hparams: hp.BertHParams = hp.BertHParams()
mobilebert_classifier_spec = functools.partial(
BertClassifierSpec,
downloaded_files=MOBILEBERT_TINY_FILES,
hparams=hp.BaseHParams(
hparams=hp.BertHParams(
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
),
name='MobileBert',
tflite_input_name={
'ids': 'serving_default_input_1:0',
'mask': 'serving_default_input_3:0',
'segment_ids': 'serving_default_input_2:0',
'mask': 'serving_default_input_3:0',
},
)
exbert_classifier_spec = functools.partial(
BertClassifierSpec,
downloaded_files=EXBERT_FILES,
hparams=hp.BertHParams(
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
),
name='ExBert',
tflite_input_name={
'ids': 'serving_default_input_1:0',
'segment_ids': 'serving_default_input_2:0',
'mask': 'serving_default_input_3:0',
},
)
@ -73,3 +103,4 @@ class SupportedModels(enum.Enum):
"""Predefined text classifier model specs supported by Model Maker."""
AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
EXBERT_CLASSIFIER = exbert_classifier_spec

View File

@ -19,7 +19,7 @@ from unittest import mock as unittest_mock
import tensorflow as tf
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
@ -57,11 +57,13 @@ class ModelSpecTest(tf.test.TestCase):
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
self.assertEqual(
model_spec_obj.hparams,
hp.BaseHParams(
hp.BertHParams(
epochs=3,
batch_size=48,
learning_rate=3e-5,
distribution_strategy='off'))
distribution_strategy='off',
),
)
def test_predefined_average_word_embedding_spec(self):
model_spec_obj = (
@ -78,7 +80,7 @@ class ModelSpecTest(tf.test.TestCase):
dropout_rate=0.2))
self.assertEqual(
model_spec_obj.hparams,
hp.BaseHParams(
hp.AverageWordEmbeddingHParams(
epochs=10,
batch_size=32,
learning_rate=0,
@ -101,7 +103,7 @@ class ModelSpecTest(tf.test.TestCase):
custom_bert_classifier_options)
def test_custom_average_word_embedding_spec(self):
custom_hparams = hp.BaseHParams(
custom_hparams = hp.AverageWordEmbeddingHParams(
learning_rate=0.4,
batch_size=64,
epochs=10,
@ -110,7 +112,8 @@ class ModelSpecTest(tf.test.TestCase):
export_dir='foo/bar',
distribution_strategy='mirrored',
num_gpus=3,
tpu='tpu/address')
tpu='tpu/address',
)
custom_average_word_embedding_model_options = (
classifier_model_options.AverageWordEmbeddingModelOptions(
seq_len=512,

View File

@ -19,15 +19,16 @@ import tempfile
from typing import Any, Optional, Sequence, Tuple
import tensorflow as tf
from tensorflow_addons import optimizers as tfa_optimizers
import tensorflow_hub as hub
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import metrics
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
from mediapipe.model_maker.python.text.text_classifier import preprocessor
@ -55,22 +56,26 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
f" got {options.supported_model}")
if (isinstance(options.model_options, mo.BertModelOptions) and
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
if isinstance(options.model_options, mo.BertModelOptions) and (
options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER
and options.supported_model != ms.SupportedModels.EXBERT_CLASSIFIER
):
raise ValueError(
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
"Expected a Bert Classifier(MobileBERT or EXBERT), got "
f"{options.supported_model}"
)
class TextClassifier(classifier.Classifier):
"""API for creating and training a text classification model."""
def __init__(self, model_spec: Any, hparams: hp.BaseHParams,
label_names: Sequence[str]):
def __init__(
self, model_spec: Any, label_names: Sequence[str], shuffle: bool
):
super().__init__(
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
model_spec=model_spec, label_names=label_names, shuffle=shuffle
)
self._model_spec = model_spec
self._hparams = hparams
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
@classmethod
@ -107,7 +112,10 @@ class TextClassifier(classifier.Classifier):
if options.hparams is None:
options.hparams = options.supported_model.value().hparams
if options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
if (
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
):
text_classifier = (
_BertClassifier.create_bert_classifier(train_data, validation_data,
options,
@ -225,11 +233,17 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
_DELIM_REGEX_PATTERN = r"[^\w\']+"
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
def __init__(
self,
model_spec: ms.AverageWordEmbeddingClassifierSpec,
model_options: mo.AverageWordEmbeddingModelOptions,
hparams: hp.BaseHParams, label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names)
hparams: hp.AverageWordEmbeddingHParams,
label_names: Sequence[str],
):
super().__init__(model_spec, label_names, hparams.shuffle)
self._model_options = model_options
self._hparams = hparams
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._loss_function = "sparse_categorical_crossentropy"
self._metric_functions = [
"accuracy",
@ -344,10 +358,16 @@ class _BertClassifier(TextClassifier):
_INITIALIZER_RANGE = 0.02
def __init__(self, model_spec: ms.BertClassifierSpec,
model_options: mo.BertModelOptions, hparams: hp.BaseHParams,
label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names)
def __init__(
self,
model_spec: ms.BertClassifierSpec,
model_options: mo.BertModelOptions,
hparams: hp.BertHParams,
label_names: Sequence[str],
):
super().__init__(model_spec, label_names, hparams.shuffle)
self._hparams = hparams
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._model_options = model_options
with self._hparams.get_strategy().scope():
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
@ -480,11 +500,26 @@ class _BertClassifier(TextClassifier):
initial_learning_rate=initial_lr,
decay_schedule_fn=lr_schedule,
warmup_steps=warmup_steps)
if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
self._optimizer = tf.keras.optimizers.experimental.AdamW(
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0)
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0
)
self._optimizer.exclude_from_weight_decay(
var_names=["LayerNorm", "layer_norm", "bias"])
var_names=["LayerNorm", "layer_norm", "bias"]
)
elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
self._optimizer = tfa_optimizers.LAMB(
lr_schedule,
weight_decay_rate=0.01,
epsilon=1e-6,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
global_clipnorm=1.0,
)
else:
raise ValueError(
"BertHParams.optimizer must be set to ADAM or "
f"LAMB. Got {self._hparams.optimizer}."
)
def _save_vocab(self, vocab_filepath: str):
tf.io.gfile.copy(

View File

@ -66,14 +66,16 @@ def run(data_dir,
quantization_config = None
if (supported_model ==
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
hparams = text_classifier.HParams(
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
hparams = text_classifier.AverageWordEmbeddingHParams(
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir
)
# Warning: This takes extremely long to run on CPU
elif (
supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
quantization_config = quantization.QuantizationConfig.for_dynamic()
hparams = text_classifier.HParams(
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
hparams = text_classifier.BertHParams(
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir
)
# Fine-tunes the model.
options = text_classifier.TextClassifierOptions(

View File

@ -16,7 +16,7 @@
import dataclasses
from typing import Optional
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
@ -34,5 +34,5 @@ class TextClassifierOptions:
architecture of the `supported_model`.
"""
supported_model: ms.SupportedModels
hparams: Optional[hp.BaseHParams] = None
hparams: Optional[hp.HParams] = None
model_options: Optional[mo.TextClassifierModelOptions] = None

View File

@ -66,12 +66,14 @@ class TextClassifierTest(tf.test.TestCase):
def test_create_and_train_average_word_embedding_model(self):
train_data, validation_data = self._get_data()
options = (
text_classifier.TextClassifierOptions(
supported_model=(text_classifier.SupportedModels
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
hparams=text_classifier.HParams(
epochs=1, batch_size=1, learning_rate=0)))
options = text_classifier.TextClassifierOptions(
supported_model=(
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
),
hparams=text_classifier.AverageWordEmbeddingHParams(
epochs=1, batch_size=1, learning_rate=0
),
)
average_word_embedding_classifier = (
text_classifier.TextClassifier.create(train_data, validation_data,
options))
@ -103,12 +105,15 @@ class TextClassifierTest(tf.test.TestCase):
options = text_classifier.TextClassifierOptions(
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
model_options=text_classifier.BertModelOptions(
do_fine_tuning=False, seq_len=2),
hparams=text_classifier.HParams(
do_fine_tuning=False, seq_len=2
),
hparams=text_classifier.BertHParams(
epochs=1,
batch_size=1,
learning_rate=3e-5,
distribution_strategy='off'))
distribution_strategy='off',
),
)
bert_classifier = text_classifier.TextClassifier.create(
train_data, validation_data, options)