From c1c51c2fe7200914b2da2992e824bbef4ed5a88e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 7 Aug 2023 14:33:00 -0700 Subject: [PATCH] Internal PiperOrigin-RevId: 554595324 --- mediapipe/model_maker/python/core/utils/BUILD | 24 ++++- .../python/core/utils/hub_loader.py | 97 ++++++++++++++++++ .../python/core/utils/hub_loader_test.py | 59 +++++++++++ .../python/core/utils/testdata/BUILD | 23 ----- .../hub_module_v1_mini/saved_model.pb | Bin 0 -> 485 bytes .../hub_module_v1_mini/tfhub_module.pb | 1 + .../hub_module_v1_mini_train/saved_model.pb | Bin 0 -> 4441 bytes .../hub_module_v1_mini_train/tfhub_module.pb | 1 + .../variables/variables.data-00000-of-00001 | 2 + .../variables/variables.index | Bin 0 -> 134 bytes .../saved_model_v2_mini/saved_model.pb | Bin 0 -> 8863 bytes .../variables/variables.data-00000-of-00001 | Bin 0 -> 126 bytes .../variables/variables.index | Bin 0 -> 199 bytes .../python/text/core/bert_model_spec.py | 17 ++- .../python/text/text_classifier/BUILD | 1 + .../python/text/text_classifier/model_spec.py | 28 +---- .../text/text_classifier/model_spec_test.py | 4 +- .../text/text_classifier/preprocessor_test.py | 20 ++-- .../text/text_classifier/text_classifier.py | 64 +++++++----- .../text_classifier/text_classifier_test.py | 18 ++-- 20 files changed, 260 insertions(+), 99 deletions(-) create mode 100644 mediapipe/model_maker/python/core/utils/hub_loader.py create mode 100644 mediapipe/model_maker/python/core/utils/hub_loader_test.py delete mode 100644 mediapipe/model_maker/python/core/utils/testdata/BUILD create mode 100644 mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/saved_model.pb create mode 100644 mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/tfhub_module.pb create mode 100644 mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/saved_model.pb create mode 100644 mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/tfhub_module.pb create mode 100644 mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.data-00000-of-00001 create mode 100644 mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.index create mode 100644 mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/saved_model.pb create mode 100644 mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.data-00000-of-00001 create mode 100644 mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.index diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index 2c29970bb..c5e031245 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -19,6 +19,13 @@ licenses(["notice"]) package(default_visibility = ["//mediapipe:__subpackages__"]) +filegroup( + name = "testdata", + srcs = glob([ + "testdata/**", + ]), +) + py_library( name = "test_util", testonly = 1, @@ -56,11 +63,26 @@ py_library( py_test( name = "file_util_test", srcs = ["file_util_test.py"], - data = ["//mediapipe/model_maker/python/core/utils/testdata"], + data = [":testdata"], tags = ["requires-net:external"], deps = [":file_util"], ) +py_library( + name = "hub_loader", + srcs = ["hub_loader.py"], +) + +py_test( + name = "hub_loader_test", + srcs = ["hub_loader_test.py"], + data = [":testdata"], + deps = [ + ":hub_loader", + "//mediapipe/tasks/python/test:test_utils", + ], +) + py_library( name = "loss_functions", srcs = ["loss_functions.py"], diff --git a/mediapipe/model_maker/python/core/utils/hub_loader.py b/mediapipe/model_maker/python/core/utils/hub_loader.py new file mode 100644 index 000000000..a52099884 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/hub_loader.py @@ -0,0 +1,97 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Handles both V1 and V2 modules.""" + +import tensorflow_hub as hub + + +class HubKerasLayerV1V2(hub.KerasLayer): + """Class to loads TF v1 and TF v2 hub modules that could be fine-tuned. + + Since TF v1 modules couldn't be retrained in hub.KerasLayer. This class + provides a workaround for retraining the whole tf1 model in tf2. In + particular, it extract self._func._self_unconditional_checkpoint_dependencies + into trainable variable in tf1. + + Doesn't update moving-mean/moving-variance for BatchNormalization during + fine-tuning. + """ + + def _setup_layer(self, trainable=False, **kwargs): + if self._is_hub_module_v1: + self._setup_layer_v1(trainable, **kwargs) + else: + # call _setup_layer from the base class for v2. + super(HubKerasLayerV1V2, self)._setup_layer(trainable, **kwargs) + + def _check_trainability(self): + if self._is_hub_module_v1: + self._check_trainability_v1() + else: + # call _check_trainability from the base class for v2. + super(HubKerasLayerV1V2, self)._check_trainability() + + def _setup_layer_v1(self, trainable=False, **kwargs): + """Constructs keras layer with relevant weights and losses.""" + # Initialize an empty layer, then add_weight() etc. as needed. + super(hub.KerasLayer, self).__init__(trainable=trainable, **kwargs) + + if not self._is_hub_module_v1: + raise ValueError( + 'Only supports to set up v1 hub module in this function.' + ) + + # v2 trainable_variable: + if hasattr(self._func, 'trainable_variables'): + for v in self._func.trainable_variables: + self._add_existing_weight(v, trainable=True) + trainable_variables = {id(v) for v in self._func.trainable_variables} + else: + trainable_variables = set() + + if not hasattr(self._func, '_self_unconditional_checkpoint_dependencies'): + raise ValueError( + "_func doesn't contains attribute " + '_self_unconditional_checkpoint_dependencies.' + ) + dependencies = self._func._self_unconditional_checkpoint_dependencies # pylint: disable=protected-access + + # Adds trainable variables. + for dep in dependencies: + if dep.name == 'variables': + for v in dep.ref: + if id(v) not in trainable_variables: + self._add_existing_weight(v, trainable=True) + trainable_variables.add(id(v)) + + # Adds non-trainable variables. + if hasattr(self._func, 'variables'): + for v in self._func.variables: + if id(v) not in trainable_variables: + self._add_existing_weight(v, trainable=False) + + # Forward the callable's regularization losses (if any). + if hasattr(self._func, 'regularization_losses'): + for l in self._func.regularization_losses: + if not callable(l): + raise ValueError( + 'hub.KerasLayer(obj) expects obj.regularization_losses to be an ' + 'iterable of callables, each returning a scalar loss term.' + ) + self.add_loss(self._call_loss_if_trainable(l)) # Supports callables. + + def _check_trainability_v1(self): + """Ignores trainability checks for V1.""" + if self._is_hub_module_v1: + return # Nothing to do. diff --git a/mediapipe/model_maker/python/core/utils/hub_loader_test.py b/mediapipe/model_maker/python/core/utils/hub_loader_test.py new file mode 100644 index 000000000..8ea15b5d1 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/hub_loader_test.py @@ -0,0 +1,59 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import hub_loader +from mediapipe.tasks.python.test import test_utils + + +class HubKerasLayerV1V2Test(tf.test.TestCase, parameterized.TestCase): + + @parameterized.parameters( + ("hub_module_v1_mini", True), + ("saved_model_v2_mini", True), + ("hub_module_v1_mini", False), + ("saved_model_v2_mini", False), + ) + def test_load_with_defaults(self, module_name, trainable): + inputs, expected_outputs = 10.0, 11.0 # Test modules perform increment op. + path = test_utils.get_test_data_path(module_name) + layer = hub_loader.HubKerasLayerV1V2(path, trainable=trainable) + output = layer(inputs) + self.assertEqual(output, expected_outputs) + + def test_trainable_variable(self): + path = test_utils.get_test_data_path("hub_module_v1_mini_train") + layer = hub_loader.HubKerasLayerV1V2(path, trainable=True) + # Checks trainable variables. + self.assertLen(layer.trainable_variables, 2) + self.assertEqual(layer.trainable_variables[0].name, "a:0") + self.assertEqual(layer.trainable_variables[1].name, "b:0") + self.assertEqual(layer.variables, layer.trainable_variables) + # Checks non-trainable variables. + self.assertEmpty(layer.non_trainable_variables) + + layer = hub_loader.HubKerasLayerV1V2(path, trainable=False) + # Checks trainable variables. + self.assertEmpty(layer.trainable_variables) + # Checks non-trainable variables. + self.assertLen(layer.non_trainable_variables, 2) + self.assertEqual(layer.non_trainable_variables[0].name, "a:0") + self.assertEqual(layer.non_trainable_variables[1].name, "b:0") + self.assertEqual(layer.variables, layer.non_trainable_variables) + + +if __name__ == "__main__": + tf.test.main() diff --git a/mediapipe/model_maker/python/core/utils/testdata/BUILD b/mediapipe/model_maker/python/core/utils/testdata/BUILD deleted file mode 100644 index ea45f6140..000000000 --- a/mediapipe/model_maker/python/core/utils/testdata/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - default_visibility = ["//mediapipe/model_maker/python/core/utils:__subpackages__"], - licenses = ["notice"], # Apache 2.0 -) - -filegroup( - name = "testdata", - srcs = ["test.txt"], -) diff --git a/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/saved_model.pb b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/saved_model.pb new file mode 100644 index 0000000000000000000000000000000000000000..e60e04a242efd314674ebf22024c4b2b7e64775b GIT binary patch literal 485 zcmZ{g!A`tGl`f2Oqi;q9M^p66#Z<>FY(!!6 z7+gPK-a(y~0Pg@eKsL&=qU0?M@~WIxrC@3`AX>;?lC;tklsYT&nP{NxSRjrtyi1cu zJ;_t0XZ!QJ-^B^CVvb-{c0dOGZvVX7i)vN&oaL|C0O$+w!o;D!bQx~E-5$se zHph0?@ful}+-n@(B1@DyTaXPUPd6D3Z?F#JNZ)am*Soa#9?eHLcJLBC!ye^;O;nhDuUSUn?r=&_tim<(Xgy_xV-Vtz(CZEfC%9fo%mUYN6L?9J{TBV=w+u zwnai5Ir0Y}kl;#B960vC{{Rva{|j&8II*+#*iBDdGKur%z3;p??=v3ymkxib^rZ?W zD5thfRUnr-`CN?8ab!5Yz)^n(@L@ZCe&LvV&>L;{4oIu5c9boRZo(;aaPFf~g#17VoWb5{6uD3QM?7@7GfmSngo8#UEc|*8pWaOKVdDpp}!4X%w*`;W1ml*!e1*n zNA>xfiv`~TiXrbtcYQZ@x#zxc?A1-jDD>QD>@F?(fKv_i#`POFuMbC?iT|Y^eC?kQ z`cEIe0i6$)-oimRtLHdL3p$t!nC(%cn^mVZRE$3K#_5h7&x|5WBJK*hVeDr+S36TR zg^EXqppGXVk0RV58dbeuV&3K${F*fQ)b(lF9>!7TbNdW5f*#v{|Nif*a=&?<#CYq9 z&Nj5njr!pHQX2fDHwpeL0Ykd^bf14BfqS( zaQ!SSsgPb($KUEiPLYxVmFzV#ULzORy{_D!2)rw+JQYV;O?X^UoyGe3^nh@h2b*Z8 z6a&~c6~)Psff%$%IE4~k*@X{h{K@ldPy)CT=rM`Yd!X?_dNk83M>I}k(h3yt0r%mE z55yre`^cR%D<_Z8Q^-QqensM1jfvY3rB*()1kD0^(%|b6w-8=RTzT-dh}$r@Ag(mJ zoVe1dv+l@%pp>>qLzRNK@=y!nO3kK%xN?g+LRghOD~K!iy&$gCtdQ{w|#g_}dVIH%ch=_zXt z-zTH@WbaGn@MPxA65g^s$c{^QQ0h(A#=@KalEBvgNU!lRY}(Fao_I0+0In^69KlZR Wif<102;RC0?Jq6x!igo%{`n6aiv1k` literal 0 HcmV?d00001 diff --git a/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/tfhub_module.pb b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/tfhub_module.pb new file mode 100644 index 000000000..d65dd8f1d --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/tfhub_module.pb @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.data-00000-of-00001 b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.data-00000-of-00001 new file mode 100644 index 000000000..3474955ee --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.data-00000-of-00001 @@ -0,0 +1,2 @@ +øÌû¾âì¿ + diff --git a/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.index b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.index new file mode 100644 index 0000000000000000000000000000000000000000..d0e35ab879470e3b996cc09607b980876f5a1a70 GIT binary patch literal 134 zcmZQzVB=tvV&Y(AVB}8ZU=(7|U@O=@=tL{vPVWw(36+&VxPwaHO zlR8<)X*+P>*MZe&K^)*R7o;K(jF3RAI3eu?B#v<59{_QI8@IjjwH@2fj-AP_SXQV~ zYWqFE_x=6-e810mLSJmcU$XT5EUZFyqh)j~HUT58uGQ{Z<*aHwYis2R$Q~)pu0}^J z&FHku@;JC7G-Hp$2xRWJ9<Qf3dg zC{Kg^m_k$USj+ZOpIIj58E^(8PYXaAG}GuBb(}C#Zy62!fR1!5Q#TG+nx@JO#@wfJ z+iY2_(AYeTYbRR0Yqd<--i9B94H(^SDs}BitEp;cKlfo~arbQJiGpjm|5&D$E)n=5 z6n8X5-B(OqdEC?5Tj-}{^X7bW6pM_Brr>Uv3p&W)36S1-~ijWsPZduNDH1kLk zN2*t<*RQNdxvuflXdN4?g#Kd{nlRQ;jx@R6>KYdA%2OwfdW)rNDf`ATbedi%i08rt z6=`S{@eCAY*E?i~nhwnXt^pccK=K1*HK zEss}J?MSa{wTqbK}HKVS{zD(p}Q)#y~ zQ@*jfigUxEhS3t_K)AjAa580IpztS&l)jrF`7+tE~U%#Dw|caxcZ07Y9Lrrgk*n&U@FqD$03$RpsV>EvLP6s1^%2KGva zV6N4QOSbGU?wp4OUiPwOVpLXT$I@2duBvzQ0FWQSVp34g~RU>JhQDM;9!D{TAFuQB7tEq4GVpzr5CqWThvP|3| zhP0G4m#`1}IVO$>1vL7Aai-!6m7R~g<`N!^PX}!J52Kn$2szSJA5rX=w#7PjTyf#J zgPc5lL>G=&Y=)-uWPRA|%_}k|vc9C7hmo4IqGPrFG_mJ4sCA#CHrtwv=uqx75uqf} z%PDr?j^eCmcVp{-uIMtBlFlY>$o@*87XWDQKVuQq)w~2+TR~6iXUNwdyO_XMFvAh@ z5J4@fv6((Grx8VAQRZM&l$(tt3KGiRC4qg{7IN5kGxT}t2)u69O=gI#4zqp z^-(dT2VLxY@xSRq(=mjxQ-XoOQ83DC2tTTN=C-{4x^Hd^^2VLbjW2CK*t);xK;;MP zcQ-foqML+C&v1c1OXH5?*v}&hYFRt4;y+o1p`)Q zTd6;NMaYOc2QLYaYF6lxFIS@sU_#!*Dl{khnK?P=rlsX!Nw=Ld1`yGhXRI=uGy>&> z;72lR@v4}Kql=l4-|7K6>R*L%&h(>xIuoUtlwHvfynP&i<=r6u3kIgVrvb}xoBKw| zsW%Dx*y?rkYy&KBd9;Q_9>~(^V8q{AJPYGkpCU;0+7EUPPT7BG4l(+S)3G$^ zrJ2AbsPEym%lQ<21D0S$#&Xn@J8iA5H*~L8GfsKc!-r|O0&{%P(rB5-22GAOR3UU4T1-##scft1kU*P!(`h35n?SRk;Q&UfCUssvr|OLKn2zE7af z0BPS7HrTRRj0bYHMk=PioW(5{#P@%s?wnD$204 zH)vP(M*@7zNHp+8ewBC=U*e}k+4GH@!6ScO)Zi1_tkD`QCyj^5X1Dm}e014({L?!+ zW8dI$0GDg{49lh-MdL2ZKC(~VqI-GpA~9KpUU#TEBt6c=x+j^rM~(Y zAs!O_Ynby3aw1-|+r;xb*zQO`TpzzZ{0W^>*uAXCF$eovY?3HNrUlAR_Dv6q*~Oh_ z%by_PDrfV5Ad_S4E=+c?Ko3q{YAZD9OaOaINkejVHC?^{)Am}beArU8rhJHprFcX{ z5#YOVMLdCt$RQcUV@CwmbRJ)g9-q}#CJ95gpi7YNome7mFclSUZyIB6Ky5Tz$Cch< z73|TgdfgH~53{lZ3_ap>L$}Yb=n}n@Gz2cyHVL)gW#Il0=Z<68K+el@9X}k9$&*O}Gc}qo(3BK9J%Kl2I-j9YF_mXwJUGfbRNB2YJ|NQX zv3Zu@LG&jC-Xp&i-o@Xd;DC#V}gv=8o$3daW^f$ p_ENypDbA>-pQk6eTZc5dku29`Hh(IVVCqyTLTG8QnJfqV{RflkHqZb7 literal 0 HcmV?d00001 diff --git a/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.data-00000-of-00001 b/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000000000000000000000000000000000..09dbb330ded529a9c09c79d8285a66ab3de8a684 GIT binary patch literal 126 zcmZQzfPliJH}`IEDRBvKFbeS$CzhqgC+C;um82GN@o_K-aTRB#=Ovbu7Nr(*c?!96 r@r5}Cc{)1zxWco>X1@L)~~PU7!nfX=@c5`8my11jf(*QrxhfW literal 0 HcmV?d00001 diff --git a/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.index b/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.index new file mode 100644 index 0000000000000000000000000000000000000000..7cfb9ffd40b028a98a66248c864a6c8644258a31 GIT binary patch literal 199 zcmZQzVB=tvV&Y(Akl~GY_HcFf4)FK%3vqPvagFzP@^W str: + if isinstance(self.files, file_util.DownloadedFiles): + return self.files.get_path() + elif isinstance(self.files, str): + return self.files + else: + raise ValueError(f'files has unsupported type: {type(self.files)}') diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 322b1e1e5..e32733e31 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -131,6 +131,7 @@ py_library( ":text_classifier_options", "//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/tasks:classifier", + "//mediapipe/model_maker/python/core/utils:hub_loader", "//mediapipe/model_maker/python/core/utils:loss_functions", "//mediapipe/model_maker/python/core/utils:metrics", "//mediapipe/model_maker/python/core/utils:model_util", diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec.py b/mediapipe/model_maker/python/text/text_classifier/model_spec.py index 724aaf377..01d1432cb 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec.py @@ -23,16 +23,8 @@ from mediapipe.model_maker.python.text.text_classifier import hyperparameters as from mediapipe.model_maker.python.text.text_classifier import model_options as mo -MOBILEBERT_TINY_FILES = file_util.DownloadedFiles( - 'text_classifier/mobilebert_tiny', - 'https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny.tar.gz', - is_folder=True, -) - -EXBERT_FILES = file_util.DownloadedFiles( - 'text_classifier/exbert', - 'https://storage.googleapis.com/mediapipe-assets/exbert.tar.gz', - is_folder=True, +MOBILEBERT_FILES = ( + 'https://tfhub.dev/google/mobilebert/uncased_L-24_H-128_B-512_A-4_F-4_OPT/1' ) @@ -71,23 +63,14 @@ class BertClassifierSpec(bert_model_spec.BertModelSpec): hparams: hp.BertHParams = dataclasses.field(default_factory=hp.BertHParams) - mobilebert_classifier_spec = functools.partial( BertClassifierSpec, - downloaded_files=MOBILEBERT_TINY_FILES, + files=MOBILEBERT_FILES, hparams=hp.BertHParams( epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' ), - name='MobileBert', -) - -exbert_classifier_spec = functools.partial( - BertClassifierSpec, - downloaded_files=EXBERT_FILES, - hparams=hp.BertHParams( - epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' - ), - name='ExBert', + name='MobileBERT', + is_tf2=False, ) @@ -96,4 +79,3 @@ class SupportedModels(enum.Enum): """Predefined text classifier model specs supported by Model Maker.""" AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec - EXBERT_CLASSIFIER = exbert_classifier_spec diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py index 4d42851d5..d1e578b81 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py @@ -42,8 +42,8 @@ class ModelSpecTest(tf.test.TestCase): def test_predefined_bert_spec(self): model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value() self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec) - self.assertEqual(model_spec_obj.name, 'MobileBert') - self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path())) + self.assertEqual(model_spec_obj.name, 'MobileBERT') + self.assertTrue(model_spec_obj.files) self.assertTrue(model_spec_obj.do_lower_case) self.assertEqual( model_spec_obj.tflite_input_name, diff --git a/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py b/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py index 28c12f96c..ff9015498 100644 --- a/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py @@ -87,11 +87,11 @@ class PreprocessorTest(tf.test.TestCase): csv_file = self._get_csv_file() dataset = text_classifier_ds.Dataset.from_csv( filename=csv_file, csv_params=self.CSV_PARAMS_) - bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value() + bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() bert_preprocessor = preprocessor.BertClassifierPreprocessor( seq_len=5, do_lower_case=bert_spec.do_lower_case, - uri=bert_spec.downloaded_files.get_path(), + uri=bert_spec.get_path(), model_name=bert_spec.name, ) preprocessed_dataset = bert_preprocessor.preprocess(dataset) @@ -121,11 +121,11 @@ class PreprocessorTest(tf.test.TestCase): csv_params=self.CSV_PARAMS_, cache_dir=self.get_temp_dir(), ) - bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value() + bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() bert_preprocessor = preprocessor.BertClassifierPreprocessor( seq_len=5, do_lower_case=bert_spec.do_lower_case, - uri=bert_spec.downloaded_files.get_path(), + uri=bert_spec.get_path(), model_name=bert_spec.name, ) ds_cache_files = dataset.tfrecord_cache_files @@ -153,7 +153,7 @@ class PreprocessorTest(tf.test.TestCase): bert_preprocessor = preprocessor.BertClassifierPreprocessor( seq_len=seq_len, do_lower_case=do_lower_case, - uri=bert_spec.downloaded_files.get_path(), + uri=bert_spec.get_path(), model_name=bert_spec.name, ) new_cf = bert_preprocessor._get_tfrecord_cache_files(cf) @@ -167,10 +167,6 @@ class PreprocessorTest(tf.test.TestCase): cache_dir=self.get_temp_dir(), num_shards=1, ) - exbert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value() - all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, True)) - all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 10, True)) - all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, False)) mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True)) all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True)) @@ -180,10 +176,10 @@ class PreprocessorTest(tf.test.TestCase): cache_dir=self.get_temp_dir(), num_shards=1, ) - all_cf_prefixes.add(self._get_new_prefix(new_cf, exbert_spec, 5, True)) + all_cf_prefixes.add(self._get_new_prefix(new_cf, mobilebert_spec, 5, True)) - # Each item of all_cf_prefixes should be unique, so 7 total. - self.assertLen(all_cf_prefixes, 7) + # Each item of all_cf_prefixes should be unique. + self.assertLen(all_cf_prefixes, 4) if __name__ == '__main__': diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index 10d88110d..76043aa72 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -24,6 +24,7 @@ import tensorflow_hub as hub from mediapipe.model_maker.python.core.data import dataset as ds from mediapipe.model_maker.python.core.tasks import classifier +from mediapipe.model_maker.python.core.utils import hub_loader from mediapipe.model_maker.python.core.utils import loss_functions from mediapipe.model_maker.python.core.utils import metrics from mediapipe.model_maker.python.core.utils import model_util @@ -52,18 +53,21 @@ def _validate(options: text_classifier_options.TextClassifierOptions): if options.model_options is None: return - if (isinstance(options.model_options, mo.AverageWordEmbeddingModelOptions) and - (options.supported_model != - ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)): - raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER," - f" got {options.supported_model}") - if isinstance(options.model_options, mo.BertModelOptions) and ( - options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER - and options.supported_model != ms.SupportedModels.EXBERT_CLASSIFIER + if isinstance( + options.model_options, mo.AverageWordEmbeddingModelOptions + ) and ( + options.supported_model + != ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER ): raise ValueError( - "Expected a Bert Classifier(MobileBERT or EXBERT), got " - f"{options.supported_model}" + "Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER," + f" got {options.supported_model}" + ) + if isinstance(options.model_options, mo.BertModelOptions) and ( + not isinstance(options.supported_model.value(), ms.BertClassifierSpec) + ): + raise ValueError( + f"Expected a Bert Classifier, got {options.supported_model}" ) @@ -113,15 +117,13 @@ class TextClassifier(classifier.Classifier): if options.hparams is None: options.hparams = options.supported_model.value().hparams - if ( - options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER - or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER - ): + if isinstance(options.supported_model.value(), ms.BertClassifierSpec): text_classifier = _BertClassifier.create_bert_classifier( train_data, validation_data, options ) - elif (options.supported_model == - ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER): + elif isinstance( + options.supported_model.value(), ms.AverageWordEmbeddingClassifierSpec + ): text_classifier = _AverageWordEmbeddingClassifier.create_average_word_embedding_classifier( train_data, validation_data, options ) @@ -348,12 +350,12 @@ class _BertClassifier(TextClassifier): self._hparams = hparams self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) self._model_options = model_options + self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None with self._hparams.get_strategy().scope(): self._loss_function = loss_functions.SparseFocalLoss( self._hparams.gamma, self._num_classes ) self._metric_functions = self._create_metrics() - self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None @classmethod def create_bert_classifier( @@ -410,7 +412,7 @@ class _BertClassifier(TextClassifier): self._text_preprocessor = preprocessor.BertClassifierPreprocessor( seq_len=self._model_options.seq_len, do_lower_case=self._model_spec.do_lower_case, - uri=self._model_spec.downloaded_files.get_path(), + uri=self._model_spec.get_path(), model_name=self._model_spec.name, ) return ( @@ -488,12 +490,26 @@ class _BertClassifier(TextClassifier): name="input_type_ids", ), ) - encoder = hub.KerasLayer( - self._model_spec.downloaded_files.get_path(), - trainable=self._model_options.do_fine_tuning, - ) - encoder_outputs = encoder(encoder_inputs) - pooled_output = encoder_outputs["pooled_output"] + if self._model_spec.is_tf2: + encoder = hub.KerasLayer( + self._model_spec.get_path(), + trainable=self._model_options.do_fine_tuning, + ) + encoder_outputs = encoder(encoder_inputs) + pooled_output = encoder_outputs["pooled_output"] + else: + renamed_inputs = dict( + input_ids=encoder_inputs["input_word_ids"], + input_mask=encoder_inputs["input_mask"], + segment_ids=encoder_inputs["input_type_ids"], + ) + encoder = hub_loader.HubKerasLayerV1V2( + self._model_spec.get_path(), + signature="tokens", + output_key="pooled_output", + trainable=self._model_options.do_fine_tuning, + ) + pooled_output = encoder(renamed_inputs) output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)( pooled_output) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index be4646f68..122182ddd 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -104,13 +104,9 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( # Skipping mobilebert b/c OSS test timeout/flakiness: b/275624089 - # dict( - # testcase_name='mobilebert', - # supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER, - # ), dict( - testcase_name='exbert', - supported_model=text_classifier.SupportedModels.EXBERT_CLASSIFIER, + testcase_name='mobilebert', + supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER, ), ) def test_create_and_train_bert(self, supported_model): @@ -156,7 +152,7 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase): def test_label_mismatch(self): options = text_classifier.TextClassifierOptions( - supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER) + supported_model=(text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER) ) train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) train_data = text_classifier.Dataset(train_tf_dataset, ['foo'], 1) @@ -174,13 +170,13 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase): train_data, validation_data = self._get_data() avg_options = text_classifier.TextClassifierOptions( - supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER), + supported_model=(text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER), model_options=text_classifier.AverageWordEmbeddingModelOptions(), ) with self.assertRaisesWithLiteralMatch( ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got' - ' SupportedModels.EXBERT_CLASSIFIER', + ' SupportedModels.MOBILEBERT_CLASSIFIER', ): text_classifier.TextClassifier.create( train_data, validation_data, avg_options @@ -194,7 +190,7 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase): ) with self.assertRaisesWithLiteralMatch( ValueError, - 'Expected a Bert Classifier(MobileBERT or EXBERT), got' + 'Expected a Bert Classifier, got' ' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER', ): text_classifier.TextClassifier.create( @@ -203,7 +199,7 @@ class TextClassifierTest(tf.test.TestCase, parameterized.TestCase): def test_bert_loss_and_metrics_creation(self): train_data, validation_data = self._get_data() - supported_model = text_classifier.SupportedModels.EXBERT_CLASSIFIER + supported_model = text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER hparams = text_classifier.BertHParams( desired_recalls=[0.2], desired_precisions=[0.9],