Internal change

PiperOrigin-RevId: 547346939
This commit is contained in:
Yilei Yang 2023-07-11 17:48:46 -07:00 committed by Copybara-Service
parent f2f49b9fc8
commit 917af2ce6b
2 changed files with 19 additions and 12 deletions

View File

@ -46,13 +46,17 @@ class BertModelSpec:
""" """
downloaded_files: file_util.DownloadedFiles downloaded_files: file_util.DownloadedFiles
hparams: hp.BaseHParams = hp.BaseHParams( hparams: hp.BaseHParams = dataclasses.field(
default_factory=lambda: hp.BaseHParams(
epochs=3, epochs=3,
batch_size=32, batch_size=32,
learning_rate=3e-5, learning_rate=3e-5,
distribution_strategy='mirrored') distribution_strategy='mirrored',
model_options: bert_model_options.BertModelOptions = ( )
bert_model_options.BertModelOptions()) )
model_options: bert_model_options.BertModelOptions = dataclasses.field(
default_factory=bert_model_options.BertModelOptions
)
do_lower_case: bool = True do_lower_case: bool = True
tflite_input_name: Dict[str, str] = dataclasses.field( tflite_input_name: Dict[str, str] = dataclasses.field(
default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME) default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)

View File

@ -47,11 +47,14 @@ class AverageWordEmbeddingClassifierSpec:
""" """
# `learning_rate` is unused for the average word embedding model # `learning_rate` is unused for the average word embedding model
hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams( hparams: hp.AverageWordEmbeddingHParams = dataclasses.field(
default_factory=lambda: hp.AverageWordEmbeddingHParams(
epochs=10, batch_size=32, learning_rate=0 epochs=10, batch_size=32, learning_rate=0
) )
model_options: mo.AverageWordEmbeddingModelOptions = ( )
mo.AverageWordEmbeddingModelOptions()) model_options: mo.AverageWordEmbeddingModelOptions = dataclasses.field(
default_factory=mo.AverageWordEmbeddingModelOptions
)
name: str = 'AverageWordEmbedding' name: str = 'AverageWordEmbedding'
average_word_embedding_classifier_spec = functools.partial( average_word_embedding_classifier_spec = functools.partial(
@ -66,7 +69,7 @@ class BertClassifierSpec(bert_model_spec.BertModelSpec):
inherited from the BertModelSpec. inherited from the BertModelSpec.
""" """
hparams: hp.BertHParams = hp.BertHParams() hparams: hp.BertHParams = dataclasses.field(default_factory=hp.BertHParams)
mobilebert_classifier_spec = functools.partial( mobilebert_classifier_spec = functools.partial(