Adds smaller MobileBERT model.
PiperOrigin-RevId: 501451414
This commit is contained in:
parent
8830eefa0b
commit
9cbb76939d
45
mediapipe/model_maker/models/text_classifier/BUILD
Normal file
45
mediapipe/model_maker/models/text_classifier/BUILD
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
load(
|
||||||
|
"//mediapipe/framework/tool:mediapipe_files.bzl",
|
||||||
|
"mediapipe_files",
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe/model_maker/python/text/text_classifier:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_files(
|
||||||
|
srcs = [
|
||||||
|
"mobilebert_tiny/assets/vocab.txt",
|
||||||
|
"mobilebert_tiny/keras_metadata.pb",
|
||||||
|
"mobilebert_tiny/saved_model.pb",
|
||||||
|
"mobilebert_tiny/variables/variables.data-00000-of-00001",
|
||||||
|
"mobilebert_tiny/variables/variables.index",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "mobilebert_tiny",
|
||||||
|
srcs = [
|
||||||
|
"mobilebert_tiny/assets/vocab.txt",
|
||||||
|
"mobilebert_tiny/keras_metadata.pb",
|
||||||
|
"mobilebert_tiny/saved_model.pb",
|
||||||
|
"mobilebert_tiny/variables/variables.data-00000-of-00001",
|
||||||
|
"mobilebert_tiny/variables/variables.index",
|
||||||
|
],
|
||||||
|
)
|
|
@ -53,6 +53,7 @@ py_library(
|
||||||
deps = [
|
deps = [
|
||||||
":model_options",
|
":model_options",
|
||||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:file_util",
|
||||||
"//mediapipe/model_maker/python/text/core:bert_model_spec",
|
"//mediapipe/model_maker/python/text/core:bert_model_spec",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -88,6 +89,9 @@ py_library(
|
||||||
py_test(
|
py_test(
|
||||||
name = "preprocessor_test",
|
name = "preprocessor_test",
|
||||||
srcs = ["preprocessor_test.py"],
|
srcs = ["preprocessor_test.py"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
|
||||||
|
],
|
||||||
tags = ["requires-net:external"],
|
tags = ["requires-net:external"],
|
||||||
deps = [
|
deps = [
|
||||||
":dataset",
|
":dataset",
|
||||||
|
@ -109,6 +113,9 @@ py_library(
|
||||||
py_library(
|
py_library(
|
||||||
name = "text_classifier",
|
name = "text_classifier",
|
||||||
srcs = ["text_classifier.py"],
|
srcs = ["text_classifier.py"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":dataset",
|
":dataset",
|
||||||
":model_options",
|
":model_options",
|
||||||
|
@ -130,6 +137,7 @@ py_test(
|
||||||
size = "large",
|
size = "large",
|
||||||
srcs = ["text_classifier_test.py"],
|
srcs = ["text_classifier_test.py"],
|
||||||
data = [
|
data = [
|
||||||
|
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
|
||||||
"//mediapipe/model_maker/python/text/text_classifier/testdata",
|
"//mediapipe/model_maker/python/text/text_classifier/testdata",
|
||||||
],
|
],
|
||||||
tags = ["requires-net:external"],
|
tags = ["requires-net:external"],
|
||||||
|
@ -151,6 +159,9 @@ py_library(
|
||||||
py_binary(
|
py_binary(
|
||||||
name = "text_classifier_demo",
|
name = "text_classifier_demo",
|
||||||
srcs = ["text_classifier_demo.py"],
|
srcs = ["text_classifier_demo.py"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":text_classifier_demo_lib",
|
":text_classifier_demo_lib",
|
||||||
],
|
],
|
||||||
|
|
|
@ -18,12 +18,15 @@ import enum
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.core.utils import file_util
|
||||||
from mediapipe.model_maker.python.text.core import bert_model_spec
|
from mediapipe.model_maker.python.text.core import bert_model_spec
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||||
|
|
||||||
# BERT-based text classifier spec inherited from BertModelSpec
|
# BERT-based text classifier spec inherited from BertModelSpec
|
||||||
BertClassifierSpec = bert_model_spec.BertModelSpec
|
BertClassifierSpec = bert_model_spec.BertModelSpec
|
||||||
|
|
||||||
|
MOBILEBERT_TINY_PATH = 'mediapipe/model_maker/models/text_classifier/mobilebert_tiny/'
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class AverageWordEmbeddingClassifierSpec:
|
class AverageWordEmbeddingClassifierSpec:
|
||||||
|
@ -49,16 +52,14 @@ average_word_embedding_classifier_spec = functools.partial(
|
||||||
mobilebert_classifier_spec = functools.partial(
|
mobilebert_classifier_spec = functools.partial(
|
||||||
BertClassifierSpec,
|
BertClassifierSpec,
|
||||||
hparams=hp.BaseHParams(
|
hparams=hp.BaseHParams(
|
||||||
epochs=3,
|
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||||
batch_size=48,
|
),
|
||||||
learning_rate=3e-5,
|
|
||||||
distribution_strategy='off'),
|
|
||||||
name='MobileBert',
|
name='MobileBert',
|
||||||
uri='https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1',
|
uri=file_util.get_absolute_path(MOBILEBERT_TINY_PATH),
|
||||||
tflite_input_name={
|
tflite_input_name={
|
||||||
'ids': 'serving_default_input_1:0',
|
'ids': 'serving_default_input_1:0',
|
||||||
'mask': 'serving_default_input_3:0',
|
'mask': 'serving_default_input_3:0',
|
||||||
'segment_ids': 'serving_default_input_2:0'
|
'segment_ids': 'serving_default_input_2:0',
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -28,9 +28,10 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||||
self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec)
|
self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec)
|
||||||
self.assertEqual(model_spec_obj.name, 'MobileBert')
|
self.assertEqual(model_spec_obj.name, 'MobileBert')
|
||||||
self.assertEqual(
|
self.assertIn(
|
||||||
model_spec_obj.uri, 'https://tfhub.dev/tensorflow/'
|
'mediapipe/model_maker/models/text_classifier/mobilebert_tiny',
|
||||||
'mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1')
|
model_spec_obj.uri,
|
||||||
|
)
|
||||||
self.assertTrue(model_spec_obj.do_lower_case)
|
self.assertTrue(model_spec_obj.do_lower_case)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
model_spec_obj.tflite_input_name, {
|
model_spec_obj.tflite_input_name, {
|
||||||
|
|
|
@ -19,5 +19,8 @@ package(
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "testdata",
|
name = "testdata",
|
||||||
srcs = ["average_word_embedding_metadata.json"],
|
srcs = [
|
||||||
|
"average_word_embedding_metadata.json",
|
||||||
|
"bert_metadata.json",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
84
mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json
vendored
Normal file
84
mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json
vendored
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
{
|
||||||
|
"name": "TextClassifier",
|
||||||
|
"description": "Classify the input text into a set of known categories.",
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "ids",
|
||||||
|
"description": "Tokenized ids of the input text.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "mask",
|
||||||
|
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "segment_ids",
|
||||||
|
"description": "0 for the first sequence, 1 for the second sequence if exists.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"output_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "score",
|
||||||
|
"description": "Score of the labels respectively.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
"max": [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
"min": [
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"associated_files": [
|
||||||
|
{
|
||||||
|
"name": "labels.txt",
|
||||||
|
"description": "Labels for categories that the model can recognize.",
|
||||||
|
"type": "TENSOR_AXIS_LABELS"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"input_process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "BertTokenizerOptions",
|
||||||
|
"options": {
|
||||||
|
"vocab_file": [
|
||||||
|
{
|
||||||
|
"name": "vocab.txt",
|
||||||
|
"description": "Vocabulary file to convert natural language words to embedding vectors.",
|
||||||
|
"type": "VOCABULARY"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"min_parser_version": "1.1.0"
|
||||||
|
}
|
|
@ -269,16 +269,21 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
|
||||||
"""Creates an Average Word Embedding model."""
|
"""Creates an Average Word Embedding model."""
|
||||||
self._model = tf.keras.Sequential([
|
self._model = tf.keras.Sequential([
|
||||||
tf.keras.layers.InputLayer(
|
tf.keras.layers.InputLayer(
|
||||||
input_shape=[self._model_options.seq_len], dtype=tf.int32),
|
input_shape=[self._model_options.seq_len],
|
||||||
|
dtype=tf.int32,
|
||||||
|
name="input_ids",
|
||||||
|
),
|
||||||
tf.keras.layers.Embedding(
|
tf.keras.layers.Embedding(
|
||||||
len(self._text_preprocessor.get_vocab()),
|
len(self._text_preprocessor.get_vocab()),
|
||||||
self._model_options.wordvec_dim,
|
self._model_options.wordvec_dim,
|
||||||
input_length=self._model_options.seq_len),
|
input_length=self._model_options.seq_len,
|
||||||
|
),
|
||||||
tf.keras.layers.GlobalAveragePooling1D(),
|
tf.keras.layers.GlobalAveragePooling1D(),
|
||||||
tf.keras.layers.Dense(
|
tf.keras.layers.Dense(
|
||||||
self._model_options.wordvec_dim, activation=tf.nn.relu),
|
self._model_options.wordvec_dim, activation=tf.nn.relu
|
||||||
|
),
|
||||||
tf.keras.layers.Dropout(self._model_options.dropout_rate),
|
tf.keras.layers.Dropout(self._model_options.dropout_rate),
|
||||||
tf.keras.layers.Dense(self._num_classes, activation="softmax")
|
tf.keras.layers.Dense(self._num_classes, activation="softmax"),
|
||||||
])
|
])
|
||||||
|
|
||||||
def _save_vocab(self, vocab_filepath: str):
|
def _save_vocab(self, vocab_filepath: str):
|
||||||
|
|
|
@ -26,6 +26,9 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
|
|
||||||
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
||||||
test_utils.get_test_data_path('average_word_embedding_metadata.json'))
|
test_utils.get_test_data_path('average_word_embedding_metadata.json'))
|
||||||
|
_BERT_CLASSIFIER_JSON_FILE = test_utils.get_test_data_path(
|
||||||
|
'bert_metadata.json'
|
||||||
|
)
|
||||||
|
|
||||||
def _get_data(self):
|
def _get_data(self):
|
||||||
labels_and_text = (('pos', 'super good'), (('neg', 'really bad')))
|
labels_and_text = (('pos', 'super good'), (('neg', 'really bad')))
|
||||||
|
@ -94,7 +97,27 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
|
|
||||||
_, accuracy = bert_classifier.evaluate(validation_data)
|
_, accuracy = bert_classifier.evaluate(validation_data)
|
||||||
self.assertGreaterEqual(accuracy, 0.0)
|
self.assertGreaterEqual(accuracy, 0.0)
|
||||||
# TODO: Add a unit test that does not run OOM.
|
|
||||||
|
# Test export_model
|
||||||
|
bert_classifier.export_model()
|
||||||
|
output_metadata_file = os.path.join(
|
||||||
|
options.hparams.export_dir, 'metadata.json'
|
||||||
|
)
|
||||||
|
output_tflite_file = os.path.join(
|
||||||
|
options.hparams.export_dir, 'model.tflite'
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(os.path.exists(output_tflite_file))
|
||||||
|
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
||||||
|
|
||||||
|
self.assertTrue(os.path.exists(output_metadata_file))
|
||||||
|
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||||
|
filecmp.clear_cache()
|
||||||
|
self.assertTrue(
|
||||||
|
filecmp.cmp(
|
||||||
|
output_metadata_file, self._BERT_CLASSIFIER_JSON_FILE, shallow=False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def test_label_mismatch(self):
|
def test_label_mismatch(self):
|
||||||
options = (
|
options = (
|
||||||
|
|
|
@ -81,7 +81,10 @@ def _setup_build_dir():
|
||||||
file.write(filedata)
|
file.write(filedata)
|
||||||
|
|
||||||
# Use bazel to download GCS model files
|
# Use bazel to download GCS model files
|
||||||
model_build_files = ['models/gesture_recognizer/BUILD']
|
model_build_files = [
|
||||||
|
'models/gesture_recognizer/BUILD',
|
||||||
|
'models/text_classifier/BUILD',
|
||||||
|
]
|
||||||
for model_build_file in model_build_files:
|
for model_build_file in model_build_files:
|
||||||
build_target_file = os.path.join(BUILD_MM_DIR, model_build_file)
|
build_target_file = os.path.join(BUILD_MM_DIR, model_build_file)
|
||||||
os.makedirs(os.path.dirname(build_target_file), exist_ok=True)
|
os.makedirs(os.path.dirname(build_target_file), exist_ok=True)
|
||||||
|
@ -95,7 +98,12 @@ def _setup_build_dir():
|
||||||
'models/gesture_recognizer/gesture_embedder/saved_model.pb',
|
'models/gesture_recognizer/gesture_embedder/saved_model.pb',
|
||||||
'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001',
|
'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001',
|
||||||
'models/gesture_recognizer/gesture_embedder/variables/variables.index',
|
'models/gesture_recognizer/gesture_embedder/variables/variables.index',
|
||||||
]
|
'models/text_classifier/mobilebert_tiny/keras_metadata.pb',
|
||||||
|
'models/text_classifier/mobilebert_tiny/saved_model.pb',
|
||||||
|
'models/text_classifier/mobilebert_tiny/assets/vocab.txt',
|
||||||
|
'models/text_classifier/mobilebert_tiny/variables/variables.data-00000-of-00001',
|
||||||
|
'models/text_classifier/mobilebert_tiny/variables/variables.index',
|
||||||
|
]
|
||||||
for elem in external_files:
|
for elem in external_files:
|
||||||
external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem)
|
external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem)
|
||||||
sys.stderr.write('downloading file: %s\n' % external_file)
|
sys.stderr.write('downloading file: %s\n' % external_file)
|
||||||
|
|
30
third_party/external_files.bzl
vendored
30
third_party/external_files.bzl
vendored
|
@ -1006,6 +1006,18 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668550484904822"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668550484904822"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_mobilebert_tiny_keras_metadata_pb",
|
||||||
|
sha256 = "cef8131a414c602b9d4742ac57f4f90bc5d8a42baec36b65deece884e2d0cf0f",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/keras_metadata.pb?generation=1673297965144159"],
|
||||||
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_mobilebert_tiny_saved_model_pb",
|
||||||
|
sha256 = "323c997cd3e17df1b2e3bdebe3cfe2b17c5ffd9488a26a4afb59ee819196837a",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/saved_model.pb?generation=1673297968138825"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_object_detection_saved_model_model_ckpt_data-00000-of-00001",
|
name = "com_google_mediapipe_object_detection_saved_model_model_ckpt_data-00000-of-00001",
|
||||||
sha256 = "ad2f733f271dd5000a8c7f926bfea1083e6408b34d4f3b60679e5a6f96251c97",
|
sha256 = "ad2f733f271dd5000a8c7f926bfea1083e6408b34d4f3b60679e5a6f96251c97",
|
||||||
|
@ -1053,3 +1065,21 @@ def external_files():
|
||||||
sha256 = "76ea482b8da6bdb3d65d3b2ea989c1699c9fa0d6df0cb6d80863d1dc6fe7c4bd",
|
sha256 = "76ea482b8da6bdb3d65d3b2ea989c1699c9fa0d6df0cb6d80863d1dc6fe7c4bd",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668550490691823"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668550490691823"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_mobilebert_tiny_assets_vocab_txt",
|
||||||
|
sha256 = "07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/assets/vocab.txt?generation=1673297970948751"],
|
||||||
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_mobilebert_tiny_variables_variables_data-00000-of-00001",
|
||||||
|
sha256 = "c3857370046cd3a2f345657cf1bb259a4e7e09185d7f0808e57803e9d41ebba4",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/variables/variables.data-00000-of-00001?generation=1673297975132568"],
|
||||||
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_mobilebert_tiny_variables_variables_index",
|
||||||
|
sha256 = "4df4d7c0fefe99903ab6ebf44b7478196ce613082d2ca692a5a37a7f24e562ed",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny/variables/variables.index?generation=1673297977586840"],
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user