Support ExBert training and option to select between AdamW and LAMB optimizers for BertClassifier

PiperOrigin-RevId: 543905014
This commit is contained in:
MediaPipe Team 2023-06-27 18:02:59 -07:00 committed by Copybara-Service
parent bed624f3b6
commit 1ee55d1f1b
9 changed files with 202 additions and 62 deletions

View File

@ -31,11 +31,11 @@ py_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":dataset", ":dataset",
":hyperparameters",
":model_options", ":model_options",
":model_spec", ":model_spec",
":text_classifier", ":text_classifier",
":text_classifier_options", ":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
], ],
) )
@ -45,12 +45,18 @@ py_library(
deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"], deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
) )
py_library(
name = "hyperparameters",
srcs = ["hyperparameters.py"],
deps = ["//mediapipe/model_maker/python/core:hyperparameters"],
)
py_library( py_library(
name = "model_spec", name = "model_spec",
srcs = ["model_spec.py"], srcs = ["model_spec.py"],
deps = [ deps = [
":hyperparameters",
":model_options", ":model_options",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/core/utils:file_util", "//mediapipe/model_maker/python/core/utils:file_util",
"//mediapipe/model_maker/python/text/core:bert_model_spec", "//mediapipe/model_maker/python/text/core:bert_model_spec",
], ],
@ -61,9 +67,9 @@ py_test(
srcs = ["model_spec_test.py"], srcs = ["model_spec_test.py"],
tags = ["requires-net:external"], tags = ["requires-net:external"],
deps = [ deps = [
":hyperparameters",
":model_options", ":model_options",
":model_spec", ":model_spec",
"//mediapipe/model_maker/python/core:hyperparameters",
], ],
) )
@ -100,9 +106,9 @@ py_library(
name = "text_classifier_options", name = "text_classifier_options",
srcs = ["text_classifier_options.py"], srcs = ["text_classifier_options.py"],
deps = [ deps = [
":hyperparameters",
":model_options", ":model_options",
":model_spec", ":model_spec",
"//mediapipe/model_maker/python/core:hyperparameters",
], ],
) )
@ -111,11 +117,11 @@ py_library(
srcs = ["text_classifier.py"], srcs = ["text_classifier.py"],
deps = [ deps = [
":dataset", ":dataset",
":hyperparameters",
":model_options", ":model_options",
":model_spec", ":model_spec",
":preprocessor", ":preprocessor",
":text_classifier_options", ":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/data:dataset",
"//mediapipe/model_maker/python/core/tasks:classifier", "//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:metrics", "//mediapipe/model_maker/python/core/utils:metrics",

View File

@ -13,19 +13,23 @@
# limitations under the License. # limitations under the License.
"""MediaPipe Public Python API for Text Classifier.""" """MediaPipe Public Python API for Text Classifier."""
from mediapipe.model_maker.python.core import hyperparameters
from mediapipe.model_maker.python.text.text_classifier import dataset from mediapipe.model_maker.python.text.text_classifier import dataset
from mediapipe.model_maker.python.text.text_classifier import hyperparameters
from mediapipe.model_maker.python.text.text_classifier import model_options from mediapipe.model_maker.python.text.text_classifier import model_options
from mediapipe.model_maker.python.text.text_classifier import model_spec from mediapipe.model_maker.python.text.text_classifier import model_spec
from mediapipe.model_maker.python.text.text_classifier import text_classifier from mediapipe.model_maker.python.text.text_classifier import text_classifier
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
HParams = hyperparameters.BaseHParams
AverageWordEmbeddingHParams = hyperparameters.AverageWordEmbeddingHParams
AverageWordEmbeddingModelOptions = (
model_options.AverageWordEmbeddingModelOptions
)
BertOptimizer = hyperparameters.BertOptimizer
BertHParams = hyperparameters.BertHParams
BertModelOptions = model_options.BertModelOptions
CSVParams = dataset.CSVParameters CSVParams = dataset.CSVParameters
Dataset = dataset.Dataset Dataset = dataset.Dataset
AverageWordEmbeddingModelOptions = (
model_options.AverageWordEmbeddingModelOptions)
BertModelOptions = model_options.BertModelOptions
SupportedModels = model_spec.SupportedModels SupportedModels = model_spec.SupportedModels
TextClassifier = text_classifier.TextClassifier TextClassifier = text_classifier.TextClassifier
TextClassifierOptions = text_classifier_options.TextClassifierOptions TextClassifierOptions = text_classifier_options.TextClassifierOptions

View File

@ -0,0 +1,54 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hyperparameters for training object detection models."""
import dataclasses
import enum
from typing import Union
from mediapipe.model_maker.python.core import hyperparameters as hp
@dataclasses.dataclass
class AverageWordEmbeddingHParams(hp.BaseHParams):
"""The hyperparameters for an AverageWordEmbeddingClassifier."""
@enum.unique
class BertOptimizer(enum.Enum):
"""Supported Optimizers for Bert Text Classifier."""
ADAMW = "adamw"
LAMB = "lamb"
@dataclasses.dataclass
class BertHParams(hp.BaseHParams):
"""The hyperparameters for a Bert Classifier.
Attributes:
learning_rate: Learning rate to use for gradient descent training.
batch_size: Batch size for training.
epochs: Number of training iterations over the dataset.
optimizer: Optimizer to use for training. Only supported values are "adamw"
and "lamb".
"""
learning_rate: float = 3e-5
batch_size: int = 48
epochs: int = 2
optimizer: BertOptimizer = BertOptimizer.ADAMW
HParams = Union[BertHParams, AverageWordEmbeddingHParams]

View File

@ -17,13 +17,11 @@ import dataclasses
import enum import enum
import functools import functools
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.utils import file_util from mediapipe.model_maker.python.core.utils import file_util
from mediapipe.model_maker.python.text.core import bert_model_spec from mediapipe.model_maker.python.text.core import bert_model_spec
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as mo from mediapipe.model_maker.python.text.text_classifier import model_options as mo
# BERT-based text classifier spec inherited from BertModelSpec
BertClassifierSpec = bert_model_spec.BertModelSpec
MOBILEBERT_TINY_FILES = file_util.DownloadedFiles( MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
'text_classifier/mobilebert_tiny', 'text_classifier/mobilebert_tiny',
@ -31,6 +29,12 @@ MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
is_folder=True, is_folder=True,
) )
EXBERT_FILES = file_util.DownloadedFiles(
'text_classifier/exbert',
'https://storage.googleapis.com/mediapipe-assets/exbert.tar.gz',
is_folder=True,
)
@dataclasses.dataclass @dataclasses.dataclass
class AverageWordEmbeddingClassifierSpec: class AverageWordEmbeddingClassifierSpec:
@ -43,27 +47,53 @@ class AverageWordEmbeddingClassifierSpec:
""" """
# `learning_rate` is unused for the average word embedding model # `learning_rate` is unused for the average word embedding model
hparams: hp.BaseHParams = hp.BaseHParams( hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams(
epochs=10, batch_size=32, learning_rate=0) epochs=10, batch_size=32, learning_rate=0
)
model_options: mo.AverageWordEmbeddingModelOptions = ( model_options: mo.AverageWordEmbeddingModelOptions = (
mo.AverageWordEmbeddingModelOptions()) mo.AverageWordEmbeddingModelOptions())
name: str = 'AverageWordEmbedding' name: str = 'AverageWordEmbedding'
average_word_embedding_classifier_spec = functools.partial( average_word_embedding_classifier_spec = functools.partial(
AverageWordEmbeddingClassifierSpec) AverageWordEmbeddingClassifierSpec)
@dataclasses.dataclass
class BertClassifierSpec(bert_model_spec.BertModelSpec):
"""Specification for a Bert classifier model.
Only overrides the hparams attribute since the rest of the attributes are
inherited from the BertModelSpec.
"""
hparams: hp.BertHParams = hp.BertHParams()
mobilebert_classifier_spec = functools.partial( mobilebert_classifier_spec = functools.partial(
BertClassifierSpec, BertClassifierSpec,
downloaded_files=MOBILEBERT_TINY_FILES, downloaded_files=MOBILEBERT_TINY_FILES,
hparams=hp.BaseHParams( hparams=hp.BertHParams(
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
), ),
name='MobileBert', name='MobileBert',
tflite_input_name={ tflite_input_name={
'ids': 'serving_default_input_1:0', 'ids': 'serving_default_input_1:0',
'mask': 'serving_default_input_3:0',
'segment_ids': 'serving_default_input_2:0', 'segment_ids': 'serving_default_input_2:0',
'mask': 'serving_default_input_3:0',
},
)
exbert_classifier_spec = functools.partial(
BertClassifierSpec,
downloaded_files=EXBERT_FILES,
hparams=hp.BertHParams(
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
),
name='ExBert',
tflite_input_name={
'ids': 'serving_default_input_1:0',
'segment_ids': 'serving_default_input_2:0',
'mask': 'serving_default_input_3:0',
}, },
) )
@ -73,3 +103,4 @@ class SupportedModels(enum.Enum):
"""Predefined text classifier model specs supported by Model Maker.""" """Predefined text classifier model specs supported by Model Maker."""
AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
EXBERT_CLASSIFIER = exbert_classifier_spec

View File

@ -19,7 +19,7 @@ from unittest import mock as unittest_mock
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core import hyperparameters as hp from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
@ -57,11 +57,13 @@ class ModelSpecTest(tf.test.TestCase):
seq_len=128, do_fine_tuning=True, dropout_rate=0.1)) seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
self.assertEqual( self.assertEqual(
model_spec_obj.hparams, model_spec_obj.hparams,
hp.BaseHParams( hp.BertHParams(
epochs=3, epochs=3,
batch_size=48, batch_size=48,
learning_rate=3e-5, learning_rate=3e-5,
distribution_strategy='off')) distribution_strategy='off',
),
)
def test_predefined_average_word_embedding_spec(self): def test_predefined_average_word_embedding_spec(self):
model_spec_obj = ( model_spec_obj = (
@ -78,7 +80,7 @@ class ModelSpecTest(tf.test.TestCase):
dropout_rate=0.2)) dropout_rate=0.2))
self.assertEqual( self.assertEqual(
model_spec_obj.hparams, model_spec_obj.hparams,
hp.BaseHParams( hp.AverageWordEmbeddingHParams(
epochs=10, epochs=10,
batch_size=32, batch_size=32,
learning_rate=0, learning_rate=0,
@ -101,7 +103,7 @@ class ModelSpecTest(tf.test.TestCase):
custom_bert_classifier_options) custom_bert_classifier_options)
def test_custom_average_word_embedding_spec(self): def test_custom_average_word_embedding_spec(self):
custom_hparams = hp.BaseHParams( custom_hparams = hp.AverageWordEmbeddingHParams(
learning_rate=0.4, learning_rate=0.4,
batch_size=64, batch_size=64,
epochs=10, epochs=10,
@ -110,7 +112,8 @@ class ModelSpecTest(tf.test.TestCase):
export_dir='foo/bar', export_dir='foo/bar',
distribution_strategy='mirrored', distribution_strategy='mirrored',
num_gpus=3, num_gpus=3,
tpu='tpu/address') tpu='tpu/address',
)
custom_average_word_embedding_model_options = ( custom_average_word_embedding_model_options = (
classifier_model_options.AverageWordEmbeddingModelOptions( classifier_model_options.AverageWordEmbeddingModelOptions(
seq_len=512, seq_len=512,

View File

@ -19,15 +19,16 @@ import tempfile
from typing import Any, Optional, Sequence, Tuple from typing import Any, Optional, Sequence, Tuple
import tensorflow as tf import tensorflow as tf
from tensorflow_addons import optimizers as tfa_optimizers
import tensorflow_hub as hub import tensorflow_hub as hub
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.data import dataset as ds from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.tasks import classifier from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import metrics from mediapipe.model_maker.python.core.utils import metrics
from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as mo from mediapipe.model_maker.python.text.text_classifier import model_options as mo
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
from mediapipe.model_maker.python.text.text_classifier import preprocessor from mediapipe.model_maker.python.text.text_classifier import preprocessor
@ -55,22 +56,26 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)): ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER," raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
f" got {options.supported_model}") f" got {options.supported_model}")
if (isinstance(options.model_options, mo.BertModelOptions) and if isinstance(options.model_options, mo.BertModelOptions) and (
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)): options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER
and options.supported_model != ms.SupportedModels.EXBERT_CLASSIFIER
):
raise ValueError( raise ValueError(
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}") "Expected a Bert Classifier(MobileBERT or EXBERT), got "
f"{options.supported_model}"
)
class TextClassifier(classifier.Classifier): class TextClassifier(classifier.Classifier):
"""API for creating and training a text classification model.""" """API for creating and training a text classification model."""
def __init__(self, model_spec: Any, hparams: hp.BaseHParams, def __init__(
label_names: Sequence[str]): self, model_spec: Any, label_names: Sequence[str], shuffle: bool
):
super().__init__( super().__init__(
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle) model_spec=model_spec, label_names=label_names, shuffle=shuffle
)
self._model_spec = model_spec self._model_spec = model_spec
self._hparams = hparams
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
@classmethod @classmethod
@ -107,7 +112,10 @@ class TextClassifier(classifier.Classifier):
if options.hparams is None: if options.hparams is None:
options.hparams = options.supported_model.value().hparams options.hparams = options.supported_model.value().hparams
if options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER: if (
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
):
text_classifier = ( text_classifier = (
_BertClassifier.create_bert_classifier(train_data, validation_data, _BertClassifier.create_bert_classifier(train_data, validation_data,
options, options,
@ -225,11 +233,17 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
_DELIM_REGEX_PATTERN = r"[^\w\']+" _DELIM_REGEX_PATTERN = r"[^\w\']+"
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec, def __init__(
self,
model_spec: ms.AverageWordEmbeddingClassifierSpec,
model_options: mo.AverageWordEmbeddingModelOptions, model_options: mo.AverageWordEmbeddingModelOptions,
hparams: hp.BaseHParams, label_names: Sequence[str]): hparams: hp.AverageWordEmbeddingHParams,
super().__init__(model_spec, hparams, label_names) label_names: Sequence[str],
):
super().__init__(model_spec, label_names, hparams.shuffle)
self._model_options = model_options self._model_options = model_options
self._hparams = hparams
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._loss_function = "sparse_categorical_crossentropy" self._loss_function = "sparse_categorical_crossentropy"
self._metric_functions = [ self._metric_functions = [
"accuracy", "accuracy",
@ -344,10 +358,16 @@ class _BertClassifier(TextClassifier):
_INITIALIZER_RANGE = 0.02 _INITIALIZER_RANGE = 0.02
def __init__(self, model_spec: ms.BertClassifierSpec, def __init__(
model_options: mo.BertModelOptions, hparams: hp.BaseHParams, self,
label_names: Sequence[str]): model_spec: ms.BertClassifierSpec,
super().__init__(model_spec, hparams, label_names) model_options: mo.BertModelOptions,
hparams: hp.BertHParams,
label_names: Sequence[str],
):
super().__init__(model_spec, label_names, hparams.shuffle)
self._hparams = hparams
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._model_options = model_options self._model_options = model_options
with self._hparams.get_strategy().scope(): with self._hparams.get_strategy().scope():
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy() self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
@ -480,11 +500,26 @@ class _BertClassifier(TextClassifier):
initial_learning_rate=initial_lr, initial_learning_rate=initial_lr,
decay_schedule_fn=lr_schedule, decay_schedule_fn=lr_schedule,
warmup_steps=warmup_steps) warmup_steps=warmup_steps)
if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
self._optimizer = tf.keras.optimizers.experimental.AdamW( self._optimizer = tf.keras.optimizers.experimental.AdamW(
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0) lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0
)
self._optimizer.exclude_from_weight_decay( self._optimizer.exclude_from_weight_decay(
var_names=["LayerNorm", "layer_norm", "bias"]) var_names=["LayerNorm", "layer_norm", "bias"]
)
elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
self._optimizer = tfa_optimizers.LAMB(
lr_schedule,
weight_decay_rate=0.01,
epsilon=1e-6,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
global_clipnorm=1.0,
)
else:
raise ValueError(
"BertHParams.optimizer must be set to ADAM or "
f"LAMB. Got {self._hparams.optimizer}."
)
def _save_vocab(self, vocab_filepath: str): def _save_vocab(self, vocab_filepath: str):
tf.io.gfile.copy( tf.io.gfile.copy(

View File

@ -66,14 +66,16 @@ def run(data_dir,
quantization_config = None quantization_config = None
if (supported_model == if (supported_model ==
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER): text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
hparams = text_classifier.HParams( hparams = text_classifier.AverageWordEmbeddingHParams(
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir) epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir
)
# Warning: This takes extremely long to run on CPU # Warning: This takes extremely long to run on CPU
elif ( elif (
supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER): supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
quantization_config = quantization.QuantizationConfig.for_dynamic() quantization_config = quantization.QuantizationConfig.for_dynamic()
hparams = text_classifier.HParams( hparams = text_classifier.BertHParams(
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir) epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir
)
# Fine-tunes the model. # Fine-tunes the model.
options = text_classifier.TextClassifierOptions( options = text_classifier.TextClassifierOptions(

View File

@ -16,7 +16,7 @@
import dataclasses import dataclasses
from typing import Optional from typing import Optional
from mediapipe.model_maker.python.core import hyperparameters as hp from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as mo from mediapipe.model_maker.python.text.text_classifier import model_options as mo
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
@ -34,5 +34,5 @@ class TextClassifierOptions:
architecture of the `supported_model`. architecture of the `supported_model`.
""" """
supported_model: ms.SupportedModels supported_model: ms.SupportedModels
hparams: Optional[hp.BaseHParams] = None hparams: Optional[hp.HParams] = None
model_options: Optional[mo.TextClassifierModelOptions] = None model_options: Optional[mo.TextClassifierModelOptions] = None

View File

@ -66,12 +66,14 @@ class TextClassifierTest(tf.test.TestCase):
def test_create_and_train_average_word_embedding_model(self): def test_create_and_train_average_word_embedding_model(self):
train_data, validation_data = self._get_data() train_data, validation_data = self._get_data()
options = ( options = text_classifier.TextClassifierOptions(
text_classifier.TextClassifierOptions( supported_model=(
supported_model=(text_classifier.SupportedModels text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
.AVERAGE_WORD_EMBEDDING_CLASSIFIER), ),
hparams=text_classifier.HParams( hparams=text_classifier.AverageWordEmbeddingHParams(
epochs=1, batch_size=1, learning_rate=0))) epochs=1, batch_size=1, learning_rate=0
),
)
average_word_embedding_classifier = ( average_word_embedding_classifier = (
text_classifier.TextClassifier.create(train_data, validation_data, text_classifier.TextClassifier.create(train_data, validation_data,
options)) options))
@ -103,12 +105,15 @@ class TextClassifierTest(tf.test.TestCase):
options = text_classifier.TextClassifierOptions( options = text_classifier.TextClassifierOptions(
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER, supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
model_options=text_classifier.BertModelOptions( model_options=text_classifier.BertModelOptions(
do_fine_tuning=False, seq_len=2), do_fine_tuning=False, seq_len=2
hparams=text_classifier.HParams( ),
hparams=text_classifier.BertHParams(
epochs=1, epochs=1,
batch_size=1, batch_size=1,
learning_rate=3e-5, learning_rate=3e-5,
distribution_strategy='off')) distribution_strategy='off',
),
)
bert_classifier = text_classifier.TextClassifier.create( bert_classifier = text_classifier.TextClassifier.create(
train_data, validation_data, options) train_data, validation_data, options)