Internal
PiperOrigin-RevId: 554595324
This commit is contained in:
parent
22054cd468
commit
c1c51c2fe7
|
@ -19,6 +19,13 @@ licenses(["notice"])
|
|||
|
||||
package(default_visibility = ["//mediapipe:__subpackages__"])
|
||||
|
||||
filegroup(
|
||||
name = "testdata",
|
||||
srcs = glob([
|
||||
"testdata/**",
|
||||
]),
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "test_util",
|
||||
testonly = 1,
|
||||
|
@ -56,11 +63,26 @@ py_library(
|
|||
py_test(
|
||||
name = "file_util_test",
|
||||
srcs = ["file_util_test.py"],
|
||||
data = ["//mediapipe/model_maker/python/core/utils/testdata"],
|
||||
data = [":testdata"],
|
||||
tags = ["requires-net:external"],
|
||||
deps = [":file_util"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "hub_loader",
|
||||
srcs = ["hub_loader.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "hub_loader_test",
|
||||
srcs = ["hub_loader_test.py"],
|
||||
data = [":testdata"],
|
||||
deps = [
|
||||
":hub_loader",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "loss_functions",
|
||||
srcs = ["loss_functions.py"],
|
||||
|
|
97
mediapipe/model_maker/python/core/utils/hub_loader.py
Normal file
97
mediapipe/model_maker/python/core/utils/hub_loader.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Handles both V1 and V2 modules."""
|
||||
|
||||
import tensorflow_hub as hub
|
||||
|
||||
|
||||
class HubKerasLayerV1V2(hub.KerasLayer):
|
||||
"""Class to loads TF v1 and TF v2 hub modules that could be fine-tuned.
|
||||
|
||||
Since TF v1 modules couldn't be retrained in hub.KerasLayer. This class
|
||||
provides a workaround for retraining the whole tf1 model in tf2. In
|
||||
particular, it extract self._func._self_unconditional_checkpoint_dependencies
|
||||
into trainable variable in tf1.
|
||||
|
||||
Doesn't update moving-mean/moving-variance for BatchNormalization during
|
||||
fine-tuning.
|
||||
"""
|
||||
|
||||
def _setup_layer(self, trainable=False, **kwargs):
|
||||
if self._is_hub_module_v1:
|
||||
self._setup_layer_v1(trainable, **kwargs)
|
||||
else:
|
||||
# call _setup_layer from the base class for v2.
|
||||
super(HubKerasLayerV1V2, self)._setup_layer(trainable, **kwargs)
|
||||
|
||||
def _check_trainability(self):
|
||||
if self._is_hub_module_v1:
|
||||
self._check_trainability_v1()
|
||||
else:
|
||||
# call _check_trainability from the base class for v2.
|
||||
super(HubKerasLayerV1V2, self)._check_trainability()
|
||||
|
||||
def _setup_layer_v1(self, trainable=False, **kwargs):
|
||||
"""Constructs keras layer with relevant weights and losses."""
|
||||
# Initialize an empty layer, then add_weight() etc. as needed.
|
||||
super(hub.KerasLayer, self).__init__(trainable=trainable, **kwargs)
|
||||
|
||||
if not self._is_hub_module_v1:
|
||||
raise ValueError(
|
||||
'Only supports to set up v1 hub module in this function.'
|
||||
)
|
||||
|
||||
# v2 trainable_variable:
|
||||
if hasattr(self._func, 'trainable_variables'):
|
||||
for v in self._func.trainable_variables:
|
||||
self._add_existing_weight(v, trainable=True)
|
||||
trainable_variables = {id(v) for v in self._func.trainable_variables}
|
||||
else:
|
||||
trainable_variables = set()
|
||||
|
||||
if not hasattr(self._func, '_self_unconditional_checkpoint_dependencies'):
|
||||
raise ValueError(
|
||||
"_func doesn't contains attribute "
|
||||
'_self_unconditional_checkpoint_dependencies.'
|
||||
)
|
||||
dependencies = self._func._self_unconditional_checkpoint_dependencies # pylint: disable=protected-access
|
||||
|
||||
# Adds trainable variables.
|
||||
for dep in dependencies:
|
||||
if dep.name == 'variables':
|
||||
for v in dep.ref:
|
||||
if id(v) not in trainable_variables:
|
||||
self._add_existing_weight(v, trainable=True)
|
||||
trainable_variables.add(id(v))
|
||||
|
||||
# Adds non-trainable variables.
|
||||
if hasattr(self._func, 'variables'):
|
||||
for v in self._func.variables:
|
||||
if id(v) not in trainable_variables:
|
||||
self._add_existing_weight(v, trainable=False)
|
||||
|
||||
# Forward the callable's regularization losses (if any).
|
||||
if hasattr(self._func, 'regularization_losses'):
|
||||
for l in self._func.regularization_losses:
|
||||
if not callable(l):
|
||||
raise ValueError(
|
||||
'hub.KerasLayer(obj) expects obj.regularization_losses to be an '
|
||||
'iterable of callables, each returning a scalar loss term.'
|
||||
)
|
||||
self.add_loss(self._call_loss_if_trainable(l)) # Supports callables.
|
||||
|
||||
def _check_trainability_v1(self):
|
||||
"""Ignores trainability checks for V1."""
|
||||
if self._is_hub_module_v1:
|
||||
return # Nothing to do.
|
59
mediapipe/model_maker/python/core/utils/hub_loader_test.py
Normal file
59
mediapipe/model_maker/python/core/utils/hub_loader_test.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import hub_loader
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
|
||||
class HubKerasLayerV1V2Test(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
("hub_module_v1_mini", True),
|
||||
("saved_model_v2_mini", True),
|
||||
("hub_module_v1_mini", False),
|
||||
("saved_model_v2_mini", False),
|
||||
)
|
||||
def test_load_with_defaults(self, module_name, trainable):
|
||||
inputs, expected_outputs = 10.0, 11.0 # Test modules perform increment op.
|
||||
path = test_utils.get_test_data_path(module_name)
|
||||
layer = hub_loader.HubKerasLayerV1V2(path, trainable=trainable)
|
||||
output = layer(inputs)
|
||||
self.assertEqual(output, expected_outputs)
|
||||
|
||||
def test_trainable_variable(self):
|
||||
path = test_utils.get_test_data_path("hub_module_v1_mini_train")
|
||||
layer = hub_loader.HubKerasLayerV1V2(path, trainable=True)
|
||||
# Checks trainable variables.
|
||||
self.assertLen(layer.trainable_variables, 2)
|
||||
self.assertEqual(layer.trainable_variables[0].name, "a:0")
|
||||
self.assertEqual(layer.trainable_variables[1].name, "b:0")
|
||||
self.assertEqual(layer.variables, layer.trainable_variables)
|
||||
# Checks non-trainable variables.
|
||||
self.assertEmpty(layer.non_trainable_variables)
|
||||
|
||||
layer = hub_loader.HubKerasLayerV1V2(path, trainable=False)
|
||||
# Checks trainable variables.
|
||||
self.assertEmpty(layer.trainable_variables)
|
||||
# Checks non-trainable variables.
|
||||
self.assertLen(layer.non_trainable_variables, 2)
|
||||
self.assertEqual(layer.non_trainable_variables[0].name, "a:0")
|
||||
self.assertEqual(layer.non_trainable_variables[1].name, "b:0")
|
||||
self.assertEqual(layer.variables, layer.non_trainable_variables)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright 2022 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe/model_maker/python/core/utils:__subpackages__"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "testdata",
|
||||
srcs = ["test.txt"],
|
||||
)
|
BIN
mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/saved_model.pb
vendored
Normal file
BIN
mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/saved_model.pb
vendored
Normal file
Binary file not shown.
1
mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/tfhub_module.pb
vendored
Normal file
1
mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/tfhub_module.pb
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
|
BIN
mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/saved_model.pb
vendored
Normal file
BIN
mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/saved_model.pb
vendored
Normal file
Binary file not shown.
1
mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/tfhub_module.pb
vendored
Normal file
1
mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/tfhub_module.pb
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
<EFBFBD>釮粮ツソ
|
||||
|
BIN
mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.index
vendored
Normal file
BIN
mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.index
vendored
Normal file
Binary file not shown.
BIN
mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/saved_model.pb
vendored
Normal file
BIN
mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/saved_model.pb
vendored
Normal file
Binary file not shown.
Binary file not shown.
BIN
mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.index
vendored
Normal file
BIN
mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.index
vendored
Normal file
Binary file not shown.
|
@ -14,7 +14,7 @@
|
|||
"""Specification for a BERT model."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Dict
|
||||
from typing import Dict, Union
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.core.utils import file_util
|
||||
|
@ -35,7 +35,9 @@ class BertModelSpec:
|
|||
Transformers for Language Understanding) for more details.
|
||||
|
||||
Attributes:
|
||||
downloaded_files: A DownloadedFiles object of the model files
|
||||
files: Either a TFHub url string which can be passed directly to
|
||||
hub.KerasLayer or a DownloadedFiles object of the model files.
|
||||
is_tf2: If True, the checkpoint is TF2 format. Else use TF1 format.
|
||||
hparams: Hyperparameters used for training.
|
||||
model_options: Configurable options for a BERT model.
|
||||
do_lower_case: boolean, whether to lower case the input text. Should be
|
||||
|
@ -45,7 +47,8 @@ class BertModelSpec:
|
|||
name: The name of the object.
|
||||
"""
|
||||
|
||||
downloaded_files: file_util.DownloadedFiles
|
||||
files: Union[str, file_util.DownloadedFiles]
|
||||
is_tf2: bool = True
|
||||
hparams: hp.BaseHParams = dataclasses.field(
|
||||
default_factory=lambda: hp.BaseHParams(
|
||||
epochs=3,
|
||||
|
@ -61,3 +64,11 @@ class BertModelSpec:
|
|||
tflite_input_name: Dict[str, str] = dataclasses.field(
|
||||
default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)
|
||||
name: str = 'Bert'
|
||||
|
||||
def get_path(self) -> str:
|
||||
if isinstance(self.files, file_util.DownloadedFiles):
|
||||
return self.files.get_path()
|
||||
elif isinstance(self.files, str):
|
||||
return self.files
|
||||
else:
|
||||
raise ValueError(f'files has unsupported type: {type(self.files)}')
|
||||
|
|
|
@ -131,6 +131,7 @@ py_library(
|
|||
":text_classifier_options",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||
"//mediapipe/model_maker/python/core/utils:hub_loader",
|
||||
"//mediapipe/model_maker/python/core/utils:loss_functions",
|
||||
"//mediapipe/model_maker/python/core/utils:metrics",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
|
|
|
@ -23,16 +23,8 @@ from mediapipe.model_maker.python.text.text_classifier import hyperparameters as
|
|||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||
|
||||
|
||||
MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
|
||||
'text_classifier/mobilebert_tiny',
|
||||
'https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny.tar.gz',
|
||||
is_folder=True,
|
||||
)
|
||||
|
||||
EXBERT_FILES = file_util.DownloadedFiles(
|
||||
'text_classifier/exbert',
|
||||
'https://storage.googleapis.com/mediapipe-assets/exbert.tar.gz',
|
||||
is_folder=True,
|
||||
MOBILEBERT_FILES = (
|
||||
'https://tfhub.dev/google/mobilebert/uncased_L-24_H-128_B-512_A-4_F-4_OPT/1'
|
||||
)
|
||||
|
||||
|
||||
|
@ -71,23 +63,14 @@ class BertClassifierSpec(bert_model_spec.BertModelSpec):
|
|||
|
||||
hparams: hp.BertHParams = dataclasses.field(default_factory=hp.BertHParams)
|
||||
|
||||
|
||||
mobilebert_classifier_spec = functools.partial(
|
||||
BertClassifierSpec,
|
||||
downloaded_files=MOBILEBERT_TINY_FILES,
|
||||
files=MOBILEBERT_FILES,
|
||||
hparams=hp.BertHParams(
|
||||
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||
),
|
||||
name='MobileBert',
|
||||
)
|
||||
|
||||
exbert_classifier_spec = functools.partial(
|
||||
BertClassifierSpec,
|
||||
downloaded_files=EXBERT_FILES,
|
||||
hparams=hp.BertHParams(
|
||||
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||
),
|
||||
name='ExBert',
|
||||
name='MobileBERT',
|
||||
is_tf2=False,
|
||||
)
|
||||
|
||||
|
||||
|
@ -96,4 +79,3 @@ class SupportedModels(enum.Enum):
|
|||
"""Predefined text classifier model specs supported by Model Maker."""
|
||||
AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
|
||||
MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
|
||||
EXBERT_CLASSIFIER = exbert_classifier_spec
|
||||
|
|
|
@ -42,8 +42,8 @@ class ModelSpecTest(tf.test.TestCase):
|
|||
def test_predefined_bert_spec(self):
|
||||
model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||
self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec)
|
||||
self.assertEqual(model_spec_obj.name, 'MobileBert')
|
||||
self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path()))
|
||||
self.assertEqual(model_spec_obj.name, 'MobileBERT')
|
||||
self.assertTrue(model_spec_obj.files)
|
||||
self.assertTrue(model_spec_obj.do_lower_case)
|
||||
self.assertEqual(
|
||||
model_spec_obj.tflite_input_name,
|
||||
|
|
|
@ -87,11 +87,11 @@ class PreprocessorTest(tf.test.TestCase):
|
|||
csv_file = self._get_csv_file()
|
||||
dataset = text_classifier_ds.Dataset.from_csv(
|
||||
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
||||
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
|
||||
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||
seq_len=5,
|
||||
do_lower_case=bert_spec.do_lower_case,
|
||||
uri=bert_spec.downloaded_files.get_path(),
|
||||
uri=bert_spec.get_path(),
|
||||
model_name=bert_spec.name,
|
||||
)
|
||||
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
||||
|
@ -121,11 +121,11 @@ class PreprocessorTest(tf.test.TestCase):
|
|||
csv_params=self.CSV_PARAMS_,
|
||||
cache_dir=self.get_temp_dir(),
|
||||
)
|
||||
bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
|
||||
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||
seq_len=5,
|
||||
do_lower_case=bert_spec.do_lower_case,
|
||||
uri=bert_spec.downloaded_files.get_path(),
|
||||
uri=bert_spec.get_path(),
|
||||
model_name=bert_spec.name,
|
||||
)
|
||||
ds_cache_files = dataset.tfrecord_cache_files
|
||||
|
@ -153,7 +153,7 @@ class PreprocessorTest(tf.test.TestCase):
|
|||
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||
seq_len=seq_len,
|
||||
do_lower_case=do_lower_case,
|
||||
uri=bert_spec.downloaded_files.get_path(),
|
||||
uri=bert_spec.get_path(),
|
||||
model_name=bert_spec.name,
|
||||
)
|
||||
new_cf = bert_preprocessor._get_tfrecord_cache_files(cf)
|
||||
|
@ -167,10 +167,6 @@ class PreprocessorTest(tf.test.TestCase):
|
|||
cache_dir=self.get_temp_dir(),
|
||||
num_shards=1,
|
||||
)
|
||||
exbert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value()
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, True))
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 10, True))
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, False))
|
||||
mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True))
|
||||
all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True))
|
||||
|
@ -180,10 +176,10 @@ class PreprocessorTest(tf.test.TestCase):
|
|||
cache_dir=self.get_temp_dir(),
|
||||
num_shards=1,
|
||||
)
|
||||
all_cf_prefixes.add(self._get_new_prefix(new_cf, exbert_spec, 5, True))
|
||||
all_cf_prefixes.add(self._get_new_prefix(new_cf, mobilebert_spec, 5, True))
|
||||
|
||||
# Each item of all_cf_prefixes should be unique, so 7 total.
|
||||
self.assertLen(all_cf_prefixes, 7)
|
||||
# Each item of all_cf_prefixes should be unique.
|
||||
self.assertLen(all_cf_prefixes, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -24,6 +24,7 @@ import tensorflow_hub as hub
|
|||
|
||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||
from mediapipe.model_maker.python.core.tasks import classifier
|
||||
from mediapipe.model_maker.python.core.utils import hub_loader
|
||||
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||
from mediapipe.model_maker.python.core.utils import metrics
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
|
@ -52,18 +53,21 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
|
|||
if options.model_options is None:
|
||||
return
|
||||
|
||||
if (isinstance(options.model_options, mo.AverageWordEmbeddingModelOptions) and
|
||||
(options.supported_model !=
|
||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
|
||||
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
|
||||
f" got {options.supported_model}")
|
||||
if isinstance(options.model_options, mo.BertModelOptions) and (
|
||||
options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||
and options.supported_model != ms.SupportedModels.EXBERT_CLASSIFIER
|
||||
if isinstance(
|
||||
options.model_options, mo.AverageWordEmbeddingModelOptions
|
||||
) and (
|
||||
options.supported_model
|
||||
!= ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
|
||||
):
|
||||
raise ValueError(
|
||||
"Expected a Bert Classifier(MobileBERT or EXBERT), got "
|
||||
f"{options.supported_model}"
|
||||
"Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
|
||||
f" got {options.supported_model}"
|
||||
)
|
||||
if isinstance(options.model_options, mo.BertModelOptions) and (
|
||||
not isinstance(options.supported_model.value(), ms.BertClassifierSpec)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected a Bert Classifier, got {options.supported_model}"
|
||||
)
|
||||
|
||||
|
||||
|
@ -113,15 +117,13 @@ class TextClassifier(classifier.Classifier):
|
|||
if options.hparams is None:
|
||||
options.hparams = options.supported_model.value().hparams
|
||||
|
||||
if (
|
||||
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
|
||||
):
|
||||
if isinstance(options.supported_model.value(), ms.BertClassifierSpec):
|
||||
text_classifier = _BertClassifier.create_bert_classifier(
|
||||
train_data, validation_data, options
|
||||
)
|
||||
elif (options.supported_model ==
|
||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
||||
elif isinstance(
|
||||
options.supported_model.value(), ms.AverageWordEmbeddingClassifierSpec
|
||||
):
|
||||
text_classifier = _AverageWordEmbeddingClassifier.create_average_word_embedding_classifier(
|
||||
train_data, validation_data, options
|
||||
)
|
||||
|
@ -348,12 +350,12 @@ class _BertClassifier(TextClassifier):
|
|||
self._hparams = hparams
|
||||
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||
self._model_options = model_options
|
||||
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||
with self._hparams.get_strategy().scope():
|
||||
self._loss_function = loss_functions.SparseFocalLoss(
|
||||
self._hparams.gamma, self._num_classes
|
||||
)
|
||||
self._metric_functions = self._create_metrics()
|
||||
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||
|
||||
@classmethod
|
||||
def create_bert_classifier(
|
||||
|
@ -410,7 +412,7 @@ class _BertClassifier(TextClassifier):
|
|||
self._text_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||
seq_len=self._model_options.seq_len,
|
||||
do_lower_case=self._model_spec.do_lower_case,
|
||||
uri=self._model_spec.downloaded_files.get_path(),
|
||||
uri=self._model_spec.get_path(),
|
||||
model_name=self._model_spec.name,
|
||||
)
|
||||
return (
|
||||
|
@ -488,12 +490,26 @@ class _BertClassifier(TextClassifier):
|
|||
name="input_type_ids",
|
||||
),
|
||||
)
|
||||
if self._model_spec.is_tf2:
|
||||
encoder = hub.KerasLayer(
|
||||
self._model_spec.downloaded_files.get_path(),
|
||||
self._model_spec.get_path(),
|
||||
trainable=self._model_options.do_fine_tuning,
|
||||
)
|
||||
encoder_outputs = encoder(encoder_inputs)
|
||||
pooled_output = encoder_outputs["pooled_output"]
|
||||
else:
|
||||
renamed_inputs = dict(
|
||||
input_ids=encoder_inputs["input_word_ids"],
|
||||
input_mask=encoder_inputs["input_mask"],
|
||||
segment_ids=encoder_inputs["input_type_ids"],
|
||||
)
|
||||
encoder = hub_loader.HubKerasLayerV1V2(
|
||||
self._model_spec.get_path(),
|
||||
signature="tokens",
|
||||
output_key="pooled_output",
|
||||
trainable=self._model_options.do_fine_tuning,
|
||||
)
|
||||
pooled_output = encoder(renamed_inputs)
|
||||
|
||||
output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)(
|
||||
pooled_output)
|
||||
|
|
|
@ -104,13 +104,9 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
@parameterized.named_parameters(
|
||||
# Skipping mobilebert b/c OSS test timeout/flakiness: b/275624089
|
||||
# dict(
|
||||
# testcase_name='mobilebert',
|
||||
# supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||
# ),
|
||||
dict(
|
||||
testcase_name='exbert',
|
||||
supported_model=text_classifier.SupportedModels.EXBERT_CLASSIFIER,
|
||||
testcase_name='mobilebert',
|
||||
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||
),
|
||||
)
|
||||
def test_create_and_train_bert(self, supported_model):
|
||||
|
@ -156,7 +152,7 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
def test_label_mismatch(self):
|
||||
options = text_classifier.TextClassifierOptions(
|
||||
supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER)
|
||||
supported_model=(text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER)
|
||||
)
|
||||
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||
train_data = text_classifier.Dataset(train_tf_dataset, ['foo'], 1)
|
||||
|
@ -174,13 +170,13 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
train_data, validation_data = self._get_data()
|
||||
|
||||
avg_options = text_classifier.TextClassifierOptions(
|
||||
supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER),
|
||||
supported_model=(text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER),
|
||||
model_options=text_classifier.AverageWordEmbeddingModelOptions(),
|
||||
)
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
||||
' SupportedModels.EXBERT_CLASSIFIER',
|
||||
' SupportedModels.MOBILEBERT_CLASSIFIER',
|
||||
):
|
||||
text_classifier.TextClassifier.create(
|
||||
train_data, validation_data, avg_options
|
||||
|
@ -194,7 +190,7 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
)
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
'Expected a Bert Classifier(MobileBERT or EXBERT), got'
|
||||
'Expected a Bert Classifier, got'
|
||||
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER',
|
||||
):
|
||||
text_classifier.TextClassifier.create(
|
||||
|
@ -203,7 +199,7 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
def test_bert_loss_and_metrics_creation(self):
|
||||
train_data, validation_data = self._get_data()
|
||||
supported_model = text_classifier.SupportedModels.EXBERT_CLASSIFIER
|
||||
supported_model = text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||
hparams = text_classifier.BertHParams(
|
||||
desired_recalls=[0.2],
|
||||
desired_precisions=[0.9],
|
||||
|
|
Loading…
Reference in New Issue
Block a user