diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index 2c29970bb..c5e031245 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -19,6 +19,13 @@ licenses(["notice"]) package(default_visibility = ["//mediapipe:__subpackages__"]) +filegroup( + name = "testdata", + srcs = glob([ + "testdata/**", + ]), +) + py_library( name = "test_util", testonly = 1, @@ -56,11 +63,26 @@ py_library( py_test( name = "file_util_test", srcs = ["file_util_test.py"], - data = ["//mediapipe/model_maker/python/core/utils/testdata"], + data = [":testdata"], tags = ["requires-net:external"], deps = [":file_util"], ) +py_library( + name = "hub_loader", + srcs = ["hub_loader.py"], +) + +py_test( + name = "hub_loader_test", + srcs = ["hub_loader_test.py"], + data = [":testdata"], + deps = [ + ":hub_loader", + "//mediapipe/tasks/python/test:test_utils", + ], +) + py_library( name = "loss_functions", srcs = ["loss_functions.py"], diff --git a/mediapipe/model_maker/python/core/utils/hub_loader.py b/mediapipe/model_maker/python/core/utils/hub_loader.py new file mode 100644 index 000000000..a52099884 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/hub_loader.py @@ -0,0 +1,97 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Handles both V1 and V2 modules.""" + +import tensorflow_hub as hub + + +class HubKerasLayerV1V2(hub.KerasLayer): + """Class to loads TF v1 and TF v2 hub modules that could be fine-tuned. + + Since TF v1 modules couldn't be retrained in hub.KerasLayer. This class + provides a workaround for retraining the whole tf1 model in tf2. In + particular, it extract self._func._self_unconditional_checkpoint_dependencies + into trainable variable in tf1. + + Doesn't update moving-mean/moving-variance for BatchNormalization during + fine-tuning. + """ + + def _setup_layer(self, trainable=False, **kwargs): + if self._is_hub_module_v1: + self._setup_layer_v1(trainable, **kwargs) + else: + # call _setup_layer from the base class for v2. + super(HubKerasLayerV1V2, self)._setup_layer(trainable, **kwargs) + + def _check_trainability(self): + if self._is_hub_module_v1: + self._check_trainability_v1() + else: + # call _check_trainability from the base class for v2. + super(HubKerasLayerV1V2, self)._check_trainability() + + def _setup_layer_v1(self, trainable=False, **kwargs): + """Constructs keras layer with relevant weights and losses.""" + # Initialize an empty layer, then add_weight() etc. as needed. + super(hub.KerasLayer, self).__init__(trainable=trainable, **kwargs) + + if not self._is_hub_module_v1: + raise ValueError( + 'Only supports to set up v1 hub module in this function.' + ) + + # v2 trainable_variable: + if hasattr(self._func, 'trainable_variables'): + for v in self._func.trainable_variables: + self._add_existing_weight(v, trainable=True) + trainable_variables = {id(v) for v in self._func.trainable_variables} + else: + trainable_variables = set() + + if not hasattr(self._func, '_self_unconditional_checkpoint_dependencies'): + raise ValueError( + "_func doesn't contains attribute " + '_self_unconditional_checkpoint_dependencies.' + ) + dependencies = self._func._self_unconditional_checkpoint_dependencies # pylint: disable=protected-access + + # Adds trainable variables. + for dep in dependencies: + if dep.name == 'variables': + for v in dep.ref: + if id(v) not in trainable_variables: + self._add_existing_weight(v, trainable=True) + trainable_variables.add(id(v)) + + # Adds non-trainable variables. + if hasattr(self._func, 'variables'): + for v in self._func.variables: + if id(v) not in trainable_variables: + self._add_existing_weight(v, trainable=False) + + # Forward the callable's regularization losses (if any). + if hasattr(self._func, 'regularization_losses'): + for l in self._func.regularization_losses: + if not callable(l): + raise ValueError( + 'hub.KerasLayer(obj) expects obj.regularization_losses to be an ' + 'iterable of callables, each returning a scalar loss term.' + ) + self.add_loss(self._call_loss_if_trainable(l)) # Supports callables. + + def _check_trainability_v1(self): + """Ignores trainability checks for V1.""" + if self._is_hub_module_v1: + return # Nothing to do. diff --git a/mediapipe/model_maker/python/core/utils/hub_loader_test.py b/mediapipe/model_maker/python/core/utils/hub_loader_test.py new file mode 100644 index 000000000..8ea15b5d1 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/hub_loader_test.py @@ -0,0 +1,59 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import hub_loader +from mediapipe.tasks.python.test import test_utils + + +class HubKerasLayerV1V2Test(tf.test.TestCase, parameterized.TestCase): + + @parameterized.parameters( + ("hub_module_v1_mini", True), + ("saved_model_v2_mini", True), + ("hub_module_v1_mini", False), + ("saved_model_v2_mini", False), + ) + def test_load_with_defaults(self, module_name, trainable): + inputs, expected_outputs = 10.0, 11.0 # Test modules perform increment op. + path = test_utils.get_test_data_path(module_name) + layer = hub_loader.HubKerasLayerV1V2(path, trainable=trainable) + output = layer(inputs) + self.assertEqual(output, expected_outputs) + + def test_trainable_variable(self): + path = test_utils.get_test_data_path("hub_module_v1_mini_train") + layer = hub_loader.HubKerasLayerV1V2(path, trainable=True) + # Checks trainable variables. + self.assertLen(layer.trainable_variables, 2) + self.assertEqual(layer.trainable_variables[0].name, "a:0") + self.assertEqual(layer.trainable_variables[1].name, "b:0") + self.assertEqual(layer.variables, layer.trainable_variables) + # Checks non-trainable variables. + self.assertEmpty(layer.non_trainable_variables) + + layer = hub_loader.HubKerasLayerV1V2(path, trainable=False) + # Checks trainable variables. + self.assertEmpty(layer.trainable_variables) + # Checks non-trainable variables. + self.assertLen(layer.non_trainable_variables, 2) + self.assertEqual(layer.non_trainable_variables[0].name, "a:0") + self.assertEqual(layer.non_trainable_variables[1].name, "b:0") + self.assertEqual(layer.variables, layer.non_trainable_variables) + + +if __name__ == "__main__": + tf.test.main() diff --git a/mediapipe/model_maker/python/core/utils/testdata/BUILD b/mediapipe/model_maker/python/core/utils/testdata/BUILD deleted file mode 100644 index ea45f6140..000000000 --- a/mediapipe/model_maker/python/core/utils/testdata/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - default_visibility = ["//mediapipe/model_maker/python/core/utils:__subpackages__"], - licenses = ["notice"], # Apache 2.0 -) - -filegroup( - name = "testdata", - srcs = ["test.txt"], -) diff --git a/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/saved_model.pb b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/saved_model.pb new file mode 100644 index 000000000..e60e04a24 Binary files /dev/null and b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/saved_model.pb differ diff --git a/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/tfhub_module.pb b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/tfhub_module.pb new file mode 100644 index 000000000..d65dd8f1d --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/tfhub_module.pb @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/saved_model.pb b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/saved_model.pb new file mode 100644 index 000000000..69519fef7 Binary files /dev/null and b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/saved_model.pb differ diff --git a/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/tfhub_module.pb b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/tfhub_module.pb new file mode 100644 index 000000000..d65dd8f1d --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/tfhub_module.pb @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.data-00000-of-00001 b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.data-00000-of-00001 new file mode 100644 index 000000000..3474955ee --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.data-00000-of-00001 @@ -0,0 +1,2 @@ +øÌû¾âì¿ + diff --git a/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.index b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.index new file mode 100644 index 000000000..d0e35ab87 Binary files /dev/null and b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.index differ diff --git a/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/saved_model.pb b/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/saved_model.pb new file mode 100644 index 000000000..314ea74fc Binary files /dev/null and b/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/saved_model.pb differ diff --git a/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.data-00000-of-00001 b/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.data-00000-of-00001 new file mode 100644 index 000000000..09dbb330d Binary files /dev/null and b/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.data-00000-of-00001 differ diff --git a/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.index b/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.index new file mode 100644 index 000000000..7cfb9ffd4 Binary files /dev/null and b/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.index differ diff --git a/mediapipe/model_maker/python/text/core/bert_model_spec.py b/mediapipe/model_maker/python/text/core/bert_model_spec.py index 80e92a06a..4a847ac33 100644 --- a/mediapipe/model_maker/python/text/core/bert_model_spec.py +++ b/mediapipe/model_maker/python/text/core/bert_model_spec.py @@ -14,7 +14,7 @@ """Specification for a BERT model.""" import dataclasses -from typing import Dict +from typing import Dict, Union from mediapipe.model_maker.python.core import hyperparameters as hp from mediapipe.model_maker.python.core.utils import file_util @@ -35,7 +35,9 @@ class BertModelSpec: Transformers for Language Understanding) for more details. Attributes: - downloaded_files: A DownloadedFiles object of the model files + files: Either a TFHub url string which can be passed directly to + hub.KerasLayer or a DownloadedFiles object of the model files. + is_tf2: If True, the checkpoint is TF2 format. Else use TF1 format. hparams: Hyperparameters used for training. model_options: Configurable options for a BERT model. do_lower_case: boolean, whether to lower case the input text. Should be @@ -45,7 +47,8 @@ class BertModelSpec: name: The name of the object. """ - downloaded_files: file_util.DownloadedFiles + files: Union[str, file_util.DownloadedFiles] + is_tf2: bool = True hparams: hp.BaseHParams = dataclasses.field( default_factory=lambda: hp.BaseHParams( epochs=3, @@ -61,3 +64,11 @@ class BertModelSpec: tflite_input_name: Dict[str, str] = dataclasses.field( default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME) name: str = 'Bert' + + def get_path(self) -> str: + if isinstance(self.files, file_util.DownloadedFiles): + return self.files.get_path() + elif isinstance(self.files, str): + return self.files + else: + raise ValueError(f'files has unsupported type: {type(self.files)}') diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 322b1e1e5..e32733e31 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -131,6 +131,7 @@ py_library( ":text_classifier_options", "//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/tasks:classifier", + "//mediapipe/model_maker/python/core/utils:hub_loader", "//mediapipe/model_maker/python/core/utils:loss_functions", "//mediapipe/model_maker/python/core/utils:metrics", "//mediapipe/model_maker/python/core/utils:model_util", diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec.py b/mediapipe/model_maker/python/text/text_classifier/model_spec.py index 724aaf377..01d1432cb 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec.py @@ -23,16 +23,8 @@ from mediapipe.model_maker.python.text.text_classifier import hyperparameters as from mediapipe.model_maker.python.text.text_classifier import model_options as mo -MOBILEBERT_TINY_FILES = file_util.DownloadedFiles( - 'text_classifier/mobilebert_tiny', - 'https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny.tar.gz', - is_folder=True, -) - -EXBERT_FILES = file_util.DownloadedFiles( - 'text_classifier/exbert', - 'https://storage.googleapis.com/mediapipe-assets/exbert.tar.gz', - is_folder=True, +MOBILEBERT_FILES = ( + 'https://tfhub.dev/google/mobilebert/uncased_L-24_H-128_B-512_A-4_F-4_OPT/1' ) @@ -71,23 +63,14 @@ class BertClassifierSpec(bert_model_spec.BertModelSpec): hparams: hp.BertHParams = dataclasses.field(default_factory=hp.BertHParams) - mobilebert_classifier_spec = functools.partial( BertClassifierSpec, - downloaded_files=MOBILEBERT_TINY_FILES, + files=MOBILEBERT_FILES, hparams=hp.BertHParams( epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' ), - name='MobileBert', -) - -exbert_classifier_spec = functools.partial( - BertClassifierSpec, - downloaded_files=EXBERT_FILES, - hparams=hp.BertHParams( - epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' - ), - name='ExBert', + name='MobileBERT', + is_tf2=False, ) @@ -96,4 +79,3 @@ class SupportedModels(enum.Enum): """Predefined text classifier model specs supported by Model Maker.""" AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec - EXBERT_CLASSIFIER = exbert_classifier_spec diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py index 4d42851d5..d1e578b81 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py @@ -42,8 +42,8 @@ class ModelSpecTest(tf.test.TestCase): def test_predefined_bert_spec(self): model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value() self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec) - self.assertEqual(model_spec_obj.name, 'MobileBert') - self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path())) + self.assertEqual(model_spec_obj.name, 'MobileBERT') + self.assertTrue(model_spec_obj.files) self.assertTrue(model_spec_obj.do_lower_case) self.assertEqual( model_spec_obj.tflite_input_name, diff --git a/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py b/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py index 28c12f96c..ff9015498 100644 --- a/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py @@ -87,11 +87,11 @@ class PreprocessorTest(tf.test.TestCase): csv_file = self._get_csv_file() dataset = text_classifier_ds.Dataset.from_csv( filename=csv_file, csv_params=self.CSV_PARAMS_) - bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value() + bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() bert_preprocessor = preprocessor.BertClassifierPreprocessor( seq_len=5, do_lower_case=bert_spec.do_lower_case, - uri=bert_spec.downloaded_files.get_path(), + uri=bert_spec.get_path(), model_name=bert_spec.name, ) preprocessed_dataset = bert_preprocessor.preprocess(dataset) @@ -121,11 +121,11 @@ class PreprocessorTest(tf.test.TestCase): csv_params=self.CSV_PARAMS_, cache_dir=self.get_temp_dir(), ) - bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value() + bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() bert_preprocessor = preprocessor.BertClassifierPreprocessor( seq_len=5, do_lower_case=bert_spec.do_lower_case, - uri=bert_spec.downloaded_files.get_path(), + uri=bert_spec.get_path(), model_name=bert_spec.name, ) ds_cache_files = dataset.tfrecord_cache_files @@ -153,7 +153,7 @@ class PreprocessorTest(tf.test.TestCase): bert_preprocessor = preprocessor.BertClassifierPreprocessor( seq_len=seq_len, do_lower_case=do_lower_case, - uri=bert_spec.downloaded_files.get_path(), + uri=bert_spec.get_path(), model_name=bert_spec.name, ) new_cf = bert_preprocessor._get_tfrecord_cache_files(cf) @@ -167,10 +167,6 @@ class PreprocessorTest(tf.test.TestCase): cache_dir=self.get_temp_dir(), num_shards=1, ) - exbert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value() - all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, True)) - all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 10, True)) - all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, False)) mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True)) all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True)) @@ -180,10 +176,10 @@ class PreprocessorTest(tf.test.TestCase): cache_dir=self.get_temp_dir(), num_shards=1, ) - all_cf_prefixes.add(self._get_new_prefix(new_cf, exbert_spec, 5, True)) + all_cf_prefixes.add(self._get_new_prefix(new_cf, mobilebert_spec, 5, True)) - # Each item of all_cf_prefixes should be unique, so 7 total. - self.assertLen(all_cf_prefixes, 7) + # Each item of all_cf_prefixes should be unique. + self.assertLen(all_cf_prefixes, 4) if __name__ == '__main__': diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index 10d88110d..76043aa72 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -24,6 +24,7 @@ import tensorflow_hub as hub from mediapipe.model_maker.python.core.data import dataset as ds from mediapipe.model_maker.python.core.tasks import classifier +from mediapipe.model_maker.python.core.utils import hub_loader from mediapipe.model_maker.python.core.utils import loss_functions from mediapipe.model_maker.python.core.utils import metrics from mediapipe.model_maker.python.core.utils import model_util @@ -52,18 +53,21 @@ def _validate(options: text_classifier_options.TextClassifierOptions): if options.model_options is None: return - if (isinstance(options.model_options, mo.AverageWordEmbeddingModelOptions) and - (options.supported_model != - ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)): - raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER," - f" got {options.supported_model}") - if isinstance(options.model_options, mo.BertModelOptions) and ( - options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER - and options.supported_model != ms.SupportedModels.EXBERT_CLASSIFIER + if isinstance( + options.model_options, mo.AverageWordEmbeddingModelOptions + ) and ( + options.supported_model + != ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER ): raise ValueError( - "Expected a Bert Classifier(MobileBERT or EXBERT), got " - f"{options.supported_model}" + "Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER," + f" got {options.supported_model}" + ) + if isinstance(options.model_options, mo.BertModelOptions) and ( + not isinstance(options.supported_model.value(), ms.BertClassifierSpec) + ): + raise ValueError( + f"Expected a Bert Classifier, got {options.supported_model}" ) @@ -113,15 +117,13 @@ class TextClassifier(classifier.Classifier): if options.hparams is None: options.hparams = options.supported_model.value().hparams - if ( - options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER - or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER - ): + if isinstance(options.supported_model.value(), ms.BertClassifierSpec): text_classifier = _BertClassifier.create_bert_classifier( train_data, validation_data, options ) - elif (options.supported_model == - ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER): + elif isinstance( + options.supported_model.value(), ms.AverageWordEmbeddingClassifierSpec + ): text_classifier = _AverageWordEmbeddingClassifier.create_average_word_embedding_classifier( train_data, validation_data, options ) @@ -348,12 +350,12 @@ class _BertClassifier(TextClassifier): self._hparams = hparams self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) self._model_options = model_options + self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None with self._hparams.get_strategy().scope(): self._loss_function = loss_functions.SparseFocalLoss( self._hparams.gamma, self._num_classes ) self._metric_functions = self._create_metrics() - self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None @classmethod def create_bert_classifier( @@ -410,7 +412,7 @@ class _BertClassifier(TextClassifier): self._text_preprocessor = preprocessor.BertClassifierPreprocessor( seq_len=self._model_options.seq_len, do_lower_case=self._model_spec.do_lower_case, - uri=self._model_spec.downloaded_files.get_path(), + uri=self._model_spec.get_path(), model_name=self._model_spec.name, ) return ( @@ -488,12 +490,26 @@ class _BertClassifier(TextClassifier): name="input_type_ids", ), ) - encoder = hub.KerasLayer( - self._model_spec.downloaded_files.get_path(), - trainable=self._model_options.do_fine_tuning, - ) - encoder_outputs = encoder(encoder_inputs) - pooled_output = encoder_outputs["pooled_output"] + if self._model_spec.is_tf2: + encoder = hub.KerasLayer( + self._model_spec.get_path(), + trainable=self._model_options.do_fine_tuning, + ) + encoder_outputs = encoder(encoder_inputs) + pooled_output = encoder_outputs["pooled_output"] + else: + renamed_inputs = dict( + input_ids=encoder_inputs["input_word_ids"], + input_mask=encoder_inputs["input_mask"], + segment_ids=encoder_inputs["input_type_ids"], + ) + encoder = hub_loader.HubKerasLayerV1V2( + self._model_spec.get_path(), + signature="tokens", + output_key="pooled_output", + trainable=self._model_options.do_fine_tuning, + ) + pooled_output = encoder(renamed_inputs) output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)( pooled_output) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index be4646f68..122182ddd 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -104,13 +104,9 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( # Skipping mobilebert b/c OSS test timeout/flakiness: b/275624089 - # dict( - # testcase_name='mobilebert', - # supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER, - # ), dict( - testcase_name='exbert', - supported_model=text_classifier.SupportedModels.EXBERT_CLASSIFIER, + testcase_name='mobilebert', + supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER, ), ) def test_create_and_train_bert(self, supported_model): @@ -156,7 +152,7 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase): def test_label_mismatch(self): options = text_classifier.TextClassifierOptions( - supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER) + supported_model=(text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER) ) train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) train_data = text_classifier.Dataset(train_tf_dataset, ['foo'], 1) @@ -174,13 +170,13 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase): train_data, validation_data = self._get_data() avg_options = text_classifier.TextClassifierOptions( - supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER), + supported_model=(text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER), model_options=text_classifier.AverageWordEmbeddingModelOptions(), ) with self.assertRaisesWithLiteralMatch( ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got' - ' SupportedModels.EXBERT_CLASSIFIER', + ' SupportedModels.MOBILEBERT_CLASSIFIER', ): text_classifier.TextClassifier.create( train_data, validation_data, avg_options @@ -194,7 +190,7 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase): ) with self.assertRaisesWithLiteralMatch( ValueError, - 'Expected a Bert Classifier(MobileBERT or EXBERT), got' + 'Expected a Bert Classifier, got' ' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER', ): text_classifier.TextClassifier.create( @@ -203,7 +199,7 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase): def test_bert_loss_and_metrics_creation(self): train_data, validation_data = self._get_data() - supported_model = text_classifier.SupportedModels.EXBERT_CLASSIFIER + supported_model = text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER hparams = text_classifier.BertHParams( desired_recalls=[0.2], desired_precisions=[0.9],