Adds smaller MobileBERT model.
PiperOrigin-RevId: 501451414
This commit is contained in:
parent
8830eefa0b
commit
9cbb76939d
45
mediapipe/model_maker/models/text_classifier/BUILD
Normal file
45
mediapipe/model_maker/models/text_classifier/BUILD
Normal file
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"//mediapipe/framework/tool:mediapipe_files.bzl",
|
||||
"mediapipe_files",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe/model_maker/python/text/text_classifier:__subpackages__"],
|
||||
)
|
||||
|
||||
mediapipe_files(
|
||||
srcs = [
|
||||
"mobilebert_tiny/assets/vocab.txt",
|
||||
"mobilebert_tiny/keras_metadata.pb",
|
||||
"mobilebert_tiny/saved_model.pb",
|
||||
"mobilebert_tiny/variables/variables.data-00000-of-00001",
|
||||
"mobilebert_tiny/variables/variables.index",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "mobilebert_tiny",
|
||||
srcs = [
|
||||
"mobilebert_tiny/assets/vocab.txt",
|
||||
"mobilebert_tiny/keras_metadata.pb",
|
||||
"mobilebert_tiny/saved_model.pb",
|
||||
"mobilebert_tiny/variables/variables.data-00000-of-00001",
|
||||
"mobilebert_tiny/variables/variables.index",
|
||||
],
|
||||
)
|
|
@ -53,6 +53,7 @@ py_library(
|
|||
deps = [
|
||||
":model_options",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
"//mediapipe/model_maker/python/core/utils:file_util",
|
||||
"//mediapipe/model_maker/python/text/core:bert_model_spec",
|
||||
],
|
||||
)
|
||||
|
@ -88,6 +89,9 @@ py_library(
|
|||
py_test(
|
||||
name = "preprocessor_test",
|
||||
srcs = ["preprocessor_test.py"],
|
||||
data = [
|
||||
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
|
||||
],
|
||||
tags = ["requires-net:external"],
|
||||
deps = [
|
||||
":dataset",
|
||||
|
@ -109,6 +113,9 @@ py_library(
|
|||
py_library(
|
||||
name = "text_classifier",
|
||||
srcs = ["text_classifier.py"],
|
||||
data = [
|
||||
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
|
||||
],
|
||||
deps = [
|
||||
":dataset",
|
||||
":model_options",
|
||||
|
@ -130,6 +137,7 @@ py_test(
|
|||
size = "large",
|
||||
srcs = ["text_classifier_test.py"],
|
||||
data = [
|
||||
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
|
||||
"//mediapipe/model_maker/python/text/text_classifier/testdata",
|
||||
],
|
||||
tags = ["requires-net:external"],
|
||||
|
@ -151,6 +159,9 @@ py_library(
|
|||
py_binary(
|
||||
name = "text_classifier_demo",
|
||||
srcs = ["text_classifier_demo.py"],
|
||||
data = [
|
||||
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
|
||||
],
|
||||
deps = [
|
||||
":text_classifier_demo_lib",
|
||||
],
|
||||
|
|
|
@ -18,12 +18,15 @@ import enum
|
|||
import functools
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.core.utils import file_util
|
||||
from mediapipe.model_maker.python.text.core import bert_model_spec
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||
|
||||
# BERT-based text classifier spec inherited from BertModelSpec
|
||||
BertClassifierSpec = bert_model_spec.BertModelSpec
|
||||
|
||||
MOBILEBERT_TINY_PATH = 'mediapipe/model_maker/models/text_classifier/mobilebert_tiny/'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AverageWordEmbeddingClassifierSpec:
|
||||
|
@ -49,16 +52,14 @@ average_word_embedding_classifier_spec = functools.partial(
|
|||
mobilebert_classifier_spec = functools.partial(
|
||||
BertClassifierSpec,
|
||||
hparams=hp.BaseHParams(
|
||||
epochs=3,
|
||||
batch_size=48,
|
||||
learning_rate=3e-5,
|
||||
distribution_strategy='off'),
|
||||
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||
),
|
||||
name='MobileBert',
|
||||
uri='https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1',
|
||||
uri=file_util.get_absolute_path(MOBILEBERT_TINY_PATH),
|
||||
tflite_input_name={
|
||||
'ids': 'serving_default_input_1:0',
|
||||
'mask': 'serving_default_input_3:0',
|
||||
'segment_ids': 'serving_default_input_2:0'
|
||||
'segment_ids': 'serving_default_input_2:0',
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
@ -28,9 +28,10 @@ class ModelSpecTest(tf.test.TestCase):
|
|||
model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||
self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec)
|
||||
self.assertEqual(model_spec_obj.name, 'MobileBert')
|
||||
self.assertEqual(
|
||||
model_spec_obj.uri, 'https://tfhub.dev/tensorflow/'
|
||||
'mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1')
|
||||
self.assertIn(
|
||||
'mediapipe/model_maker/models/text_classifier/mobilebert_tiny',
|
||||
model_spec_obj.uri,
|
||||
)
|
||||
self.assertTrue(model_spec_obj.do_lower_case)
|
||||
self.assertEqual(
|
||||
model_spec_obj.tflite_input_name, {
|
||||
|
|
|
@ -19,5 +19,8 @@ package(
|
|||
|
||||
filegroup(
|
||||
name = "testdata",
|
||||
srcs = ["average_word_embedding_metadata.json"],
|
||||
srcs = [
|
||||
"average_word_embedding_metadata.json",
|
||||
"bert_metadata.json",
|
||||
],
|
||||
)
|
||||
|
|
84
mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json
vendored
Normal file
84
mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json
vendored
Normal file
|
@ -0,0 +1,84 @@
|
|||
{
|
||||
"name": "TextClassifier",
|
||||
"description": "Classify the input text into a set of known categories.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "ids",
|
||||
"description": "Tokenized ids of the input text.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "mask",
|
||||
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "segment_ids",
|
||||
"description": "0 for the first sequence, 1 for the second sequence if exists.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "score",
|
||||
"description": "Score of the labels respectively.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
0.0
|
||||
]
|
||||
},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
"description": "Labels for categories that the model can recognize.",
|
||||
"type": "TENSOR_AXIS_LABELS"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"input_process_units": [
|
||||
{
|
||||
"options_type": "BertTokenizerOptions",
|
||||
"options": {
|
||||
"vocab_file": [
|
||||
{
|
||||
"name": "vocab.txt",
|
||||
"description": "Vocabulary file to convert natural language words to embedding vectors.",
|
||||
"type": "VOCABULARY"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.1.0"
|
||||
}
|
|
@ -269,16 +269,21 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
|
|||
"""Creates an Average Word Embedding model."""
|
||||
self._model = tf.keras.Sequential([
|
||||
tf.keras.layers.InputLayer(
|
||||
input_shape=[self._model_options.seq_len], dtype=tf.int32),
|
||||
input_shape=[self._model_options.seq_len],
|
||||
dtype=tf.int32,
|
||||
name="input_ids",
|
||||
),
|
||||
tf.keras.layers.Embedding(
|
||||
len(self._text_preprocessor.get_vocab()),
|
||||
self._model_options.wordvec_dim,
|
||||
input_length=self._model_options.seq_len),
|
||||
input_length=self._model_options.seq_len,
|
||||
),
|
||||
tf.keras.layers.GlobalAveragePooling1D(),
|
||||
tf.keras.layers.Dense(
|
||||
self._model_options.wordvec_dim, activation=tf.nn.relu),
|
||||
self._model_options.wordvec_dim, activation=tf.nn.relu
|
||||
),
|
||||
tf.keras.layers.Dropout(self._model_options.dropout_rate),
|
||||
tf.keras.layers.Dense(self._num_classes, activation="softmax")
|
||||
tf.keras.layers.Dense(self._num_classes, activation="softmax"),
|
||||
])
|
||||
|
||||
def _save_vocab(self, vocab_filepath: str):
|
||||
|
|
|
@ -26,6 +26,9 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
|
||||
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
||||
test_utils.get_test_data_path('average_word_embedding_metadata.json'))
|
||||
_BERT_CLASSIFIER_JSON_FILE = test_utils.get_test_data_path(
|
||||
'bert_metadata.json'
|
||||
)
|
||||
|
||||
def _get_data(self):
|
||||
labels_and_text = (('pos', 'super good'), (('neg', 'really bad')))
|
||||
|
@ -94,7 +97,27 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
|
||||
_, accuracy = bert_classifier.evaluate(validation_data)
|
||||
self.assertGreaterEqual(accuracy, 0.0)
|
||||
# TODO: Add a unit test that does not run OOM.
|
||||
|
||||
# Test export_model
|
||||
bert_classifier.export_model()
|
||||
output_metadata_file = os.path.join(
|
||||
options.hparams.export_dir, 'metadata.json'
|
||||
)
|
||||
output_tflite_file = os.path.join(
|
||||
options.hparams.export_dir, 'model.tflite'
|
||||
)
|
||||
|
||||
self.assertTrue(os.path.exists(output_tflite_file))
|
||||
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
||||
|
||||
self.assertTrue(os.path.exists(output_metadata_file))
|
||||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||
filecmp.clear_cache()
|
||||
self.assertTrue(
|
||||
filecmp.cmp(
|
||||
output_metadata_file, self._BERT_CLASSIFIER_JSON_FILE, shallow=False
|
||||
)
|
||||
)
|
||||
|
||||
def test_label_mismatch(self):
|
||||
options = (
|
||||
|
|
|
@ -81,7 +81,10 @@ def _setup_build_dir():
|
|||
file.write(filedata)
|
||||
|
||||
# Use bazel to download GCS model files
|
||||
model_build_files = ['models/gesture_recognizer/BUILD']
|
||||
model_build_files = [
|
||||
'models/gesture_recognizer/BUILD',
|
||||
'models/text_classifier/BUILD',
|
||||
]
|
||||
for model_build_file in model_build_files:
|
||||
build_target_file = os.path.join(BUILD_MM_DIR, model_build_file)
|
||||
os.makedirs(os.path.dirname(build_target_file), exist_ok=True)
|
||||
|
@ -95,7 +98,12 @@ def _setup_build_dir():
|
|||
'models/gesture_recognizer/gesture_embedder/saved_model.pb',
|
||||
'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001',
|
||||
'models/gesture_recognizer/gesture_embedder/variables/variables.index',
|
||||
]
|
||||
'models/text_classifier/mobilebert_tiny/keras_metadata.pb',
|
||||
'models/text_classifier/mobilebert_tiny/saved_model.pb',
|
||||
'models/text_classifier/mobilebert_tiny/assets/vocab.txt',
|
||||
'models/text_classifier/mobilebert_tiny/variables/variables.data-00000-of-00001',
|
||||
'models/text_classifier/mobilebert_tiny/variables/variables.index',
|
||||
]
|
||||
for elem in external_files:
|
||||
external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem)
|
||||
sys.stderr.write('downloading file: %s\n' % external_file)
|
||||
|
|
30
third_party/external_files.bzl
vendored
30
third_party/external_files.bzl
vendored
|
@ -1006,6 +1006,18 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668550484904822"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilebert_tiny_keras_metadata_pb",
|
||||
sha256 = "cef8131a414c602b9d4742ac57f4f90bc5d8a42baec36b65deece884e2d0cf0f",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/keras_metadata.pb?generation=1673297965144159"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilebert_tiny_saved_model_pb",
|
||||
sha256 = "323c997cd3e17df1b2e3bdebe3cfe2b17c5ffd9488a26a4afb59ee819196837a",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/saved_model.pb?generation=1673297968138825"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_object_detection_saved_model_model_ckpt_data-00000-of-00001",
|
||||
sha256 = "ad2f733f271dd5000a8c7f926bfea1083e6408b34d4f3b60679e5a6f96251c97",
|
||||
|
@ -1053,3 +1065,21 @@ def external_files():
|
|||
sha256 = "76ea482b8da6bdb3d65d3b2ea989c1699c9fa0d6df0cb6d80863d1dc6fe7c4bd",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668550490691823"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilebert_tiny_assets_vocab_txt",
|
||||
sha256 = "07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/assets/vocab.txt?generation=1673297970948751"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilebert_tiny_variables_variables_data-00000-of-00001",
|
||||
sha256 = "c3857370046cd3a2f345657cf1bb259a4e7e09185d7f0808e57803e9d41ebba4",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/variables/variables.data-00000-of-00001?generation=1673297975132568"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilebert_tiny_variables_variables_index",
|
||||
sha256 = "4df4d7c0fefe99903ab6ebf44b7478196ce613082d2ca692a5a37a7f24e562ed",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/variables/variables.index?generation=1673297977586840"],
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user