Internal change
PiperOrigin-RevId: 547346939
This commit is contained in:
		
							parent
							
								
									f2f49b9fc8
								
							
						
					
					
						commit
						917af2ce6b
					
				| 
						 | 
					@ -46,13 +46,17 @@ class BertModelSpec:
 | 
				
			||||||
  """
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  downloaded_files: file_util.DownloadedFiles
 | 
					  downloaded_files: file_util.DownloadedFiles
 | 
				
			||||||
  hparams: hp.BaseHParams = hp.BaseHParams(
 | 
					  hparams: hp.BaseHParams = dataclasses.field(
 | 
				
			||||||
      epochs=3,
 | 
					      default_factory=lambda: hp.BaseHParams(
 | 
				
			||||||
      batch_size=32,
 | 
					          epochs=3,
 | 
				
			||||||
      learning_rate=3e-5,
 | 
					          batch_size=32,
 | 
				
			||||||
      distribution_strategy='mirrored')
 | 
					          learning_rate=3e-5,
 | 
				
			||||||
  model_options: bert_model_options.BertModelOptions = (
 | 
					          distribution_strategy='mirrored',
 | 
				
			||||||
      bert_model_options.BertModelOptions())
 | 
					      )
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
 | 
					  model_options: bert_model_options.BertModelOptions = dataclasses.field(
 | 
				
			||||||
 | 
					      default_factory=bert_model_options.BertModelOptions
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
  do_lower_case: bool = True
 | 
					  do_lower_case: bool = True
 | 
				
			||||||
  tflite_input_name: Dict[str, str] = dataclasses.field(
 | 
					  tflite_input_name: Dict[str, str] = dataclasses.field(
 | 
				
			||||||
      default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)
 | 
					      default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -47,11 +47,14 @@ class AverageWordEmbeddingClassifierSpec:
 | 
				
			||||||
  """
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # `learning_rate` is unused for the average word embedding model
 | 
					  # `learning_rate` is unused for the average word embedding model
 | 
				
			||||||
  hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams(
 | 
					  hparams: hp.AverageWordEmbeddingHParams = dataclasses.field(
 | 
				
			||||||
      epochs=10, batch_size=32, learning_rate=0
 | 
					      default_factory=lambda: hp.AverageWordEmbeddingHParams(
 | 
				
			||||||
 | 
					          epochs=10, batch_size=32, learning_rate=0
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
 | 
					  model_options: mo.AverageWordEmbeddingModelOptions = dataclasses.field(
 | 
				
			||||||
 | 
					      default_factory=mo.AverageWordEmbeddingModelOptions
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
  model_options: mo.AverageWordEmbeddingModelOptions = (
 | 
					 | 
				
			||||||
      mo.AverageWordEmbeddingModelOptions())
 | 
					 | 
				
			||||||
  name: str = 'AverageWordEmbedding'
 | 
					  name: str = 'AverageWordEmbedding'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
average_word_embedding_classifier_spec = functools.partial(
 | 
					average_word_embedding_classifier_spec = functools.partial(
 | 
				
			||||||
| 
						 | 
					@ -66,7 +69,7 @@ class BertClassifierSpec(bert_model_spec.BertModelSpec):
 | 
				
			||||||
  inherited from the BertModelSpec.
 | 
					  inherited from the BertModelSpec.
 | 
				
			||||||
  """
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  hparams: hp.BertHParams = hp.BertHParams()
 | 
					  hparams: hp.BertHParams = dataclasses.field(default_factory=hp.BertHParams)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
mobilebert_classifier_spec = functools.partial(
 | 
					mobilebert_classifier_spec = functools.partial(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user