Support ExBert training and option to select between AdamW and LAMB optimizers for BertClassifier
PiperOrigin-RevId: 543905014
This commit is contained in:
parent
bed624f3b6
commit
1ee55d1f1b
|
@ -31,11 +31,11 @@ py_library(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":dataset",
|
":dataset",
|
||||||
|
":hyperparameters",
|
||||||
":model_options",
|
":model_options",
|
||||||
":model_spec",
|
":model_spec",
|
||||||
":text_classifier",
|
":text_classifier",
|
||||||
":text_classifier_options",
|
":text_classifier_options",
|
||||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -45,12 +45,18 @@ py_library(
|
||||||
deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
|
deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "hyperparameters",
|
||||||
|
srcs = ["hyperparameters.py"],
|
||||||
|
deps = ["//mediapipe/model_maker/python/core:hyperparameters"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "model_spec",
|
name = "model_spec",
|
||||||
srcs = ["model_spec.py"],
|
srcs = ["model_spec.py"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":hyperparameters",
|
||||||
":model_options",
|
":model_options",
|
||||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
|
||||||
"//mediapipe/model_maker/python/core/utils:file_util",
|
"//mediapipe/model_maker/python/core/utils:file_util",
|
||||||
"//mediapipe/model_maker/python/text/core:bert_model_spec",
|
"//mediapipe/model_maker/python/text/core:bert_model_spec",
|
||||||
],
|
],
|
||||||
|
@ -61,9 +67,9 @@ py_test(
|
||||||
srcs = ["model_spec_test.py"],
|
srcs = ["model_spec_test.py"],
|
||||||
tags = ["requires-net:external"],
|
tags = ["requires-net:external"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":hyperparameters",
|
||||||
":model_options",
|
":model_options",
|
||||||
":model_spec",
|
":model_spec",
|
||||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -100,9 +106,9 @@ py_library(
|
||||||
name = "text_classifier_options",
|
name = "text_classifier_options",
|
||||||
srcs = ["text_classifier_options.py"],
|
srcs = ["text_classifier_options.py"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":hyperparameters",
|
||||||
":model_options",
|
":model_options",
|
||||||
":model_spec",
|
":model_spec",
|
||||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -111,11 +117,11 @@ py_library(
|
||||||
srcs = ["text_classifier.py"],
|
srcs = ["text_classifier.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":dataset",
|
":dataset",
|
||||||
|
":hyperparameters",
|
||||||
":model_options",
|
":model_options",
|
||||||
":model_spec",
|
":model_spec",
|
||||||
":preprocessor",
|
":preprocessor",
|
||||||
":text_classifier_options",
|
":text_classifier_options",
|
||||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
|
||||||
"//mediapipe/model_maker/python/core/data:dataset",
|
"//mediapipe/model_maker/python/core/data:dataset",
|
||||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||||
"//mediapipe/model_maker/python/core/utils:metrics",
|
"//mediapipe/model_maker/python/core/utils:metrics",
|
||||||
|
|
|
@ -13,19 +13,23 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""MediaPipe Public Python API for Text Classifier."""
|
"""MediaPipe Public Python API for Text Classifier."""
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters
|
|
||||||
from mediapipe.model_maker.python.text.text_classifier import dataset
|
from mediapipe.model_maker.python.text.text_classifier import dataset
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import hyperparameters
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_options
|
from mediapipe.model_maker.python.text.text_classifier import model_options
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
||||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||||
|
|
||||||
HParams = hyperparameters.BaseHParams
|
|
||||||
|
AverageWordEmbeddingHParams = hyperparameters.AverageWordEmbeddingHParams
|
||||||
|
AverageWordEmbeddingModelOptions = (
|
||||||
|
model_options.AverageWordEmbeddingModelOptions
|
||||||
|
)
|
||||||
|
BertOptimizer = hyperparameters.BertOptimizer
|
||||||
|
BertHParams = hyperparameters.BertHParams
|
||||||
|
BertModelOptions = model_options.BertModelOptions
|
||||||
CSVParams = dataset.CSVParameters
|
CSVParams = dataset.CSVParameters
|
||||||
Dataset = dataset.Dataset
|
Dataset = dataset.Dataset
|
||||||
AverageWordEmbeddingModelOptions = (
|
|
||||||
model_options.AverageWordEmbeddingModelOptions)
|
|
||||||
BertModelOptions = model_options.BertModelOptions
|
|
||||||
SupportedModels = model_spec.SupportedModels
|
SupportedModels = model_spec.SupportedModels
|
||||||
TextClassifier = text_classifier.TextClassifier
|
TextClassifier = text_classifier.TextClassifier
|
||||||
TextClassifierOptions = text_classifier_options.TextClassifierOptions
|
TextClassifierOptions = text_classifier_options.TextClassifierOptions
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Hyperparameters for training object detection models."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import enum
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class AverageWordEmbeddingHParams(hp.BaseHParams):
|
||||||
|
"""The hyperparameters for an AverageWordEmbeddingClassifier."""
|
||||||
|
|
||||||
|
|
||||||
|
@enum.unique
|
||||||
|
class BertOptimizer(enum.Enum):
|
||||||
|
"""Supported Optimizers for Bert Text Classifier."""
|
||||||
|
|
||||||
|
ADAMW = "adamw"
|
||||||
|
LAMB = "lamb"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class BertHParams(hp.BaseHParams):
|
||||||
|
"""The hyperparameters for a Bert Classifier.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
learning_rate: Learning rate to use for gradient descent training.
|
||||||
|
batch_size: Batch size for training.
|
||||||
|
epochs: Number of training iterations over the dataset.
|
||||||
|
optimizer: Optimizer to use for training. Only supported values are "adamw"
|
||||||
|
and "lamb".
|
||||||
|
"""
|
||||||
|
|
||||||
|
learning_rate: float = 3e-5
|
||||||
|
batch_size: int = 48
|
||||||
|
epochs: int = 2
|
||||||
|
optimizer: BertOptimizer = BertOptimizer.ADAMW
|
||||||
|
|
||||||
|
|
||||||
|
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
|
|
@ -17,13 +17,11 @@ import dataclasses
|
||||||
import enum
|
import enum
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
|
||||||
from mediapipe.model_maker.python.core.utils import file_util
|
from mediapipe.model_maker.python.core.utils import file_util
|
||||||
from mediapipe.model_maker.python.text.core import bert_model_spec
|
from mediapipe.model_maker.python.text.core import bert_model_spec
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||||
|
|
||||||
# BERT-based text classifier spec inherited from BertModelSpec
|
|
||||||
BertClassifierSpec = bert_model_spec.BertModelSpec
|
|
||||||
|
|
||||||
MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
|
MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
|
||||||
'text_classifier/mobilebert_tiny',
|
'text_classifier/mobilebert_tiny',
|
||||||
|
@ -31,6 +29,12 @@ MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
|
||||||
is_folder=True,
|
is_folder=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
EXBERT_FILES = file_util.DownloadedFiles(
|
||||||
|
'text_classifier/exbert',
|
||||||
|
'https://storage.googleapis.com/mediapipe-assets/exbert.tar.gz',
|
||||||
|
is_folder=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class AverageWordEmbeddingClassifierSpec:
|
class AverageWordEmbeddingClassifierSpec:
|
||||||
|
@ -43,27 +47,53 @@ class AverageWordEmbeddingClassifierSpec:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# `learning_rate` is unused for the average word embedding model
|
# `learning_rate` is unused for the average word embedding model
|
||||||
hparams: hp.BaseHParams = hp.BaseHParams(
|
hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams(
|
||||||
epochs=10, batch_size=32, learning_rate=0)
|
epochs=10, batch_size=32, learning_rate=0
|
||||||
|
)
|
||||||
model_options: mo.AverageWordEmbeddingModelOptions = (
|
model_options: mo.AverageWordEmbeddingModelOptions = (
|
||||||
mo.AverageWordEmbeddingModelOptions())
|
mo.AverageWordEmbeddingModelOptions())
|
||||||
name: str = 'AverageWordEmbedding'
|
name: str = 'AverageWordEmbedding'
|
||||||
|
|
||||||
|
|
||||||
average_word_embedding_classifier_spec = functools.partial(
|
average_word_embedding_classifier_spec = functools.partial(
|
||||||
AverageWordEmbeddingClassifierSpec)
|
AverageWordEmbeddingClassifierSpec)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class BertClassifierSpec(bert_model_spec.BertModelSpec):
|
||||||
|
"""Specification for a Bert classifier model.
|
||||||
|
|
||||||
|
Only overrides the hparams attribute since the rest of the attributes are
|
||||||
|
inherited from the BertModelSpec.
|
||||||
|
"""
|
||||||
|
|
||||||
|
hparams: hp.BertHParams = hp.BertHParams()
|
||||||
|
|
||||||
|
|
||||||
mobilebert_classifier_spec = functools.partial(
|
mobilebert_classifier_spec = functools.partial(
|
||||||
BertClassifierSpec,
|
BertClassifierSpec,
|
||||||
downloaded_files=MOBILEBERT_TINY_FILES,
|
downloaded_files=MOBILEBERT_TINY_FILES,
|
||||||
hparams=hp.BaseHParams(
|
hparams=hp.BertHParams(
|
||||||
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||||
),
|
),
|
||||||
name='MobileBert',
|
name='MobileBert',
|
||||||
tflite_input_name={
|
tflite_input_name={
|
||||||
'ids': 'serving_default_input_1:0',
|
'ids': 'serving_default_input_1:0',
|
||||||
'mask': 'serving_default_input_3:0',
|
|
||||||
'segment_ids': 'serving_default_input_2:0',
|
'segment_ids': 'serving_default_input_2:0',
|
||||||
|
'mask': 'serving_default_input_3:0',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
exbert_classifier_spec = functools.partial(
|
||||||
|
BertClassifierSpec,
|
||||||
|
downloaded_files=EXBERT_FILES,
|
||||||
|
hparams=hp.BertHParams(
|
||||||
|
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||||
|
),
|
||||||
|
name='ExBert',
|
||||||
|
tflite_input_name={
|
||||||
|
'ids': 'serving_default_input_1:0',
|
||||||
|
'segment_ids': 'serving_default_input_2:0',
|
||||||
|
'mask': 'serving_default_input_3:0',
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -73,3 +103,4 @@ class SupportedModels(enum.Enum):
|
||||||
"""Predefined text classifier model specs supported by Model Maker."""
|
"""Predefined text classifier model specs supported by Model Maker."""
|
||||||
AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
|
AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
|
||||||
MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
|
MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
|
||||||
|
EXBERT_CLASSIFIER = exbert_classifier_spec
|
||||||
|
|
|
@ -19,7 +19,7 @@ from unittest import mock as unittest_mock
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
|
from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
|
|
||||||
|
@ -57,11 +57,13 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
|
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
model_spec_obj.hparams,
|
model_spec_obj.hparams,
|
||||||
hp.BaseHParams(
|
hp.BertHParams(
|
||||||
epochs=3,
|
epochs=3,
|
||||||
batch_size=48,
|
batch_size=48,
|
||||||
learning_rate=3e-5,
|
learning_rate=3e-5,
|
||||||
distribution_strategy='off'))
|
distribution_strategy='off',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def test_predefined_average_word_embedding_spec(self):
|
def test_predefined_average_word_embedding_spec(self):
|
||||||
model_spec_obj = (
|
model_spec_obj = (
|
||||||
|
@ -78,7 +80,7 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
dropout_rate=0.2))
|
dropout_rate=0.2))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
model_spec_obj.hparams,
|
model_spec_obj.hparams,
|
||||||
hp.BaseHParams(
|
hp.AverageWordEmbeddingHParams(
|
||||||
epochs=10,
|
epochs=10,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
learning_rate=0,
|
learning_rate=0,
|
||||||
|
@ -101,7 +103,7 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
custom_bert_classifier_options)
|
custom_bert_classifier_options)
|
||||||
|
|
||||||
def test_custom_average_word_embedding_spec(self):
|
def test_custom_average_word_embedding_spec(self):
|
||||||
custom_hparams = hp.BaseHParams(
|
custom_hparams = hp.AverageWordEmbeddingHParams(
|
||||||
learning_rate=0.4,
|
learning_rate=0.4,
|
||||||
batch_size=64,
|
batch_size=64,
|
||||||
epochs=10,
|
epochs=10,
|
||||||
|
@ -110,7 +112,8 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
export_dir='foo/bar',
|
export_dir='foo/bar',
|
||||||
distribution_strategy='mirrored',
|
distribution_strategy='mirrored',
|
||||||
num_gpus=3,
|
num_gpus=3,
|
||||||
tpu='tpu/address')
|
tpu='tpu/address',
|
||||||
|
)
|
||||||
custom_average_word_embedding_model_options = (
|
custom_average_word_embedding_model_options = (
|
||||||
classifier_model_options.AverageWordEmbeddingModelOptions(
|
classifier_model_options.AverageWordEmbeddingModelOptions(
|
||||||
seq_len=512,
|
seq_len=512,
|
||||||
|
|
|
@ -19,15 +19,16 @@ import tempfile
|
||||||
from typing import Any, Optional, Sequence, Tuple
|
from typing import Any, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_addons import optimizers as tfa_optimizers
|
||||||
import tensorflow_hub as hub
|
import tensorflow_hub as hub
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
|
||||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||||
from mediapipe.model_maker.python.core.tasks import classifier
|
from mediapipe.model_maker.python.core.tasks import classifier
|
||||||
from mediapipe.model_maker.python.core.utils import metrics
|
from mediapipe.model_maker.python.core.utils import metrics
|
||||||
from mediapipe.model_maker.python.core.utils import model_util
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
from mediapipe.model_maker.python.core.utils import quantization
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||||
|
@ -55,22 +56,26 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
|
||||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
|
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
|
||||||
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
|
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
|
||||||
f" got {options.supported_model}")
|
f" got {options.supported_model}")
|
||||||
if (isinstance(options.model_options, mo.BertModelOptions) and
|
if isinstance(options.model_options, mo.BertModelOptions) and (
|
||||||
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
|
options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||||
|
and options.supported_model != ms.SupportedModels.EXBERT_CLASSIFIER
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
|
"Expected a Bert Classifier(MobileBERT or EXBERT), got "
|
||||||
|
f"{options.supported_model}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TextClassifier(classifier.Classifier):
|
class TextClassifier(classifier.Classifier):
|
||||||
"""API for creating and training a text classification model."""
|
"""API for creating and training a text classification model."""
|
||||||
|
|
||||||
def __init__(self, model_spec: Any, hparams: hp.BaseHParams,
|
def __init__(
|
||||||
label_names: Sequence[str]):
|
self, model_spec: Any, label_names: Sequence[str], shuffle: bool
|
||||||
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
|
model_spec=model_spec, label_names=label_names, shuffle=shuffle
|
||||||
|
)
|
||||||
self._model_spec = model_spec
|
self._model_spec = model_spec
|
||||||
self._hparams = hparams
|
|
||||||
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
|
||||||
self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
|
self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -107,7 +112,10 @@ class TextClassifier(classifier.Classifier):
|
||||||
if options.hparams is None:
|
if options.hparams is None:
|
||||||
options.hparams = options.supported_model.value().hparams
|
options.hparams = options.supported_model.value().hparams
|
||||||
|
|
||||||
if options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
|
if (
|
||||||
|
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||||
|
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
|
||||||
|
):
|
||||||
text_classifier = (
|
text_classifier = (
|
||||||
_BertClassifier.create_bert_classifier(train_data, validation_data,
|
_BertClassifier.create_bert_classifier(train_data, validation_data,
|
||||||
options,
|
options,
|
||||||
|
@ -225,11 +233,17 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
|
||||||
|
|
||||||
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
||||||
|
|
||||||
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
|
def __init__(
|
||||||
model_options: mo.AverageWordEmbeddingModelOptions,
|
self,
|
||||||
hparams: hp.BaseHParams, label_names: Sequence[str]):
|
model_spec: ms.AverageWordEmbeddingClassifierSpec,
|
||||||
super().__init__(model_spec, hparams, label_names)
|
model_options: mo.AverageWordEmbeddingModelOptions,
|
||||||
|
hparams: hp.AverageWordEmbeddingHParams,
|
||||||
|
label_names: Sequence[str],
|
||||||
|
):
|
||||||
|
super().__init__(model_spec, label_names, hparams.shuffle)
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
|
self._hparams = hparams
|
||||||
|
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||||
self._loss_function = "sparse_categorical_crossentropy"
|
self._loss_function = "sparse_categorical_crossentropy"
|
||||||
self._metric_functions = [
|
self._metric_functions = [
|
||||||
"accuracy",
|
"accuracy",
|
||||||
|
@ -344,10 +358,16 @@ class _BertClassifier(TextClassifier):
|
||||||
|
|
||||||
_INITIALIZER_RANGE = 0.02
|
_INITIALIZER_RANGE = 0.02
|
||||||
|
|
||||||
def __init__(self, model_spec: ms.BertClassifierSpec,
|
def __init__(
|
||||||
model_options: mo.BertModelOptions, hparams: hp.BaseHParams,
|
self,
|
||||||
label_names: Sequence[str]):
|
model_spec: ms.BertClassifierSpec,
|
||||||
super().__init__(model_spec, hparams, label_names)
|
model_options: mo.BertModelOptions,
|
||||||
|
hparams: hp.BertHParams,
|
||||||
|
label_names: Sequence[str],
|
||||||
|
):
|
||||||
|
super().__init__(model_spec, label_names, hparams.shuffle)
|
||||||
|
self._hparams = hparams
|
||||||
|
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
with self._hparams.get_strategy().scope():
|
with self._hparams.get_strategy().scope():
|
||||||
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||||
|
@ -480,11 +500,26 @@ class _BertClassifier(TextClassifier):
|
||||||
initial_learning_rate=initial_lr,
|
initial_learning_rate=initial_lr,
|
||||||
decay_schedule_fn=lr_schedule,
|
decay_schedule_fn=lr_schedule,
|
||||||
warmup_steps=warmup_steps)
|
warmup_steps=warmup_steps)
|
||||||
|
if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
|
||||||
self._optimizer = tf.keras.optimizers.experimental.AdamW(
|
self._optimizer = tf.keras.optimizers.experimental.AdamW(
|
||||||
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0)
|
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0
|
||||||
self._optimizer.exclude_from_weight_decay(
|
)
|
||||||
var_names=["LayerNorm", "layer_norm", "bias"])
|
self._optimizer.exclude_from_weight_decay(
|
||||||
|
var_names=["LayerNorm", "layer_norm", "bias"]
|
||||||
|
)
|
||||||
|
elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
|
||||||
|
self._optimizer = tfa_optimizers.LAMB(
|
||||||
|
lr_schedule,
|
||||||
|
weight_decay_rate=0.01,
|
||||||
|
epsilon=1e-6,
|
||||||
|
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
|
||||||
|
global_clipnorm=1.0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"BertHParams.optimizer must be set to ADAM or "
|
||||||
|
f"LAMB. Got {self._hparams.optimizer}."
|
||||||
|
)
|
||||||
|
|
||||||
def _save_vocab(self, vocab_filepath: str):
|
def _save_vocab(self, vocab_filepath: str):
|
||||||
tf.io.gfile.copy(
|
tf.io.gfile.copy(
|
||||||
|
|
|
@ -66,14 +66,16 @@ def run(data_dir,
|
||||||
quantization_config = None
|
quantization_config = None
|
||||||
if (supported_model ==
|
if (supported_model ==
|
||||||
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
||||||
hparams = text_classifier.HParams(
|
hparams = text_classifier.AverageWordEmbeddingHParams(
|
||||||
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
|
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir
|
||||||
|
)
|
||||||
# Warning: This takes extremely long to run on CPU
|
# Warning: This takes extremely long to run on CPU
|
||||||
elif (
|
elif (
|
||||||
supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
|
supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
|
||||||
quantization_config = quantization.QuantizationConfig.for_dynamic()
|
quantization_config = quantization.QuantizationConfig.for_dynamic()
|
||||||
hparams = text_classifier.HParams(
|
hparams = text_classifier.BertHParams(
|
||||||
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
|
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir
|
||||||
|
)
|
||||||
|
|
||||||
# Fine-tunes the model.
|
# Fine-tunes the model.
|
||||||
options = text_classifier.TextClassifierOptions(
|
options = text_classifier.TextClassifierOptions(
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
|
|
||||||
|
@ -34,5 +34,5 @@ class TextClassifierOptions:
|
||||||
architecture of the `supported_model`.
|
architecture of the `supported_model`.
|
||||||
"""
|
"""
|
||||||
supported_model: ms.SupportedModels
|
supported_model: ms.SupportedModels
|
||||||
hparams: Optional[hp.BaseHParams] = None
|
hparams: Optional[hp.HParams] = None
|
||||||
model_options: Optional[mo.TextClassifierModelOptions] = None
|
model_options: Optional[mo.TextClassifierModelOptions] = None
|
||||||
|
|
|
@ -66,12 +66,14 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
|
|
||||||
def test_create_and_train_average_word_embedding_model(self):
|
def test_create_and_train_average_word_embedding_model(self):
|
||||||
train_data, validation_data = self._get_data()
|
train_data, validation_data = self._get_data()
|
||||||
options = (
|
options = text_classifier.TextClassifierOptions(
|
||||||
text_classifier.TextClassifierOptions(
|
supported_model=(
|
||||||
supported_model=(text_classifier.SupportedModels
|
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
|
||||||
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
),
|
||||||
hparams=text_classifier.HParams(
|
hparams=text_classifier.AverageWordEmbeddingHParams(
|
||||||
epochs=1, batch_size=1, learning_rate=0)))
|
epochs=1, batch_size=1, learning_rate=0
|
||||||
|
),
|
||||||
|
)
|
||||||
average_word_embedding_classifier = (
|
average_word_embedding_classifier = (
|
||||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||||
options))
|
options))
|
||||||
|
@ -103,12 +105,15 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
options = text_classifier.TextClassifierOptions(
|
options = text_classifier.TextClassifierOptions(
|
||||||
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||||
model_options=text_classifier.BertModelOptions(
|
model_options=text_classifier.BertModelOptions(
|
||||||
do_fine_tuning=False, seq_len=2),
|
do_fine_tuning=False, seq_len=2
|
||||||
hparams=text_classifier.HParams(
|
),
|
||||||
|
hparams=text_classifier.BertHParams(
|
||||||
epochs=1,
|
epochs=1,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
learning_rate=3e-5,
|
learning_rate=3e-5,
|
||||||
distribution_strategy='off'))
|
distribution_strategy='off',
|
||||||
|
),
|
||||||
|
)
|
||||||
bert_classifier = text_classifier.TextClassifier.create(
|
bert_classifier = text_classifier.TextClassifier.create(
|
||||||
train_data, validation_data, options)
|
train_data, validation_data, options)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user