Internal change

PiperOrigin-RevId: 537613648
This commit is contained in:
MediaPipe Team 2023-06-03 20:17:28 -07:00 committed by Copybara-Service
parent 549e09cace
commit cbf1d97429
3 changed files with 32 additions and 13 deletions

View File

@ -15,9 +15,12 @@
import dataclasses
import tempfile
from typing import Optional
import tensorflow as tf
from official.common import distribute_utils
@dataclasses.dataclass
class BaseHParams:
@ -43,10 +46,10 @@ class BaseHParams:
documentation for more details:
https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy.
num_gpus: How many GPUs to use at each worker with the
DistributionStrategies API. The default is -1, which means utilize all
available GPUs.
tpu: The Cloud TPU to use for training. This should be either the name used
when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.
DistributionStrategies API. The default is 0.
tpu: The TPU resource to be used for training. This should be either the
name used when creating the Cloud TPU, a grpc://ip.address.of.tpu:8470
url, or an empty string if using a local TPU.
"""
# Parameters for train configuration
@ -63,5 +66,16 @@ class BaseHParams:
# Parameters for hardware acceleration
distribution_strategy: str = 'off'
num_gpus: int = -1 # default value of -1 means use all available GPUs
num_gpus: int = 0
tpu: str = ''
_strategy: tf.distribute.Strategy = dataclasses.field(init=False)
def __post_init__(self):
self._strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=self.distribution_strategy,
num_gpus=self.num_gpus,
tpu_address=self.tpu,
)
def get_strategy(self):
return self._strategy

View File

@ -85,8 +85,10 @@ class ModelSpecTest(tf.test.TestCase):
steps_per_epoch=None,
shuffle=False,
distribution_strategy='off',
num_gpus=-1,
tpu=''))
num_gpus=0,
tpu='',
),
)
def test_custom_bert_spec(self):
custom_bert_classifier_options = (

View File

@ -311,9 +311,11 @@ class _BertClassifier(TextClassifier):
label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names)
self._model_options = model_options
with self._hparams.get_strategy().scope():
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy(
"test_accuracy", dtype=tf.float32)
"test_accuracy", dtype=tf.float32
)
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
@classmethod
@ -350,6 +352,7 @@ class _BertClassifier(TextClassifier):
"""
(processed_train_data, processed_validation_data) = (
self._load_and_run_preprocessor(train_data, validation_data))
with self._hparams.get_strategy().scope():
self._create_model()
self._create_optimizer(processed_train_data)
self._train_model(processed_train_data, processed_validation_data)