Support ExBert training and option to select between AdamW and LAMB optimizers for BertClassifier
PiperOrigin-RevId: 543905014
This commit is contained in:
parent
bed624f3b6
commit
1ee55d1f1b
|
@ -31,11 +31,11 @@ py_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":hyperparameters",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
":text_classifier",
|
||||
":text_classifier_options",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -45,12 +45,18 @@ py_library(
|
|||
deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "hyperparameters",
|
||||
srcs = ["hyperparameters.py"],
|
||||
deps = ["//mediapipe/model_maker/python/core:hyperparameters"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_spec",
|
||||
srcs = ["model_spec.py"],
|
||||
deps = [
|
||||
":hyperparameters",
|
||||
":model_options",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
"//mediapipe/model_maker/python/core/utils:file_util",
|
||||
"//mediapipe/model_maker/python/text/core:bert_model_spec",
|
||||
],
|
||||
|
@ -61,9 +67,9 @@ py_test(
|
|||
srcs = ["model_spec_test.py"],
|
||||
tags = ["requires-net:external"],
|
||||
deps = [
|
||||
":hyperparameters",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -100,9 +106,9 @@ py_library(
|
|||
name = "text_classifier_options",
|
||||
srcs = ["text_classifier_options.py"],
|
||||
deps = [
|
||||
":hyperparameters",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -111,11 +117,11 @@ py_library(
|
|||
srcs = ["text_classifier.py"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":hyperparameters",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
":preprocessor",
|
||||
":text_classifier_options",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||
"//mediapipe/model_maker/python/core/utils:metrics",
|
||||
|
|
|
@ -13,19 +13,23 @@
|
|||
# limitations under the License.
|
||||
"""MediaPipe Public Python API for Text Classifier."""
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset
|
||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_options
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||
|
||||
HParams = hyperparameters.BaseHParams
|
||||
|
||||
AverageWordEmbeddingHParams = hyperparameters.AverageWordEmbeddingHParams
|
||||
AverageWordEmbeddingModelOptions = (
|
||||
model_options.AverageWordEmbeddingModelOptions
|
||||
)
|
||||
BertOptimizer = hyperparameters.BertOptimizer
|
||||
BertHParams = hyperparameters.BertHParams
|
||||
BertModelOptions = model_options.BertModelOptions
|
||||
CSVParams = dataset.CSVParameters
|
||||
Dataset = dataset.Dataset
|
||||
AverageWordEmbeddingModelOptions = (
|
||||
model_options.AverageWordEmbeddingModelOptions)
|
||||
BertModelOptions = model_options.BertModelOptions
|
||||
SupportedModels = model_spec.SupportedModels
|
||||
TextClassifier = text_classifier.TextClassifier
|
||||
TextClassifierOptions = text_classifier_options.TextClassifierOptions
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Hyperparameters for training object detection models."""
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing import Union
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AverageWordEmbeddingHParams(hp.BaseHParams):
|
||||
"""The hyperparameters for an AverageWordEmbeddingClassifier."""
|
||||
|
||||
|
||||
@enum.unique
|
||||
class BertOptimizer(enum.Enum):
|
||||
"""Supported Optimizers for Bert Text Classifier."""
|
||||
|
||||
ADAMW = "adamw"
|
||||
LAMB = "lamb"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BertHParams(hp.BaseHParams):
|
||||
"""The hyperparameters for a Bert Classifier.
|
||||
|
||||
Attributes:
|
||||
learning_rate: Learning rate to use for gradient descent training.
|
||||
batch_size: Batch size for training.
|
||||
epochs: Number of training iterations over the dataset.
|
||||
optimizer: Optimizer to use for training. Only supported values are "adamw"
|
||||
and "lamb".
|
||||
"""
|
||||
|
||||
learning_rate: float = 3e-5
|
||||
batch_size: int = 48
|
||||
epochs: int = 2
|
||||
optimizer: BertOptimizer = BertOptimizer.ADAMW
|
||||
|
||||
|
||||
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
|
|
@ -17,13 +17,11 @@ import dataclasses
|
|||
import enum
|
||||
import functools
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.core.utils import file_util
|
||||
from mediapipe.model_maker.python.text.core import bert_model_spec
|
||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||
|
||||
# BERT-based text classifier spec inherited from BertModelSpec
|
||||
BertClassifierSpec = bert_model_spec.BertModelSpec
|
||||
|
||||
MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
|
||||
'text_classifier/mobilebert_tiny',
|
||||
|
@ -31,6 +29,12 @@ MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
|
|||
is_folder=True,
|
||||
)
|
||||
|
||||
EXBERT_FILES = file_util.DownloadedFiles(
|
||||
'text_classifier/exbert',
|
||||
'https://storage.googleapis.com/mediapipe-assets/exbert.tar.gz',
|
||||
is_folder=True,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AverageWordEmbeddingClassifierSpec:
|
||||
|
@ -43,27 +47,53 @@ class AverageWordEmbeddingClassifierSpec:
|
|||
"""
|
||||
|
||||
# `learning_rate` is unused for the average word embedding model
|
||||
hparams: hp.BaseHParams = hp.BaseHParams(
|
||||
epochs=10, batch_size=32, learning_rate=0)
|
||||
hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams(
|
||||
epochs=10, batch_size=32, learning_rate=0
|
||||
)
|
||||
model_options: mo.AverageWordEmbeddingModelOptions = (
|
||||
mo.AverageWordEmbeddingModelOptions())
|
||||
name: str = 'AverageWordEmbedding'
|
||||
|
||||
|
||||
average_word_embedding_classifier_spec = functools.partial(
|
||||
AverageWordEmbeddingClassifierSpec)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BertClassifierSpec(bert_model_spec.BertModelSpec):
|
||||
"""Specification for a Bert classifier model.
|
||||
|
||||
Only overrides the hparams attribute since the rest of the attributes are
|
||||
inherited from the BertModelSpec.
|
||||
"""
|
||||
|
||||
hparams: hp.BertHParams = hp.BertHParams()
|
||||
|
||||
|
||||
mobilebert_classifier_spec = functools.partial(
|
||||
BertClassifierSpec,
|
||||
downloaded_files=MOBILEBERT_TINY_FILES,
|
||||
hparams=hp.BaseHParams(
|
||||
hparams=hp.BertHParams(
|
||||
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||
),
|
||||
name='MobileBert',
|
||||
tflite_input_name={
|
||||
'ids': 'serving_default_input_1:0',
|
||||
'mask': 'serving_default_input_3:0',
|
||||
'segment_ids': 'serving_default_input_2:0',
|
||||
'mask': 'serving_default_input_3:0',
|
||||
},
|
||||
)
|
||||
|
||||
exbert_classifier_spec = functools.partial(
|
||||
BertClassifierSpec,
|
||||
downloaded_files=EXBERT_FILES,
|
||||
hparams=hp.BertHParams(
|
||||
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||
),
|
||||
name='ExBert',
|
||||
tflite_input_name={
|
||||
'ids': 'serving_default_input_1:0',
|
||||
'segment_ids': 'serving_default_input_2:0',
|
||||
'mask': 'serving_default_input_3:0',
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -73,3 +103,4 @@ class SupportedModels(enum.Enum):
|
|||
"""Predefined text classifier model specs supported by Model Maker."""
|
||||
AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
|
||||
MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
|
||||
EXBERT_CLASSIFIER = exbert_classifier_spec
|
||||
|
|
|
@ -19,7 +19,7 @@ from unittest import mock as unittest_mock
|
|||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||
|
||||
|
@ -57,11 +57,13 @@ class ModelSpecTest(tf.test.TestCase):
|
|||
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
|
||||
self.assertEqual(
|
||||
model_spec_obj.hparams,
|
||||
hp.BaseHParams(
|
||||
hp.BertHParams(
|
||||
epochs=3,
|
||||
batch_size=48,
|
||||
learning_rate=3e-5,
|
||||
distribution_strategy='off'))
|
||||
distribution_strategy='off',
|
||||
),
|
||||
)
|
||||
|
||||
def test_predefined_average_word_embedding_spec(self):
|
||||
model_spec_obj = (
|
||||
|
@ -78,7 +80,7 @@ class ModelSpecTest(tf.test.TestCase):
|
|||
dropout_rate=0.2))
|
||||
self.assertEqual(
|
||||
model_spec_obj.hparams,
|
||||
hp.BaseHParams(
|
||||
hp.AverageWordEmbeddingHParams(
|
||||
epochs=10,
|
||||
batch_size=32,
|
||||
learning_rate=0,
|
||||
|
@ -101,7 +103,7 @@ class ModelSpecTest(tf.test.TestCase):
|
|||
custom_bert_classifier_options)
|
||||
|
||||
def test_custom_average_word_embedding_spec(self):
|
||||
custom_hparams = hp.BaseHParams(
|
||||
custom_hparams = hp.AverageWordEmbeddingHParams(
|
||||
learning_rate=0.4,
|
||||
batch_size=64,
|
||||
epochs=10,
|
||||
|
@ -110,7 +112,8 @@ class ModelSpecTest(tf.test.TestCase):
|
|||
export_dir='foo/bar',
|
||||
distribution_strategy='mirrored',
|
||||
num_gpus=3,
|
||||
tpu='tpu/address')
|
||||
tpu='tpu/address',
|
||||
)
|
||||
custom_average_word_embedding_model_options = (
|
||||
classifier_model_options.AverageWordEmbeddingModelOptions(
|
||||
seq_len=512,
|
||||
|
|
|
@ -19,15 +19,16 @@ import tempfile
|
|||
from typing import Any, Optional, Sequence, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_addons import optimizers as tfa_optimizers
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||
from mediapipe.model_maker.python.core.tasks import classifier
|
||||
from mediapipe.model_maker.python.core.utils import metrics
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||
|
@ -55,22 +56,26 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
|
|||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
|
||||
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
|
||||
f" got {options.supported_model}")
|
||||
if (isinstance(options.model_options, mo.BertModelOptions) and
|
||||
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
|
||||
if isinstance(options.model_options, mo.BertModelOptions) and (
|
||||
options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||
and options.supported_model != ms.SupportedModels.EXBERT_CLASSIFIER
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
|
||||
"Expected a Bert Classifier(MobileBERT or EXBERT), got "
|
||||
f"{options.supported_model}"
|
||||
)
|
||||
|
||||
|
||||
class TextClassifier(classifier.Classifier):
|
||||
"""API for creating and training a text classification model."""
|
||||
|
||||
def __init__(self, model_spec: Any, hparams: hp.BaseHParams,
|
||||
label_names: Sequence[str]):
|
||||
def __init__(
|
||||
self, model_spec: Any, label_names: Sequence[str], shuffle: bool
|
||||
):
|
||||
super().__init__(
|
||||
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
|
||||
model_spec=model_spec, label_names=label_names, shuffle=shuffle
|
||||
)
|
||||
self._model_spec = model_spec
|
||||
self._hparams = hparams
|
||||
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||
self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
|
||||
|
||||
@classmethod
|
||||
|
@ -107,7 +112,10 @@ class TextClassifier(classifier.Classifier):
|
|||
if options.hparams is None:
|
||||
options.hparams = options.supported_model.value().hparams
|
||||
|
||||
if options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
|
||||
if (
|
||||
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
|
||||
):
|
||||
text_classifier = (
|
||||
_BertClassifier.create_bert_classifier(train_data, validation_data,
|
||||
options,
|
||||
|
@ -225,11 +233,17 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
|
|||
|
||||
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
||||
|
||||
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
|
||||
def __init__(
|
||||
self,
|
||||
model_spec: ms.AverageWordEmbeddingClassifierSpec,
|
||||
model_options: mo.AverageWordEmbeddingModelOptions,
|
||||
hparams: hp.BaseHParams, label_names: Sequence[str]):
|
||||
super().__init__(model_spec, hparams, label_names)
|
||||
hparams: hp.AverageWordEmbeddingHParams,
|
||||
label_names: Sequence[str],
|
||||
):
|
||||
super().__init__(model_spec, label_names, hparams.shuffle)
|
||||
self._model_options = model_options
|
||||
self._hparams = hparams
|
||||
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||
self._loss_function = "sparse_categorical_crossentropy"
|
||||
self._metric_functions = [
|
||||
"accuracy",
|
||||
|
@ -344,10 +358,16 @@ class _BertClassifier(TextClassifier):
|
|||
|
||||
_INITIALIZER_RANGE = 0.02
|
||||
|
||||
def __init__(self, model_spec: ms.BertClassifierSpec,
|
||||
model_options: mo.BertModelOptions, hparams: hp.BaseHParams,
|
||||
label_names: Sequence[str]):
|
||||
super().__init__(model_spec, hparams, label_names)
|
||||
def __init__(
|
||||
self,
|
||||
model_spec: ms.BertClassifierSpec,
|
||||
model_options: mo.BertModelOptions,
|
||||
hparams: hp.BertHParams,
|
||||
label_names: Sequence[str],
|
||||
):
|
||||
super().__init__(model_spec, label_names, hparams.shuffle)
|
||||
self._hparams = hparams
|
||||
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||
self._model_options = model_options
|
||||
with self._hparams.get_strategy().scope():
|
||||
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
|
@ -480,11 +500,26 @@ class _BertClassifier(TextClassifier):
|
|||
initial_learning_rate=initial_lr,
|
||||
decay_schedule_fn=lr_schedule,
|
||||
warmup_steps=warmup_steps)
|
||||
|
||||
if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
|
||||
self._optimizer = tf.keras.optimizers.experimental.AdamW(
|
||||
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0)
|
||||
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0
|
||||
)
|
||||
self._optimizer.exclude_from_weight_decay(
|
||||
var_names=["LayerNorm", "layer_norm", "bias"])
|
||||
var_names=["LayerNorm", "layer_norm", "bias"]
|
||||
)
|
||||
elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
|
||||
self._optimizer = tfa_optimizers.LAMB(
|
||||
lr_schedule,
|
||||
weight_decay_rate=0.01,
|
||||
epsilon=1e-6,
|
||||
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
|
||||
global_clipnorm=1.0,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"BertHParams.optimizer must be set to ADAM or "
|
||||
f"LAMB. Got {self._hparams.optimizer}."
|
||||
)
|
||||
|
||||
def _save_vocab(self, vocab_filepath: str):
|
||||
tf.io.gfile.copy(
|
||||
|
|
|
@ -66,14 +66,16 @@ def run(data_dir,
|
|||
quantization_config = None
|
||||
if (supported_model ==
|
||||
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
||||
hparams = text_classifier.HParams(
|
||||
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
|
||||
hparams = text_classifier.AverageWordEmbeddingHParams(
|
||||
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir
|
||||
)
|
||||
# Warning: This takes extremely long to run on CPU
|
||||
elif (
|
||||
supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
|
||||
quantization_config = quantization.QuantizationConfig.for_dynamic()
|
||||
hparams = text_classifier.HParams(
|
||||
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
|
||||
hparams = text_classifier.BertHParams(
|
||||
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir
|
||||
)
|
||||
|
||||
# Fine-tunes the model.
|
||||
options = text_classifier.TextClassifierOptions(
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||
|
||||
|
@ -34,5 +34,5 @@ class TextClassifierOptions:
|
|||
architecture of the `supported_model`.
|
||||
"""
|
||||
supported_model: ms.SupportedModels
|
||||
hparams: Optional[hp.BaseHParams] = None
|
||||
hparams: Optional[hp.HParams] = None
|
||||
model_options: Optional[mo.TextClassifierModelOptions] = None
|
||||
|
|
|
@ -66,12 +66,14 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
|
||||
def test_create_and_train_average_word_embedding_model(self):
|
||||
train_data, validation_data = self._get_data()
|
||||
options = (
|
||||
text_classifier.TextClassifierOptions(
|
||||
supported_model=(text_classifier.SupportedModels
|
||||
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
||||
hparams=text_classifier.HParams(
|
||||
epochs=1, batch_size=1, learning_rate=0)))
|
||||
options = text_classifier.TextClassifierOptions(
|
||||
supported_model=(
|
||||
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
|
||||
),
|
||||
hparams=text_classifier.AverageWordEmbeddingHParams(
|
||||
epochs=1, batch_size=1, learning_rate=0
|
||||
),
|
||||
)
|
||||
average_word_embedding_classifier = (
|
||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||
options))
|
||||
|
@ -103,12 +105,15 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
options = text_classifier.TextClassifierOptions(
|
||||
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||
model_options=text_classifier.BertModelOptions(
|
||||
do_fine_tuning=False, seq_len=2),
|
||||
hparams=text_classifier.HParams(
|
||||
do_fine_tuning=False, seq_len=2
|
||||
),
|
||||
hparams=text_classifier.BertHParams(
|
||||
epochs=1,
|
||||
batch_size=1,
|
||||
learning_rate=3e-5,
|
||||
distribution_strategy='off'))
|
||||
distribution_strategy='off',
|
||||
),
|
||||
)
|
||||
bert_classifier = text_classifier.TextClassifier.create(
|
||||
train_data, validation_data, options)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user