From 917af2ce6b628079508ac4bdc11a7657b207d016 Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Tue, 11 Jul 2023 17:48:46 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 547346939 --- .../python/text/core/bert_model_spec.py | 18 +++++++++++------- .../python/text/text_classifier/model_spec.py | 13 ++++++++----- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/mediapipe/model_maker/python/text/core/bert_model_spec.py b/mediapipe/model_maker/python/text/core/bert_model_spec.py index 792c2c9a6..80e92a06a 100644 --- a/mediapipe/model_maker/python/text/core/bert_model_spec.py +++ b/mediapipe/model_maker/python/text/core/bert_model_spec.py @@ -46,13 +46,17 @@ class BertModelSpec: """ downloaded_files: file_util.DownloadedFiles - hparams: hp.BaseHParams = hp.BaseHParams( - epochs=3, - batch_size=32, - learning_rate=3e-5, - distribution_strategy='mirrored') - model_options: bert_model_options.BertModelOptions = ( - bert_model_options.BertModelOptions()) + hparams: hp.BaseHParams = dataclasses.field( + default_factory=lambda: hp.BaseHParams( + epochs=3, + batch_size=32, + learning_rate=3e-5, + distribution_strategy='mirrored', + ) + ) + model_options: bert_model_options.BertModelOptions = dataclasses.field( + default_factory=bert_model_options.BertModelOptions + ) do_lower_case: bool = True tflite_input_name: Dict[str, str] = dataclasses.field( default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME) diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec.py b/mediapipe/model_maker/python/text/text_classifier/model_spec.py index 452e22679..8bd83143c 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec.py @@ -47,11 +47,14 @@ class AverageWordEmbeddingClassifierSpec: """ # `learning_rate` is unused for the average word embedding model - hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams( - epochs=10, batch_size=32, learning_rate=0 + hparams: hp.AverageWordEmbeddingHParams = dataclasses.field( + default_factory=lambda: hp.AverageWordEmbeddingHParams( + epochs=10, batch_size=32, learning_rate=0 + ) + ) + model_options: mo.AverageWordEmbeddingModelOptions = dataclasses.field( + default_factory=mo.AverageWordEmbeddingModelOptions ) - model_options: mo.AverageWordEmbeddingModelOptions = ( - mo.AverageWordEmbeddingModelOptions()) name: str = 'AverageWordEmbedding' average_word_embedding_classifier_spec = functools.partial( @@ -66,7 +69,7 @@ class BertClassifierSpec(bert_model_spec.BertModelSpec): inherited from the BertModelSpec. """ - hparams: hp.BertHParams = hp.BertHParams() + hparams: hp.BertHParams = dataclasses.field(default_factory=hp.BertHParams) mobilebert_classifier_spec = functools.partial(