Internal change
PiperOrigin-RevId: 537613648
This commit is contained in:
parent
549e09cace
commit
cbf1d97429
|
@ -15,9 +15,12 @@
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from official.common import distribute_utils
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class BaseHParams:
|
class BaseHParams:
|
||||||
|
@ -43,10 +46,10 @@ class BaseHParams:
|
||||||
documentation for more details:
|
documentation for more details:
|
||||||
https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy.
|
https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy.
|
||||||
num_gpus: How many GPUs to use at each worker with the
|
num_gpus: How many GPUs to use at each worker with the
|
||||||
DistributionStrategies API. The default is -1, which means utilize all
|
DistributionStrategies API. The default is 0.
|
||||||
available GPUs.
|
tpu: The TPU resource to be used for training. This should be either the
|
||||||
tpu: The Cloud TPU to use for training. This should be either the name used
|
name used when creating the Cloud TPU, a grpc://ip.address.of.tpu:8470
|
||||||
when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.
|
url, or an empty string if using a local TPU.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Parameters for train configuration
|
# Parameters for train configuration
|
||||||
|
@ -63,5 +66,16 @@ class BaseHParams:
|
||||||
|
|
||||||
# Parameters for hardware acceleration
|
# Parameters for hardware acceleration
|
||||||
distribution_strategy: str = 'off'
|
distribution_strategy: str = 'off'
|
||||||
num_gpus: int = -1 # default value of -1 means use all available GPUs
|
num_gpus: int = 0
|
||||||
tpu: str = ''
|
tpu: str = ''
|
||||||
|
_strategy: tf.distribute.Strategy = dataclasses.field(init=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self._strategy = distribute_utils.get_distribution_strategy(
|
||||||
|
distribution_strategy=self.distribution_strategy,
|
||||||
|
num_gpus=self.num_gpus,
|
||||||
|
tpu_address=self.tpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_strategy(self):
|
||||||
|
return self._strategy
|
||||||
|
|
|
@ -85,8 +85,10 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
steps_per_epoch=None,
|
steps_per_epoch=None,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
distribution_strategy='off',
|
distribution_strategy='off',
|
||||||
num_gpus=-1,
|
num_gpus=0,
|
||||||
tpu=''))
|
tpu='',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def test_custom_bert_spec(self):
|
def test_custom_bert_spec(self):
|
||||||
custom_bert_classifier_options = (
|
custom_bert_classifier_options = (
|
||||||
|
|
|
@ -311,9 +311,11 @@ class _BertClassifier(TextClassifier):
|
||||||
label_names: Sequence[str]):
|
label_names: Sequence[str]):
|
||||||
super().__init__(model_spec, hparams, label_names)
|
super().__init__(model_spec, hparams, label_names)
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
with self._hparams.get_strategy().scope():
|
||||||
self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy(
|
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||||
"test_accuracy", dtype=tf.float32)
|
self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy(
|
||||||
|
"test_accuracy", dtype=tf.float32
|
||||||
|
)
|
||||||
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -350,8 +352,9 @@ class _BertClassifier(TextClassifier):
|
||||||
"""
|
"""
|
||||||
(processed_train_data, processed_validation_data) = (
|
(processed_train_data, processed_validation_data) = (
|
||||||
self._load_and_run_preprocessor(train_data, validation_data))
|
self._load_and_run_preprocessor(train_data, validation_data))
|
||||||
self._create_model()
|
with self._hparams.get_strategy().scope():
|
||||||
self._create_optimizer(processed_train_data)
|
self._create_model()
|
||||||
|
self._create_optimizer(processed_train_data)
|
||||||
self._train_model(processed_train_data, processed_validation_data)
|
self._train_model(processed_train_data, processed_validation_data)
|
||||||
|
|
||||||
def _load_and_run_preprocessor(
|
def _load_and_run_preprocessor(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user