diff --git a/mediapipe/model_maker/python/text/core/bert_model_spec.py b/mediapipe/model_maker/python/text/core/bert_model_spec.py index 792c2c9a6..80e92a06a 100644 --- a/mediapipe/model_maker/python/text/core/bert_model_spec.py +++ b/mediapipe/model_maker/python/text/core/bert_model_spec.py @@ -46,13 +46,17 @@ class BertModelSpec: """ downloaded_files: file_util.DownloadedFiles - hparams: hp.BaseHParams = hp.BaseHParams( - epochs=3, - batch_size=32, - learning_rate=3e-5, - distribution_strategy='mirrored') - model_options: bert_model_options.BertModelOptions = ( - bert_model_options.BertModelOptions()) + hparams: hp.BaseHParams = dataclasses.field( + default_factory=lambda: hp.BaseHParams( + epochs=3, + batch_size=32, + learning_rate=3e-5, + distribution_strategy='mirrored', + ) + ) + model_options: bert_model_options.BertModelOptions = dataclasses.field( + default_factory=bert_model_options.BertModelOptions + ) do_lower_case: bool = True tflite_input_name: Dict[str, str] = dataclasses.field( default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME) diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec.py b/mediapipe/model_maker/python/text/text_classifier/model_spec.py index 452e22679..8bd83143c 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec.py @@ -47,11 +47,14 @@ class AverageWordEmbeddingClassifierSpec: """ # `learning_rate` is unused for the average word embedding model - hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams( - epochs=10, batch_size=32, learning_rate=0 + hparams: hp.AverageWordEmbeddingHParams = dataclasses.field( + default_factory=lambda: hp.AverageWordEmbeddingHParams( + epochs=10, batch_size=32, learning_rate=0 + ) + ) + model_options: mo.AverageWordEmbeddingModelOptions = dataclasses.field( + default_factory=mo.AverageWordEmbeddingModelOptions ) - model_options: mo.AverageWordEmbeddingModelOptions = ( - mo.AverageWordEmbeddingModelOptions()) name: str = 'AverageWordEmbedding' average_word_embedding_classifier_spec = functools.partial( @@ -66,7 +69,7 @@ class BertClassifierSpec(bert_model_spec.BertModelSpec): inherited from the BertModelSpec. """ - hparams: hp.BertHParams = hp.BertHParams() + hparams: hp.BertHParams = dataclasses.field(default_factory=hp.BertHParams) mobilebert_classifier_spec = functools.partial(