PiperOrigin-RevId: 554595324
This commit is contained in:
MediaPipe Team 2023-08-07 14:33:00 -07:00 committed by Copybara-Service
parent 22054cd468
commit c1c51c2fe7
20 changed files with 260 additions and 99 deletions

View File

@ -19,6 +19,13 @@ licenses(["notice"])
package(default_visibility = ["//mediapipe:__subpackages__"])
filegroup(
name = "testdata",
srcs = glob([
"testdata/**",
]),
)
py_library(
name = "test_util",
testonly = 1,
@ -56,11 +63,26 @@ py_library(
py_test(
name = "file_util_test",
srcs = ["file_util_test.py"],
data = ["//mediapipe/model_maker/python/core/utils/testdata"],
data = [":testdata"],
tags = ["requires-net:external"],
deps = [":file_util"],
)
py_library(
name = "hub_loader",
srcs = ["hub_loader.py"],
)
py_test(
name = "hub_loader_test",
srcs = ["hub_loader_test.py"],
data = [":testdata"],
deps = [
":hub_loader",
"//mediapipe/tasks/python/test:test_utils",
],
)
py_library(
name = "loss_functions",
srcs = ["loss_functions.py"],

View File

@ -0,0 +1,97 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Handles both V1 and V2 modules."""
import tensorflow_hub as hub
class HubKerasLayerV1V2(hub.KerasLayer):
"""Class to loads TF v1 and TF v2 hub modules that could be fine-tuned.
Since TF v1 modules couldn't be retrained in hub.KerasLayer. This class
provides a workaround for retraining the whole tf1 model in tf2. In
particular, it extract self._func._self_unconditional_checkpoint_dependencies
into trainable variable in tf1.
Doesn't update moving-mean/moving-variance for BatchNormalization during
fine-tuning.
"""
def _setup_layer(self, trainable=False, **kwargs):
if self._is_hub_module_v1:
self._setup_layer_v1(trainable, **kwargs)
else:
# call _setup_layer from the base class for v2.
super(HubKerasLayerV1V2, self)._setup_layer(trainable, **kwargs)
def _check_trainability(self):
if self._is_hub_module_v1:
self._check_trainability_v1()
else:
# call _check_trainability from the base class for v2.
super(HubKerasLayerV1V2, self)._check_trainability()
def _setup_layer_v1(self, trainable=False, **kwargs):
"""Constructs keras layer with relevant weights and losses."""
# Initialize an empty layer, then add_weight() etc. as needed.
super(hub.KerasLayer, self).__init__(trainable=trainable, **kwargs)
if not self._is_hub_module_v1:
raise ValueError(
'Only supports to set up v1 hub module in this function.'
)
# v2 trainable_variable:
if hasattr(self._func, 'trainable_variables'):
for v in self._func.trainable_variables:
self._add_existing_weight(v, trainable=True)
trainable_variables = {id(v) for v in self._func.trainable_variables}
else:
trainable_variables = set()
if not hasattr(self._func, '_self_unconditional_checkpoint_dependencies'):
raise ValueError(
"_func doesn't contains attribute "
'_self_unconditional_checkpoint_dependencies.'
)
dependencies = self._func._self_unconditional_checkpoint_dependencies # pylint: disable=protected-access
# Adds trainable variables.
for dep in dependencies:
if dep.name == 'variables':
for v in dep.ref:
if id(v) not in trainable_variables:
self._add_existing_weight(v, trainable=True)
trainable_variables.add(id(v))
# Adds non-trainable variables.
if hasattr(self._func, 'variables'):
for v in self._func.variables:
if id(v) not in trainable_variables:
self._add_existing_weight(v, trainable=False)
# Forward the callable's regularization losses (if any).
if hasattr(self._func, 'regularization_losses'):
for l in self._func.regularization_losses:
if not callable(l):
raise ValueError(
'hub.KerasLayer(obj) expects obj.regularization_losses to be an '
'iterable of callables, each returning a scalar loss term.'
)
self.add_loss(self._call_loss_if_trainable(l)) # Supports callables.
def _check_trainability_v1(self):
"""Ignores trainability checks for V1."""
if self._is_hub_module_v1:
return # Nothing to do.

View File

@ -0,0 +1,59 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import hub_loader
from mediapipe.tasks.python.test import test_utils
class HubKerasLayerV1V2Test(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
("hub_module_v1_mini", True),
("saved_model_v2_mini", True),
("hub_module_v1_mini", False),
("saved_model_v2_mini", False),
)
def test_load_with_defaults(self, module_name, trainable):
inputs, expected_outputs = 10.0, 11.0 # Test modules perform increment op.
path = test_utils.get_test_data_path(module_name)
layer = hub_loader.HubKerasLayerV1V2(path, trainable=trainable)
output = layer(inputs)
self.assertEqual(output, expected_outputs)
def test_trainable_variable(self):
path = test_utils.get_test_data_path("hub_module_v1_mini_train")
layer = hub_loader.HubKerasLayerV1V2(path, trainable=True)
# Checks trainable variables.
self.assertLen(layer.trainable_variables, 2)
self.assertEqual(layer.trainable_variables[0].name, "a:0")
self.assertEqual(layer.trainable_variables[1].name, "b:0")
self.assertEqual(layer.variables, layer.trainable_variables)
# Checks non-trainable variables.
self.assertEmpty(layer.non_trainable_variables)
layer = hub_loader.HubKerasLayerV1V2(path, trainable=False)
# Checks trainable variables.
self.assertEmpty(layer.trainable_variables)
# Checks non-trainable variables.
self.assertLen(layer.non_trainable_variables, 2)
self.assertEqual(layer.non_trainable_variables[0].name, "a:0")
self.assertEqual(layer.non_trainable_variables[1].name, "b:0")
self.assertEqual(layer.variables, layer.non_trainable_variables)
if __name__ == "__main__":
tf.test.main()

View File

@ -1,23 +0,0 @@
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_visibility = ["//mediapipe/model_maker/python/core/utils:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "testdata",
srcs = ["test.txt"],
)

View File

@ -14,7 +14,7 @@
"""Specification for a BERT model."""
import dataclasses
from typing import Dict
from typing import Dict, Union
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.utils import file_util
@ -35,7 +35,9 @@ class BertModelSpec:
Transformers for Language Understanding) for more details.
Attributes:
downloaded_files: A DownloadedFiles object of the model files
files: Either a TFHub url string which can be passed directly to
hub.KerasLayer or a DownloadedFiles object of the model files.
is_tf2: If True, the checkpoint is TF2 format. Else use TF1 format.
hparams: Hyperparameters used for training.
model_options: Configurable options for a BERT model.
do_lower_case: boolean, whether to lower case the input text. Should be
@ -45,7 +47,8 @@ class BertModelSpec:
name: The name of the object.
"""
downloaded_files: file_util.DownloadedFiles
files: Union[str, file_util.DownloadedFiles]
is_tf2: bool = True
hparams: hp.BaseHParams = dataclasses.field(
default_factory=lambda: hp.BaseHParams(
epochs=3,
@ -61,3 +64,11 @@ class BertModelSpec:
tflite_input_name: Dict[str, str] = dataclasses.field(
default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)
name: str = 'Bert'
def get_path(self) -> str:
if isinstance(self.files, file_util.DownloadedFiles):
return self.files.get_path()
elif isinstance(self.files, str):
return self.files
else:
raise ValueError(f'files has unsupported type: {type(self.files)}')

View File

@ -131,6 +131,7 @@ py_library(
":text_classifier_options",
"//mediapipe/model_maker/python/core/data:dataset",
"//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:hub_loader",
"//mediapipe/model_maker/python/core/utils:loss_functions",
"//mediapipe/model_maker/python/core/utils:metrics",
"//mediapipe/model_maker/python/core/utils:model_util",

View File

@ -23,16 +23,8 @@ from mediapipe.model_maker.python.text.text_classifier import hyperparameters as
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
'text_classifier/mobilebert_tiny',
'https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny.tar.gz',
is_folder=True,
)
EXBERT_FILES = file_util.DownloadedFiles(
'text_classifier/exbert',
'https://storage.googleapis.com/mediapipe-assets/exbert.tar.gz',
is_folder=True,
MOBILEBERT_FILES = (
'https://tfhub.dev/google/mobilebert/uncased_L-24_H-128_B-512_A-4_F-4_OPT/1'
)
@ -71,23 +63,14 @@ class BertClassifierSpec(bert_model_spec.BertModelSpec):
hparams: hp.BertHParams = dataclasses.field(default_factory=hp.BertHParams)
mobilebert_classifier_spec = functools.partial(
BertClassifierSpec,
downloaded_files=MOBILEBERT_TINY_FILES,
files=MOBILEBERT_FILES,
hparams=hp.BertHParams(
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
),
name='MobileBert',
)
exbert_classifier_spec = functools.partial(
BertClassifierSpec,
downloaded_files=EXBERT_FILES,
hparams=hp.BertHParams(
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
),
name='ExBert',
name='MobileBERT',
is_tf2=False,
)
@ -96,4 +79,3 @@ class SupportedModels(enum.Enum):
"""Predefined text classifier model specs supported by Model Maker."""
AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
EXBERT_CLASSIFIER = exbert_classifier_spec

View File

@ -42,8 +42,8 @@ class ModelSpecTest(tf.test.TestCase):
def test_predefined_bert_spec(self):
model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value()
self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec)
self.assertEqual(model_spec_obj.name, 'MobileBert')
self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path()))
self.assertEqual(model_spec_obj.name, 'MobileBERT')
self.assertTrue(model_spec_obj.files)
self.assertTrue(model_spec_obj.do_lower_case)
self.assertEqual(
model_spec_obj.tflite_input_name,

View File

@ -87,11 +87,11 @@ class PreprocessorTest(tf.test.TestCase):
csv_file = self._get_csv_file()
dataset = text_classifier_ds.Dataset.from_csv(
filename=csv_file, csv_params=self.CSV_PARAMS_)
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=5,
do_lower_case=bert_spec.do_lower_case,
uri=bert_spec.downloaded_files.get_path(),
uri=bert_spec.get_path(),
model_name=bert_spec.name,
)
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
@ -121,11 +121,11 @@ class PreprocessorTest(tf.test.TestCase):
csv_params=self.CSV_PARAMS_,
cache_dir=self.get_temp_dir(),
)
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=5,
do_lower_case=bert_spec.do_lower_case,
uri=bert_spec.downloaded_files.get_path(),
uri=bert_spec.get_path(),
model_name=bert_spec.name,
)
ds_cache_files = dataset.tfrecord_cache_files
@ -153,7 +153,7 @@ class PreprocessorTest(tf.test.TestCase):
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=seq_len,
do_lower_case=do_lower_case,
uri=bert_spec.downloaded_files.get_path(),
uri=bert_spec.get_path(),
model_name=bert_spec.name,
)
new_cf = bert_preprocessor._get_tfrecord_cache_files(cf)
@ -167,10 +167,6 @@ class PreprocessorTest(tf.test.TestCase):
cache_dir=self.get_temp_dir(),
num_shards=1,
)
exbert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, True))
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 10, True))
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, False))
mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True))
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True))
@ -180,10 +176,10 @@ class PreprocessorTest(tf.test.TestCase):
cache_dir=self.get_temp_dir(),
num_shards=1,
)
all_cf_prefixes.add(self._get_new_prefix(new_cf, exbert_spec, 5, True))
all_cf_prefixes.add(self._get_new_prefix(new_cf, mobilebert_spec, 5, True))
# Each item of all_cf_prefixes should be unique, so 7 total.
self.assertLen(all_cf_prefixes, 7)
# Each item of all_cf_prefixes should be unique.
self.assertLen(all_cf_prefixes, 4)
if __name__ == '__main__':

View File

@ -24,6 +24,7 @@ import tensorflow_hub as hub
from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import hub_loader
from mediapipe.model_maker.python.core.utils import loss_functions
from mediapipe.model_maker.python.core.utils import metrics
from mediapipe.model_maker.python.core.utils import model_util
@ -52,18 +53,21 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
if options.model_options is None:
return
if (isinstance(options.model_options, mo.AverageWordEmbeddingModelOptions) and
(options.supported_model !=
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
f" got {options.supported_model}")
if isinstance(options.model_options, mo.BertModelOptions) and (
options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER
and options.supported_model != ms.SupportedModels.EXBERT_CLASSIFIER
if isinstance(
options.model_options, mo.AverageWordEmbeddingModelOptions
) and (
options.supported_model
!= ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
):
raise ValueError(
"Expected a Bert Classifier(MobileBERT or EXBERT), got "
f"{options.supported_model}"
"Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
f" got {options.supported_model}"
)
if isinstance(options.model_options, mo.BertModelOptions) and (
not isinstance(options.supported_model.value(), ms.BertClassifierSpec)
):
raise ValueError(
f"Expected a Bert Classifier, got {options.supported_model}"
)
@ -113,15 +117,13 @@ class TextClassifier(classifier.Classifier):
if options.hparams is None:
options.hparams = options.supported_model.value().hparams
if (
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
):
if isinstance(options.supported_model.value(), ms.BertClassifierSpec):
text_classifier = _BertClassifier.create_bert_classifier(
train_data, validation_data, options
)
elif (options.supported_model ==
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
elif isinstance(
options.supported_model.value(), ms.AverageWordEmbeddingClassifierSpec
):
text_classifier = _AverageWordEmbeddingClassifier.create_average_word_embedding_classifier(
train_data, validation_data, options
)
@ -348,12 +350,12 @@ class _BertClassifier(TextClassifier):
self._hparams = hparams
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._model_options = model_options
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
with self._hparams.get_strategy().scope():
self._loss_function = loss_functions.SparseFocalLoss(
self._hparams.gamma, self._num_classes
)
self._metric_functions = self._create_metrics()
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
@classmethod
def create_bert_classifier(
@ -410,7 +412,7 @@ class _BertClassifier(TextClassifier):
self._text_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=self._model_options.seq_len,
do_lower_case=self._model_spec.do_lower_case,
uri=self._model_spec.downloaded_files.get_path(),
uri=self._model_spec.get_path(),
model_name=self._model_spec.name,
)
return (
@ -488,12 +490,26 @@ class _BertClassifier(TextClassifier):
name="input_type_ids",
),
)
encoder = hub.KerasLayer(
self._model_spec.downloaded_files.get_path(),
trainable=self._model_options.do_fine_tuning,
)
encoder_outputs = encoder(encoder_inputs)
pooled_output = encoder_outputs["pooled_output"]
if self._model_spec.is_tf2:
encoder = hub.KerasLayer(
self._model_spec.get_path(),
trainable=self._model_options.do_fine_tuning,
)
encoder_outputs = encoder(encoder_inputs)
pooled_output = encoder_outputs["pooled_output"]
else:
renamed_inputs = dict(
input_ids=encoder_inputs["input_word_ids"],
input_mask=encoder_inputs["input_mask"],
segment_ids=encoder_inputs["input_type_ids"],
)
encoder = hub_loader.HubKerasLayerV1V2(
self._model_spec.get_path(),
signature="tokens",
output_key="pooled_output",
trainable=self._model_options.do_fine_tuning,
)
pooled_output = encoder(renamed_inputs)
output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)(
pooled_output)

View File

@ -104,13 +104,9 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
# Skipping mobilebert b/c OSS test timeout/flakiness: b/275624089
# dict(
# testcase_name='mobilebert',
# supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
# ),
dict(
testcase_name='exbert',
supported_model=text_classifier.SupportedModels.EXBERT_CLASSIFIER,
testcase_name='mobilebert',
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
),
)
def test_create_and_train_bert(self, supported_model):
@ -156,7 +152,7 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
def test_label_mismatch(self):
options = text_classifier.TextClassifierOptions(
supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER)
supported_model=(text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER)
)
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
train_data = text_classifier.Dataset(train_tf_dataset, ['foo'], 1)
@ -174,13 +170,13 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
train_data, validation_data = self._get_data()
avg_options = text_classifier.TextClassifierOptions(
supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER),
supported_model=(text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER),
model_options=text_classifier.AverageWordEmbeddingModelOptions(),
)
with self.assertRaisesWithLiteralMatch(
ValueError,
'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
' SupportedModels.EXBERT_CLASSIFIER',
' SupportedModels.MOBILEBERT_CLASSIFIER',
):
text_classifier.TextClassifier.create(
train_data, validation_data, avg_options
@ -194,7 +190,7 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
)
with self.assertRaisesWithLiteralMatch(
ValueError,
'Expected a Bert Classifier(MobileBERT or EXBERT), got'
'Expected a Bert Classifier, got'
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER',
):
text_classifier.TextClassifier.create(
@ -203,7 +199,7 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
def test_bert_loss_and_metrics_creation(self):
train_data, validation_data = self._get_data()
supported_model = text_classifier.SupportedModels.EXBERT_CLASSIFIER
supported_model = text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER
hparams = text_classifier.BertHParams(
desired_recalls=[0.2],
desired_precisions=[0.9],