diff --git a/mediapipe/model_maker/python/core/hyperparameters.py b/mediapipe/model_maker/python/core/hyperparameters.py index 3b3e3540b..224716550 100644 --- a/mediapipe/model_maker/python/core/hyperparameters.py +++ b/mediapipe/model_maker/python/core/hyperparameters.py @@ -15,9 +15,12 @@ import dataclasses import tempfile - from typing import Optional +import tensorflow as tf + +from official.common import distribute_utils + @dataclasses.dataclass class BaseHParams: @@ -43,10 +46,10 @@ class BaseHParams: documentation for more details: https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy. num_gpus: How many GPUs to use at each worker with the - DistributionStrategies API. The default is -1, which means utilize all - available GPUs. - tpu: The Cloud TPU to use for training. This should be either the name used - when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url. + DistributionStrategies API. The default is 0. + tpu: The TPU resource to be used for training. This should be either the + name used when creating the Cloud TPU, a grpc://ip.address.of.tpu:8470 + url, or an empty string if using a local TPU. """ # Parameters for train configuration @@ -63,5 +66,16 @@ class BaseHParams: # Parameters for hardware acceleration distribution_strategy: str = 'off' - num_gpus: int = -1 # default value of -1 means use all available GPUs + num_gpus: int = 0 tpu: str = '' + _strategy: tf.distribute.Strategy = dataclasses.field(init=False) + + def __post_init__(self): + self._strategy = distribute_utils.get_distribution_strategy( + distribution_strategy=self.distribution_strategy, + num_gpus=self.num_gpus, + tpu_address=self.tpu, + ) + + def get_strategy(self): + return self._strategy diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py index d3daac540..a8d40558c 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py @@ -85,8 +85,10 @@ class ModelSpecTest(tf.test.TestCase): steps_per_epoch=None, shuffle=False, distribution_strategy='off', - num_gpus=-1, - tpu='')) + num_gpus=0, + tpu='', + ), + ) def test_custom_bert_spec(self): custom_bert_classifier_options = ( diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index cd6ceb9b3..c3dd48be8 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -311,9 +311,11 @@ class _BertClassifier(TextClassifier): label_names: Sequence[str]): super().__init__(model_spec, hparams, label_names) self._model_options = model_options - self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy() - self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy( - "test_accuracy", dtype=tf.float32) + with self._hparams.get_strategy().scope(): + self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy() + self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy( + "test_accuracy", dtype=tf.float32 + ) self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None @classmethod @@ -350,8 +352,9 @@ class _BertClassifier(TextClassifier): """ (processed_train_data, processed_validation_data) = ( self._load_and_run_preprocessor(train_data, validation_data)) - self._create_model() - self._create_optimizer(processed_train_data) + with self._hparams.get_strategy().scope(): + self._create_model() + self._create_optimizer(processed_train_data) self._train_model(processed_train_data, processed_validation_data) def _load_and_run_preprocessor(