Internal change
PiperOrigin-RevId: 537613648
This commit is contained in:
parent
549e09cace
commit
cbf1d97429
|
@ -15,9 +15,12 @@
|
|||
|
||||
import dataclasses
|
||||
import tempfile
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from official.common import distribute_utils
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseHParams:
|
||||
|
@ -43,10 +46,10 @@ class BaseHParams:
|
|||
documentation for more details:
|
||||
https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy.
|
||||
num_gpus: How many GPUs to use at each worker with the
|
||||
DistributionStrategies API. The default is -1, which means utilize all
|
||||
available GPUs.
|
||||
tpu: The Cloud TPU to use for training. This should be either the name used
|
||||
when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.
|
||||
DistributionStrategies API. The default is 0.
|
||||
tpu: The TPU resource to be used for training. This should be either the
|
||||
name used when creating the Cloud TPU, a grpc://ip.address.of.tpu:8470
|
||||
url, or an empty string if using a local TPU.
|
||||
"""
|
||||
|
||||
# Parameters for train configuration
|
||||
|
@ -63,5 +66,16 @@ class BaseHParams:
|
|||
|
||||
# Parameters for hardware acceleration
|
||||
distribution_strategy: str = 'off'
|
||||
num_gpus: int = -1 # default value of -1 means use all available GPUs
|
||||
num_gpus: int = 0
|
||||
tpu: str = ''
|
||||
_strategy: tf.distribute.Strategy = dataclasses.field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self._strategy = distribute_utils.get_distribution_strategy(
|
||||
distribution_strategy=self.distribution_strategy,
|
||||
num_gpus=self.num_gpus,
|
||||
tpu_address=self.tpu,
|
||||
)
|
||||
|
||||
def get_strategy(self):
|
||||
return self._strategy
|
||||
|
|
|
@ -85,8 +85,10 @@ class ModelSpecTest(tf.test.TestCase):
|
|||
steps_per_epoch=None,
|
||||
shuffle=False,
|
||||
distribution_strategy='off',
|
||||
num_gpus=-1,
|
||||
tpu=''))
|
||||
num_gpus=0,
|
||||
tpu='',
|
||||
),
|
||||
)
|
||||
|
||||
def test_custom_bert_spec(self):
|
||||
custom_bert_classifier_options = (
|
||||
|
|
|
@ -311,9 +311,11 @@ class _BertClassifier(TextClassifier):
|
|||
label_names: Sequence[str]):
|
||||
super().__init__(model_spec, hparams, label_names)
|
||||
self._model_options = model_options
|
||||
with self._hparams.get_strategy().scope():
|
||||
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy(
|
||||
"test_accuracy", dtype=tf.float32)
|
||||
"test_accuracy", dtype=tf.float32
|
||||
)
|
||||
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||
|
||||
@classmethod
|
||||
|
@ -350,6 +352,7 @@ class _BertClassifier(TextClassifier):
|
|||
"""
|
||||
(processed_train_data, processed_validation_data) = (
|
||||
self._load_and_run_preprocessor(train_data, validation_data))
|
||||
with self._hparams.get_strategy().scope():
|
||||
self._create_model()
|
||||
self._create_optimizer(processed_train_data)
|
||||
self._train_model(processed_train_data, processed_validation_data)
|
||||
|
|
Loading…
Reference in New Issue
Block a user