Adds smaller MobileBERT model.

PiperOrigin-RevId: 501451414
This commit is contained in:
MediaPipe Team 2023-01-11 20:33:26 -08:00 committed by Copybara-Service
parent 8830eefa0b
commit 9cbb76939d
10 changed files with 228 additions and 17 deletions

View File

@ -0,0 +1,45 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//mediapipe/framework/tool:mediapipe_files.bzl",
"mediapipe_files",
)
licenses(["notice"])
package(
default_visibility = ["//mediapipe/model_maker/python/text/text_classifier:__subpackages__"],
)
mediapipe_files(
srcs = [
"mobilebert_tiny/assets/vocab.txt",
"mobilebert_tiny/keras_metadata.pb",
"mobilebert_tiny/saved_model.pb",
"mobilebert_tiny/variables/variables.data-00000-of-00001",
"mobilebert_tiny/variables/variables.index",
],
)
filegroup(
name = "mobilebert_tiny",
srcs = [
"mobilebert_tiny/assets/vocab.txt",
"mobilebert_tiny/keras_metadata.pb",
"mobilebert_tiny/saved_model.pb",
"mobilebert_tiny/variables/variables.data-00000-of-00001",
"mobilebert_tiny/variables/variables.index",
],
)

View File

@ -53,6 +53,7 @@ py_library(
deps = [ deps = [
":model_options", ":model_options",
"//mediapipe/model_maker/python/core:hyperparameters", "//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/core/utils:file_util",
"//mediapipe/model_maker/python/text/core:bert_model_spec", "//mediapipe/model_maker/python/text/core:bert_model_spec",
], ],
) )
@ -88,6 +89,9 @@ py_library(
py_test( py_test(
name = "preprocessor_test", name = "preprocessor_test",
srcs = ["preprocessor_test.py"], srcs = ["preprocessor_test.py"],
data = [
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
],
tags = ["requires-net:external"], tags = ["requires-net:external"],
deps = [ deps = [
":dataset", ":dataset",
@ -109,6 +113,9 @@ py_library(
py_library( py_library(
name = "text_classifier", name = "text_classifier",
srcs = ["text_classifier.py"], srcs = ["text_classifier.py"],
data = [
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
],
deps = [ deps = [
":dataset", ":dataset",
":model_options", ":model_options",
@ -130,6 +137,7 @@ py_test(
size = "large", size = "large",
srcs = ["text_classifier_test.py"], srcs = ["text_classifier_test.py"],
data = [ data = [
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
"//mediapipe/model_maker/python/text/text_classifier/testdata", "//mediapipe/model_maker/python/text/text_classifier/testdata",
], ],
tags = ["requires-net:external"], tags = ["requires-net:external"],
@ -151,6 +159,9 @@ py_library(
py_binary( py_binary(
name = "text_classifier_demo", name = "text_classifier_demo",
srcs = ["text_classifier_demo.py"], srcs = ["text_classifier_demo.py"],
data = [
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
],
deps = [ deps = [
":text_classifier_demo_lib", ":text_classifier_demo_lib",
], ],

View File

@ -18,12 +18,15 @@ import enum
import functools import functools
from mediapipe.model_maker.python.core import hyperparameters as hp from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.utils import file_util
from mediapipe.model_maker.python.text.core import bert_model_spec from mediapipe.model_maker.python.text.core import bert_model_spec
from mediapipe.model_maker.python.text.text_classifier import model_options as mo from mediapipe.model_maker.python.text.text_classifier import model_options as mo
# BERT-based text classifier spec inherited from BertModelSpec # BERT-based text classifier spec inherited from BertModelSpec
BertClassifierSpec = bert_model_spec.BertModelSpec BertClassifierSpec = bert_model_spec.BertModelSpec
MOBILEBERT_TINY_PATH = 'mediapipe/model_maker/models/text_classifier/mobilebert_tiny/'
@dataclasses.dataclass @dataclasses.dataclass
class AverageWordEmbeddingClassifierSpec: class AverageWordEmbeddingClassifierSpec:
@ -49,16 +52,14 @@ average_word_embedding_classifier_spec = functools.partial(
mobilebert_classifier_spec = functools.partial( mobilebert_classifier_spec = functools.partial(
BertClassifierSpec, BertClassifierSpec,
hparams=hp.BaseHParams( hparams=hp.BaseHParams(
epochs=3, epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
batch_size=48, ),
learning_rate=3e-5,
distribution_strategy='off'),
name='MobileBert', name='MobileBert',
uri='https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1', uri=file_util.get_absolute_path(MOBILEBERT_TINY_PATH),
tflite_input_name={ tflite_input_name={
'ids': 'serving_default_input_1:0', 'ids': 'serving_default_input_1:0',
'mask': 'serving_default_input_3:0', 'mask': 'serving_default_input_3:0',
'segment_ids': 'serving_default_input_2:0' 'segment_ids': 'serving_default_input_2:0',
}, },
) )

View File

@ -28,9 +28,10 @@ class ModelSpecTest(tf.test.TestCase):
model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value() model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value()
self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec) self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec)
self.assertEqual(model_spec_obj.name, 'MobileBert') self.assertEqual(model_spec_obj.name, 'MobileBert')
self.assertEqual( self.assertIn(
model_spec_obj.uri, 'https://tfhub.dev/tensorflow/' 'mediapipe/model_maker/models/text_classifier/mobilebert_tiny',
'mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1') model_spec_obj.uri,
)
self.assertTrue(model_spec_obj.do_lower_case) self.assertTrue(model_spec_obj.do_lower_case)
self.assertEqual( self.assertEqual(
model_spec_obj.tflite_input_name, { model_spec_obj.tflite_input_name, {

View File

@ -19,5 +19,8 @@ package(
filegroup( filegroup(
name = "testdata", name = "testdata",
srcs = ["average_word_embedding_metadata.json"], srcs = [
"average_word_embedding_metadata.json",
"bert_metadata.json",
],
) )

View File

@ -0,0 +1,84 @@
{
"name": "TextClassifier",
"description": "Classify the input text into a set of known categories.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "ids",
"description": "Tokenized ids of the input text.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
},
{
"name": "mask",
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
},
{
"name": "segment_ids",
"description": "0 for the first sequence, 1 for the second sequence if exists.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
],
"input_process_units": [
{
"options_type": "BertTokenizerOptions",
"options": {
"vocab_file": [
{
"name": "vocab.txt",
"description": "Vocabulary file to convert natural language words to embedding vectors.",
"type": "VOCABULARY"
}
]
}
}
]
}
],
"min_parser_version": "1.1.0"
}

View File

@ -269,16 +269,21 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
"""Creates an Average Word Embedding model.""" """Creates an Average Word Embedding model."""
self._model = tf.keras.Sequential([ self._model = tf.keras.Sequential([
tf.keras.layers.InputLayer( tf.keras.layers.InputLayer(
input_shape=[self._model_options.seq_len], dtype=tf.int32), input_shape=[self._model_options.seq_len],
dtype=tf.int32,
name="input_ids",
),
tf.keras.layers.Embedding( tf.keras.layers.Embedding(
len(self._text_preprocessor.get_vocab()), len(self._text_preprocessor.get_vocab()),
self._model_options.wordvec_dim, self._model_options.wordvec_dim,
input_length=self._model_options.seq_len), input_length=self._model_options.seq_len,
),
tf.keras.layers.GlobalAveragePooling1D(), tf.keras.layers.GlobalAveragePooling1D(),
tf.keras.layers.Dense( tf.keras.layers.Dense(
self._model_options.wordvec_dim, activation=tf.nn.relu), self._model_options.wordvec_dim, activation=tf.nn.relu
),
tf.keras.layers.Dropout(self._model_options.dropout_rate), tf.keras.layers.Dropout(self._model_options.dropout_rate),
tf.keras.layers.Dense(self._num_classes, activation="softmax") tf.keras.layers.Dense(self._num_classes, activation="softmax"),
]) ])
def _save_vocab(self, vocab_filepath: str): def _save_vocab(self, vocab_filepath: str):

View File

@ -26,6 +26,9 @@ class TextClassifierTest(tf.test.TestCase):
_AVERAGE_WORD_EMBEDDING_JSON_FILE = ( _AVERAGE_WORD_EMBEDDING_JSON_FILE = (
test_utils.get_test_data_path('average_word_embedding_metadata.json')) test_utils.get_test_data_path('average_word_embedding_metadata.json'))
_BERT_CLASSIFIER_JSON_FILE = test_utils.get_test_data_path(
'bert_metadata.json'
)
def _get_data(self): def _get_data(self):
labels_and_text = (('pos', 'super good'), (('neg', 'really bad'))) labels_and_text = (('pos', 'super good'), (('neg', 'really bad')))
@ -94,7 +97,27 @@ class TextClassifierTest(tf.test.TestCase):
_, accuracy = bert_classifier.evaluate(validation_data) _, accuracy = bert_classifier.evaluate(validation_data)
self.assertGreaterEqual(accuracy, 0.0) self.assertGreaterEqual(accuracy, 0.0)
# TODO: Add a unit test that does not run OOM.
# Test export_model
bert_classifier.export_model()
output_metadata_file = os.path.join(
options.hparams.export_dir, 'metadata.json'
)
output_tflite_file = os.path.join(
options.hparams.export_dir, 'model.tflite'
)
self.assertTrue(os.path.exists(output_tflite_file))
self.assertGreater(os.path.getsize(output_tflite_file), 0)
self.assertTrue(os.path.exists(output_metadata_file))
self.assertGreater(os.path.getsize(output_metadata_file), 0)
filecmp.clear_cache()
self.assertTrue(
filecmp.cmp(
output_metadata_file, self._BERT_CLASSIFIER_JSON_FILE, shallow=False
)
)
def test_label_mismatch(self): def test_label_mismatch(self):
options = ( options = (

View File

@ -81,7 +81,10 @@ def _setup_build_dir():
file.write(filedata) file.write(filedata)
# Use bazel to download GCS model files # Use bazel to download GCS model files
model_build_files = ['models/gesture_recognizer/BUILD'] model_build_files = [
'models/gesture_recognizer/BUILD',
'models/text_classifier/BUILD',
]
for model_build_file in model_build_files: for model_build_file in model_build_files:
build_target_file = os.path.join(BUILD_MM_DIR, model_build_file) build_target_file = os.path.join(BUILD_MM_DIR, model_build_file)
os.makedirs(os.path.dirname(build_target_file), exist_ok=True) os.makedirs(os.path.dirname(build_target_file), exist_ok=True)
@ -95,6 +98,11 @@ def _setup_build_dir():
'models/gesture_recognizer/gesture_embedder/saved_model.pb', 'models/gesture_recognizer/gesture_embedder/saved_model.pb',
'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001', 'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001',
'models/gesture_recognizer/gesture_embedder/variables/variables.index', 'models/gesture_recognizer/gesture_embedder/variables/variables.index',
'models/text_classifier/mobilebert_tiny/keras_metadata.pb',
'models/text_classifier/mobilebert_tiny/saved_model.pb',
'models/text_classifier/mobilebert_tiny/assets/vocab.txt',
'models/text_classifier/mobilebert_tiny/variables/variables.data-00000-of-00001',
'models/text_classifier/mobilebert_tiny/variables/variables.index',
] ]
for elem in external_files: for elem in external_files:
external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem) external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem)

View File

@ -1006,6 +1006,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668550484904822"], urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668550484904822"],
) )
http_file(
name = "com_google_mediapipe_mobilebert_tiny_keras_metadata_pb",
sha256 = "cef8131a414c602b9d4742ac57f4f90bc5d8a42baec36b65deece884e2d0cf0f",
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/keras_metadata.pb?generation=1673297965144159"],
)
http_file(
name = "com_google_mediapipe_mobilebert_tiny_saved_model_pb",
sha256 = "323c997cd3e17df1b2e3bdebe3cfe2b17c5ffd9488a26a4afb59ee819196837a",
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/saved_model.pb?generation=1673297968138825"],
)
http_file( http_file(
name = "com_google_mediapipe_object_detection_saved_model_model_ckpt_data-00000-of-00001", name = "com_google_mediapipe_object_detection_saved_model_model_ckpt_data-00000-of-00001",
sha256 = "ad2f733f271dd5000a8c7f926bfea1083e6408b34d4f3b60679e5a6f96251c97", sha256 = "ad2f733f271dd5000a8c7f926bfea1083e6408b34d4f3b60679e5a6f96251c97",
@ -1053,3 +1065,21 @@ def external_files():
sha256 = "76ea482b8da6bdb3d65d3b2ea989c1699c9fa0d6df0cb6d80863d1dc6fe7c4bd", sha256 = "76ea482b8da6bdb3d65d3b2ea989c1699c9fa0d6df0cb6d80863d1dc6fe7c4bd",
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668550490691823"], urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668550490691823"],
) )
http_file(
name = "com_google_mediapipe_mobilebert_tiny_assets_vocab_txt",
sha256 = "07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3",
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/assets/vocab.txt?generation=1673297970948751"],
)
http_file(
name = "com_google_mediapipe_mobilebert_tiny_variables_variables_data-00000-of-00001",
sha256 = "c3857370046cd3a2f345657cf1bb259a4e7e09185d7f0808e57803e9d41ebba4",
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/variables/variables.data-00000-of-00001?generation=1673297975132568"],
)
http_file(
name = "com_google_mediapipe_mobilebert_tiny_variables_variables_index",
sha256 = "4df4d7c0fefe99903ab6ebf44b7478196ce613082d2ca692a5a37a7f24e562ed",
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/variables/variables.index?generation=1673297977586840"],
)