Internal change

PiperOrigin-RevId: 547346939
This commit is contained in:
Yilei Yang 2023-07-11 17:48:46 -07:00 committed by Copybara-Service
parent f2f49b9fc8
commit 917af2ce6b
2 changed files with 19 additions and 12 deletions

View File

@ -46,13 +46,17 @@ class BertModelSpec:
"""
downloaded_files: file_util.DownloadedFiles
hparams: hp.BaseHParams = hp.BaseHParams(
hparams: hp.BaseHParams = dataclasses.field(
default_factory=lambda: hp.BaseHParams(
epochs=3,
batch_size=32,
learning_rate=3e-5,
distribution_strategy='mirrored')
model_options: bert_model_options.BertModelOptions = (
bert_model_options.BertModelOptions())
distribution_strategy='mirrored',
)
)
model_options: bert_model_options.BertModelOptions = dataclasses.field(
default_factory=bert_model_options.BertModelOptions
)
do_lower_case: bool = True
tflite_input_name: Dict[str, str] = dataclasses.field(
default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)

View File

@ -47,11 +47,14 @@ class AverageWordEmbeddingClassifierSpec:
"""
# `learning_rate` is unused for the average word embedding model
hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams(
hparams: hp.AverageWordEmbeddingHParams = dataclasses.field(
default_factory=lambda: hp.AverageWordEmbeddingHParams(
epochs=10, batch_size=32, learning_rate=0
)
model_options: mo.AverageWordEmbeddingModelOptions = (
mo.AverageWordEmbeddingModelOptions())
)
model_options: mo.AverageWordEmbeddingModelOptions = dataclasses.field(
default_factory=mo.AverageWordEmbeddingModelOptions
)
name: str = 'AverageWordEmbedding'
average_word_embedding_classifier_spec = functools.partial(
@ -66,7 +69,7 @@ class BertClassifierSpec(bert_model_spec.BertModelSpec):
inherited from the BertModelSpec.
"""
hparams: hp.BertHParams = hp.BertHParams()
hparams: hp.BertHParams = dataclasses.field(default_factory=hp.BertHParams)
mobilebert_classifier_spec = functools.partial(