Internal change

PiperOrigin-RevId: 537613648
This commit is contained in:
MediaPipe Team 2023-06-03 20:17:28 -07:00 committed by Copybara-Service
parent 549e09cace
commit cbf1d97429
3 changed files with 32 additions and 13 deletions

View File

@ -15,9 +15,12 @@
import dataclasses import dataclasses
import tempfile import tempfile
from typing import Optional from typing import Optional
import tensorflow as tf
from official.common import distribute_utils
@dataclasses.dataclass @dataclasses.dataclass
class BaseHParams: class BaseHParams:
@ -43,10 +46,10 @@ class BaseHParams:
documentation for more details: documentation for more details:
https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy. https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy.
num_gpus: How many GPUs to use at each worker with the num_gpus: How many GPUs to use at each worker with the
DistributionStrategies API. The default is -1, which means utilize all DistributionStrategies API. The default is 0.
available GPUs. tpu: The TPU resource to be used for training. This should be either the
tpu: The Cloud TPU to use for training. This should be either the name used name used when creating the Cloud TPU, a grpc://ip.address.of.tpu:8470
when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url. url, or an empty string if using a local TPU.
""" """
# Parameters for train configuration # Parameters for train configuration
@ -63,5 +66,16 @@ class BaseHParams:
# Parameters for hardware acceleration # Parameters for hardware acceleration
distribution_strategy: str = 'off' distribution_strategy: str = 'off'
num_gpus: int = -1 # default value of -1 means use all available GPUs num_gpus: int = 0
tpu: str = '' tpu: str = ''
_strategy: tf.distribute.Strategy = dataclasses.field(init=False)
def __post_init__(self):
self._strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=self.distribution_strategy,
num_gpus=self.num_gpus,
tpu_address=self.tpu,
)
def get_strategy(self):
return self._strategy

View File

@ -85,8 +85,10 @@ class ModelSpecTest(tf.test.TestCase):
steps_per_epoch=None, steps_per_epoch=None,
shuffle=False, shuffle=False,
distribution_strategy='off', distribution_strategy='off',
num_gpus=-1, num_gpus=0,
tpu='')) tpu='',
),
)
def test_custom_bert_spec(self): def test_custom_bert_spec(self):
custom_bert_classifier_options = ( custom_bert_classifier_options = (

View File

@ -311,9 +311,11 @@ class _BertClassifier(TextClassifier):
label_names: Sequence[str]): label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names) super().__init__(model_spec, hparams, label_names)
self._model_options = model_options self._model_options = model_options
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy() with self._hparams.get_strategy().scope():
self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy( self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
"test_accuracy", dtype=tf.float32) self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy(
"test_accuracy", dtype=tf.float32
)
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
@classmethod @classmethod
@ -350,8 +352,9 @@ class _BertClassifier(TextClassifier):
""" """
(processed_train_data, processed_validation_data) = ( (processed_train_data, processed_validation_data) = (
self._load_and_run_preprocessor(train_data, validation_data)) self._load_and_run_preprocessor(train_data, validation_data))
self._create_model() with self._hparams.get_strategy().scope():
self._create_optimizer(processed_train_data) self._create_model()
self._create_optimizer(processed_train_data)
self._train_model(processed_train_data, processed_validation_data) self._train_model(processed_train_data, processed_validation_data)
def _load_and_run_preprocessor( def _load_and_run_preprocessor(