From 9cbb76939dd069eacecae103c5c27b6e07c7e9c7 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 11 Jan 2023 20:33:26 -0800 Subject: [PATCH] Adds smaller MobileBERT model. PiperOrigin-RevId: 501451414 --- .../model_maker/models/text_classifier/BUILD | 45 ++++++++++ .../python/text/text_classifier/BUILD | 11 +++ .../python/text/text_classifier/model_spec.py | 13 +-- .../text/text_classifier/model_spec_test.py | 7 +- .../text/text_classifier/testdata/BUILD | 5 +- .../testdata/bert_metadata.json | 84 +++++++++++++++++++ .../text/text_classifier/text_classifier.py | 13 ++- .../text_classifier/text_classifier_test.py | 25 +++++- mediapipe/model_maker/setup.py | 12 ++- third_party/external_files.bzl | 30 +++++++ 10 files changed, 228 insertions(+), 17 deletions(-) create mode 100644 mediapipe/model_maker/models/text_classifier/BUILD create mode 100644 mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json diff --git a/mediapipe/model_maker/models/text_classifier/BUILD b/mediapipe/model_maker/models/text_classifier/BUILD new file mode 100644 index 000000000..4c54bbccc --- /dev/null +++ b/mediapipe/model_maker/models/text_classifier/BUILD @@ -0,0 +1,45 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) + +licenses(["notice"]) + +package( + default_visibility = ["//mediapipe/model_maker/python/text/text_classifier:__subpackages__"], +) + +mediapipe_files( + srcs = [ + "mobilebert_tiny/assets/vocab.txt", + "mobilebert_tiny/keras_metadata.pb", + "mobilebert_tiny/saved_model.pb", + "mobilebert_tiny/variables/variables.data-00000-of-00001", + "mobilebert_tiny/variables/variables.index", + ], +) + +filegroup( + name = "mobilebert_tiny", + srcs = [ + "mobilebert_tiny/assets/vocab.txt", + "mobilebert_tiny/keras_metadata.pb", + "mobilebert_tiny/saved_model.pb", + "mobilebert_tiny/variables/variables.data-00000-of-00001", + "mobilebert_tiny/variables/variables.index", + ], +) diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 7bb41351e..43f2b6c75 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -53,6 +53,7 @@ py_library( deps = [ ":model_options", "//mediapipe/model_maker/python/core:hyperparameters", + "//mediapipe/model_maker/python/core/utils:file_util", "//mediapipe/model_maker/python/text/core:bert_model_spec", ], ) @@ -88,6 +89,9 @@ py_library( py_test( name = "preprocessor_test", srcs = ["preprocessor_test.py"], + data = [ + "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", + ], tags = ["requires-net:external"], deps = [ ":dataset", @@ -109,6 +113,9 @@ py_library( py_library( name = "text_classifier", srcs = ["text_classifier.py"], + data = [ + "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", + ], deps = [ ":dataset", ":model_options", @@ -130,6 +137,7 @@ py_test( size = "large", srcs = ["text_classifier_test.py"], data = [ + "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", "//mediapipe/model_maker/python/text/text_classifier/testdata", ], tags = ["requires-net:external"], @@ -151,6 +159,9 @@ py_library( py_binary( name = "text_classifier_demo", srcs = ["text_classifier_demo.py"], + data = [ + "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", + ], deps = [ ":text_classifier_demo_lib", ], diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec.py b/mediapipe/model_maker/python/text/text_classifier/model_spec.py index 9df7e1039..a6bdd9522 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec.py @@ -18,12 +18,15 @@ import enum import functools from mediapipe.model_maker.python.core import hyperparameters as hp +from mediapipe.model_maker.python.core.utils import file_util from mediapipe.model_maker.python.text.core import bert_model_spec from mediapipe.model_maker.python.text.text_classifier import model_options as mo # BERT-based text classifier spec inherited from BertModelSpec BertClassifierSpec = bert_model_spec.BertModelSpec +MOBILEBERT_TINY_PATH = 'mediapipe/model_maker/models/text_classifier/mobilebert_tiny/' + @dataclasses.dataclass class AverageWordEmbeddingClassifierSpec: @@ -49,16 +52,14 @@ average_word_embedding_classifier_spec = functools.partial( mobilebert_classifier_spec = functools.partial( BertClassifierSpec, hparams=hp.BaseHParams( - epochs=3, - batch_size=48, - learning_rate=3e-5, - distribution_strategy='off'), + epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' + ), name='MobileBert', - uri='https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1', + uri=file_util.get_absolute_path(MOBILEBERT_TINY_PATH), tflite_input_name={ 'ids': 'serving_default_input_1:0', 'mask': 'serving_default_input_3:0', - 'segment_ids': 'serving_default_input_2:0' + 'segment_ids': 'serving_default_input_2:0', }, ) diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py index dd7f880f3..3ea019b44 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py @@ -28,9 +28,10 @@ class ModelSpecTest(tf.test.TestCase): model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value() self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec) self.assertEqual(model_spec_obj.name, 'MobileBert') - self.assertEqual( - model_spec_obj.uri, 'https://tfhub.dev/tensorflow/' - 'mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1') + self.assertIn( + 'mediapipe/model_maker/models/text_classifier/mobilebert_tiny', + model_spec_obj.uri, + ) self.assertTrue(model_spec_obj.do_lower_case) self.assertEqual( model_spec_obj.tflite_input_name, { diff --git a/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD b/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD index 663c72082..a581462cf 100644 --- a/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD @@ -19,5 +19,8 @@ package( filegroup( name = "testdata", - srcs = ["average_word_embedding_metadata.json"], + srcs = [ + "average_word_embedding_metadata.json", + "bert_metadata.json", + ], ) diff --git a/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json b/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json new file mode 100644 index 000000000..24214a80d --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json @@ -0,0 +1,84 @@ +{ + "name": "TextClassifier", + "description": "Classify the input text into a set of known categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "ids", + "description": "Tokenized ids of the input text.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "mask", + "description": "Mask with 1 for real tokens and 0 for padding tokens.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + }, + { + "name": "segment_ids", + "description": "0 for the first sequence, 1 for the second sequence if exists.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + } + } + ], + "output_tensor_metadata": [ + { + "name": "score", + "description": "Score of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ], + "input_process_units": [ + { + "options_type": "BertTokenizerOptions", + "options": { + "vocab_file": [ + { + "name": "vocab.txt", + "description": "Vocabulary file to convert natural language words to embedding vectors.", + "type": "VOCABULARY" + } + ] + } + } + ] + } + ], + "min_parser_version": "1.1.0" +} diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index 1a338e345..f6abc8bf0 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -269,16 +269,21 @@ class _AverageWordEmbeddingClassifier(TextClassifier): """Creates an Average Word Embedding model.""" self._model = tf.keras.Sequential([ tf.keras.layers.InputLayer( - input_shape=[self._model_options.seq_len], dtype=tf.int32), + input_shape=[self._model_options.seq_len], + dtype=tf.int32, + name="input_ids", + ), tf.keras.layers.Embedding( len(self._text_preprocessor.get_vocab()), self._model_options.wordvec_dim, - input_length=self._model_options.seq_len), + input_length=self._model_options.seq_len, + ), tf.keras.layers.GlobalAveragePooling1D(), tf.keras.layers.Dense( - self._model_options.wordvec_dim, activation=tf.nn.relu), + self._model_options.wordvec_dim, activation=tf.nn.relu + ), tf.keras.layers.Dropout(self._model_options.dropout_rate), - tf.keras.layers.Dense(self._num_classes, activation="softmax") + tf.keras.layers.Dense(self._num_classes, activation="softmax"), ]) def _save_vocab(self, vocab_filepath: str): diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index eb4443b44..1ae2bc553 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -26,6 +26,9 @@ class TextClassifierTest(tf.test.TestCase): _AVERAGE_WORD_EMBEDDING_JSON_FILE = ( test_utils.get_test_data_path('average_word_embedding_metadata.json')) + _BERT_CLASSIFIER_JSON_FILE = test_utils.get_test_data_path( + 'bert_metadata.json' + ) def _get_data(self): labels_and_text = (('pos', 'super good'), (('neg', 'really bad'))) @@ -94,7 +97,27 @@ class TextClassifierTest(tf.test.TestCase): _, accuracy = bert_classifier.evaluate(validation_data) self.assertGreaterEqual(accuracy, 0.0) - # TODO: Add a unit test that does not run OOM. + + # Test export_model + bert_classifier.export_model() + output_metadata_file = os.path.join( + options.hparams.export_dir, 'metadata.json' + ) + output_tflite_file = os.path.join( + options.hparams.export_dir, 'model.tflite' + ) + + self.assertTrue(os.path.exists(output_tflite_file)) + self.assertGreater(os.path.getsize(output_tflite_file), 0) + + self.assertTrue(os.path.exists(output_metadata_file)) + self.assertGreater(os.path.getsize(output_metadata_file), 0) + filecmp.clear_cache() + self.assertTrue( + filecmp.cmp( + output_metadata_file, self._BERT_CLASSIFIER_JSON_FILE, shallow=False + ) + ) def test_label_mismatch(self): options = ( diff --git a/mediapipe/model_maker/setup.py b/mediapipe/model_maker/setup.py index 7114e2080..1dac6301a 100644 --- a/mediapipe/model_maker/setup.py +++ b/mediapipe/model_maker/setup.py @@ -81,7 +81,10 @@ def _setup_build_dir(): file.write(filedata) # Use bazel to download GCS model files - model_build_files = ['models/gesture_recognizer/BUILD'] + model_build_files = [ + 'models/gesture_recognizer/BUILD', + 'models/text_classifier/BUILD', + ] for model_build_file in model_build_files: build_target_file = os.path.join(BUILD_MM_DIR, model_build_file) os.makedirs(os.path.dirname(build_target_file), exist_ok=True) @@ -95,7 +98,12 @@ def _setup_build_dir(): 'models/gesture_recognizer/gesture_embedder/saved_model.pb', 'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001', 'models/gesture_recognizer/gesture_embedder/variables/variables.index', - ] + 'models/text_classifier/mobilebert_tiny/keras_metadata.pb', + 'models/text_classifier/mobilebert_tiny/saved_model.pb', + 'models/text_classifier/mobilebert_tiny/assets/vocab.txt', + 'models/text_classifier/mobilebert_tiny/variables/variables.data-00000-of-00001', + 'models/text_classifier/mobilebert_tiny/variables/variables.index', + ] for elem in external_files: external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem) sys.stderr.write('downloading file: %s\n' % external_file) diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 790486676..5adfbdfc6 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -1006,6 +1006,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668550484904822"], ) + http_file( + name = "com_google_mediapipe_mobilebert_tiny_keras_metadata_pb", + sha256 = "cef8131a414c602b9d4742ac57f4f90bc5d8a42baec36b65deece884e2d0cf0f", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/keras_metadata.pb?generation=1673297965144159"], + ) + + http_file( + name = "com_google_mediapipe_mobilebert_tiny_saved_model_pb", + sha256 = "323c997cd3e17df1b2e3bdebe3cfe2b17c5ffd9488a26a4afb59ee819196837a", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/saved_model.pb?generation=1673297968138825"], + ) + http_file( name = "com_google_mediapipe_object_detection_saved_model_model_ckpt_data-00000-of-00001", sha256 = "ad2f733f271dd5000a8c7f926bfea1083e6408b34d4f3b60679e5a6f96251c97", @@ -1053,3 +1065,21 @@ def external_files(): sha256 = "76ea482b8da6bdb3d65d3b2ea989c1699c9fa0d6df0cb6d80863d1dc6fe7c4bd", urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668550490691823"], ) + + http_file( + name = "com_google_mediapipe_mobilebert_tiny_assets_vocab_txt", + sha256 = "07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/assets/vocab.txt?generation=1673297970948751"], + ) + + http_file( + name = "com_google_mediapipe_mobilebert_tiny_variables_variables_data-00000-of-00001", + sha256 = "c3857370046cd3a2f345657cf1bb259a4e7e09185d7f0808e57803e9d41ebba4", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/variables/variables.data-00000-of-00001?generation=1673297975132568"], + ) + + http_file( + name = "com_google_mediapipe_mobilebert_tiny_variables_variables_index", + sha256 = "4df4d7c0fefe99903ab6ebf44b7478196ce613082d2ca692a5a37a7f24e562ed", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/variables/variables.index?generation=1673297977586840"], + )