Add download model on demand to text classifier
PiperOrigin-RevId: 508441452
This commit is contained in:
parent
28f728bed5
commit
99fc975f49
|
@ -32,14 +32,3 @@ mediapipe_files(
|
||||||
"mobilebert_tiny/variables/variables.index",
|
"mobilebert_tiny/variables/variables.index",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
filegroup(
|
|
||||||
name = "mobilebert_tiny",
|
|
||||||
srcs = [
|
|
||||||
"mobilebert_tiny/assets/vocab.txt",
|
|
||||||
"mobilebert_tiny/keras_metadata.pb",
|
|
||||||
"mobilebert_tiny/saved_model.pb",
|
|
||||||
"mobilebert_tiny/variables/variables.data-00000-of-00001",
|
|
||||||
"mobilebert_tiny/variables/variables.index",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
|
@ -21,8 +21,6 @@ import tarfile
|
||||||
import tempfile
|
import tempfile
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
# resources dependency
|
|
||||||
|
|
||||||
|
|
||||||
_TEMPDIR_FOLDER = 'model_maker'
|
_TEMPDIR_FOLDER = 'model_maker'
|
||||||
|
|
||||||
|
@ -97,29 +95,3 @@ class DownloadedFiles:
|
||||||
with open(absolute_path, 'wb') as f:
|
with open(absolute_path, 'wb') as f:
|
||||||
f.write(r.content)
|
f.write(r.content)
|
||||||
return str(absolute_path)
|
return str(absolute_path)
|
||||||
|
|
||||||
|
|
||||||
# TODO Remove after text_classifier supports downloading on demand.
|
|
||||||
def get_absolute_path(file_path: str) -> str:
|
|
||||||
"""Gets the absolute path of a file in the model_maker directory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: The path to a file relative to the `mediapipe` dir
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The full path of the file
|
|
||||||
"""
|
|
||||||
# Extract the file path before and including 'model_maker' as the
|
|
||||||
# `mm_base_dir`. By joining it with the `path` after 'model_maker/', it
|
|
||||||
# yields to the absolute path of the model files directory. We must join
|
|
||||||
# on 'model_maker' because in the pypi package, the 'model_maker' directory
|
|
||||||
# is renamed to 'mediapipe_model_maker'. So we have to join on model_maker
|
|
||||||
# to ensure that the `mm_base_dir` path includes the renamed
|
|
||||||
# 'mediapipe_model_maker' directory.
|
|
||||||
cwd = os.path.dirname(__file__)
|
|
||||||
cwd_stop_idx = cwd.rfind('model_maker') + len('model_maker')
|
|
||||||
mm_base_dir = cwd[:cwd_stop_idx]
|
|
||||||
file_path_start_idx = file_path.find('model_maker') + len('model_maker') + 1
|
|
||||||
mm_relative_path = file_path[file_path_start_idx:]
|
|
||||||
absolute_path = os.path.join(mm_base_dir, mm_relative_path)
|
|
||||||
return absolute_path
|
|
||||||
|
|
|
@ -74,11 +74,6 @@ class FileUtilTest(absltest.TestCase):
|
||||||
self.assertEqual(model_path, model_path_2)
|
self.assertEqual(model_path, model_path_2)
|
||||||
self.assertEqual(mock_get.call_count, 1)
|
self.assertEqual(mock_get.call_count, 1)
|
||||||
|
|
||||||
def test_get_absolute_path(self):
|
|
||||||
test_file = 'mediapipe/model_maker/python/core/utils/testdata/test.txt'
|
|
||||||
absolute_path = file_util.get_absolute_path(test_file)
|
|
||||||
self.assertTrue(os.path.exists(absolute_path))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
|
@ -31,5 +31,6 @@ py_library(
|
||||||
deps = [
|
deps = [
|
||||||
":bert_model_options",
|
":bert_model_options",
|
||||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:file_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,6 +17,7 @@ import dataclasses
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.core.utils import file_util
|
||||||
from mediapipe.model_maker.python.text.core import bert_model_options
|
from mediapipe.model_maker.python.text.core import bert_model_options
|
||||||
|
|
||||||
_DEFAULT_TFLITE_INPUT_NAME = {
|
_DEFAULT_TFLITE_INPUT_NAME = {
|
||||||
|
@ -34,16 +35,17 @@ class BertModelSpec:
|
||||||
Transformers for Language Understanding) for more details.
|
Transformers for Language Understanding) for more details.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
downloaded_files: A DownloadedFiles object of the model files
|
||||||
hparams: Hyperparameters used for training.
|
hparams: Hyperparameters used for training.
|
||||||
model_options: Configurable options for a BERT model.
|
model_options: Configurable options for a BERT model.
|
||||||
do_lower_case: boolean, whether to lower case the input text. Should be
|
do_lower_case: boolean, whether to lower case the input text. Should be
|
||||||
True / False for uncased / cased models respectively, where the models
|
True / False for uncased / cased models respectively, where the models
|
||||||
are specified by the `uri`.
|
are specified by `downloaded_files`.
|
||||||
tflite_input_name: Dict, input names for the TFLite model.
|
tflite_input_name: Dict, input names for the TFLite model.
|
||||||
uri: URI for the BERT module.
|
|
||||||
name: The name of the object.
|
name: The name of the object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
downloaded_files: file_util.DownloadedFiles
|
||||||
hparams: hp.BaseHParams = hp.BaseHParams(
|
hparams: hp.BaseHParams = hp.BaseHParams(
|
||||||
epochs=3,
|
epochs=3,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
|
@ -54,5 +56,4 @@ class BertModelSpec:
|
||||||
do_lower_case: bool = True
|
do_lower_case: bool = True
|
||||||
tflite_input_name: Dict[str, str] = dataclasses.field(
|
tflite_input_name: Dict[str, str] = dataclasses.field(
|
||||||
default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)
|
default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)
|
||||||
uri: str = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1'
|
|
||||||
name: str = 'Bert'
|
name: str = 'Bert'
|
||||||
|
|
|
@ -61,6 +61,7 @@ py_library(
|
||||||
py_test(
|
py_test(
|
||||||
name = "model_spec_test",
|
name = "model_spec_test",
|
||||||
srcs = ["model_spec_test.py"],
|
srcs = ["model_spec_test.py"],
|
||||||
|
tags = ["requires-net:external"],
|
||||||
deps = [
|
deps = [
|
||||||
":model_options",
|
":model_options",
|
||||||
":model_spec",
|
":model_spec",
|
||||||
|
@ -89,9 +90,6 @@ py_library(
|
||||||
py_test(
|
py_test(
|
||||||
name = "preprocessor_test",
|
name = "preprocessor_test",
|
||||||
srcs = ["preprocessor_test.py"],
|
srcs = ["preprocessor_test.py"],
|
||||||
data = [
|
|
||||||
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
|
|
||||||
],
|
|
||||||
tags = ["requires-net:external"],
|
tags = ["requires-net:external"],
|
||||||
deps = [
|
deps = [
|
||||||
":dataset",
|
":dataset",
|
||||||
|
@ -113,9 +111,6 @@ py_library(
|
||||||
py_library(
|
py_library(
|
||||||
name = "text_classifier",
|
name = "text_classifier",
|
||||||
srcs = ["text_classifier.py"],
|
srcs = ["text_classifier.py"],
|
||||||
data = [
|
|
||||||
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":dataset",
|
":dataset",
|
||||||
":model_options",
|
":model_options",
|
||||||
|
@ -137,7 +132,6 @@ py_test(
|
||||||
size = "large",
|
size = "large",
|
||||||
srcs = ["text_classifier_test.py"],
|
srcs = ["text_classifier_test.py"],
|
||||||
data = [
|
data = [
|
||||||
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
|
|
||||||
"//mediapipe/model_maker/python/text/text_classifier/testdata",
|
"//mediapipe/model_maker/python/text/text_classifier/testdata",
|
||||||
],
|
],
|
||||||
tags = [
|
tags = [
|
||||||
|
@ -163,9 +157,6 @@ py_library(
|
||||||
py_binary(
|
py_binary(
|
||||||
name = "text_classifier_demo",
|
name = "text_classifier_demo",
|
||||||
srcs = ["text_classifier_demo.py"],
|
srcs = ["text_classifier_demo.py"],
|
||||||
data = [
|
|
||||||
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":text_classifier_demo_lib",
|
":text_classifier_demo_lib",
|
||||||
],
|
],
|
||||||
|
|
|
@ -25,7 +25,11 @@ from mediapipe.model_maker.python.text.text_classifier import model_options as m
|
||||||
# BERT-based text classifier spec inherited from BertModelSpec
|
# BERT-based text classifier spec inherited from BertModelSpec
|
||||||
BertClassifierSpec = bert_model_spec.BertModelSpec
|
BertClassifierSpec = bert_model_spec.BertModelSpec
|
||||||
|
|
||||||
MOBILEBERT_TINY_PATH = 'mediapipe/model_maker/models/text_classifier/mobilebert_tiny/'
|
MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
|
||||||
|
'text_classifier/mobilebert_tiny',
|
||||||
|
'https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny.tar.gz',
|
||||||
|
is_folder=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
@ -51,11 +55,11 @@ average_word_embedding_classifier_spec = functools.partial(
|
||||||
|
|
||||||
mobilebert_classifier_spec = functools.partial(
|
mobilebert_classifier_spec = functools.partial(
|
||||||
BertClassifierSpec,
|
BertClassifierSpec,
|
||||||
|
downloaded_files=MOBILEBERT_TINY_FILES,
|
||||||
hparams=hp.BaseHParams(
|
hparams=hp.BaseHParams(
|
||||||
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||||
),
|
),
|
||||||
name='MobileBert',
|
name='MobileBert',
|
||||||
uri=file_util.get_absolute_path(MOBILEBERT_TINY_PATH),
|
|
||||||
tflite_input_name={
|
tflite_input_name={
|
||||||
'ids': 'serving_default_input_1:0',
|
'ids': 'serving_default_input_1:0',
|
||||||
'mask': 'serving_default_input_3:0',
|
'mask': 'serving_default_input_3:0',
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
"""Tests for model_spec."""
|
"""Tests for model_spec."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -24,14 +26,24 @@ from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
|
|
||||||
class ModelSpecTest(tf.test.TestCase):
|
class ModelSpecTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||||
|
# condition when downloading model since these tests may run in parallel.
|
||||||
|
mock_gettempdir = unittest_mock.patch.object(
|
||||||
|
tempfile,
|
||||||
|
'gettempdir',
|
||||||
|
return_value=self.create_tempdir(),
|
||||||
|
autospec=True,
|
||||||
|
)
|
||||||
|
self.mock_gettempdir = mock_gettempdir.start()
|
||||||
|
self.addCleanup(mock_gettempdir.stop)
|
||||||
|
|
||||||
def test_predefined_bert_spec(self):
|
def test_predefined_bert_spec(self):
|
||||||
model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||||
self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec)
|
self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec)
|
||||||
self.assertEqual(model_spec_obj.name, 'MobileBert')
|
self.assertEqual(model_spec_obj.name, 'MobileBert')
|
||||||
self.assertIn(
|
self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path()))
|
||||||
'mediapipe/model_maker/models/text_classifier/mobilebert_tiny',
|
|
||||||
model_spec_obj.uri,
|
|
||||||
)
|
|
||||||
self.assertTrue(model_spec_obj.do_lower_case)
|
self.assertTrue(model_spec_obj.do_lower_case)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
model_spec_obj.tflite_input_name, {
|
model_spec_obj.tflite_input_name, {
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.testing as npt
|
import numpy.testing as npt
|
||||||
|
@ -28,6 +30,19 @@ class PreprocessorTest(tf.test.TestCase):
|
||||||
CSV_PARAMS_ = text_classifier_ds.CSVParameters(
|
CSV_PARAMS_ = text_classifier_ds.CSVParameters(
|
||||||
text_column='text', label_column='label')
|
text_column='text', label_column='label')
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||||
|
# condition when downloading model since these tests may run in parallel.
|
||||||
|
mock_gettempdir = unittest_mock.patch.object(
|
||||||
|
tempfile,
|
||||||
|
'gettempdir',
|
||||||
|
return_value=self.create_tempdir(),
|
||||||
|
autospec=True,
|
||||||
|
)
|
||||||
|
self.mock_gettempdir = mock_gettempdir.start()
|
||||||
|
self.addCleanup(mock_gettempdir.stop)
|
||||||
|
|
||||||
def _get_csv_file(self):
|
def _get_csv_file(self):
|
||||||
labels_and_text = (('pos', 'super super super super good'),
|
labels_and_text = (('pos', 'super super super super good'),
|
||||||
(('neg', 'really bad')))
|
(('neg', 'really bad')))
|
||||||
|
@ -71,7 +86,10 @@ class PreprocessorTest(tf.test.TestCase):
|
||||||
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
||||||
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||||
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||||
seq_len=5, do_lower_case=bert_spec.do_lower_case, uri=bert_spec.uri)
|
seq_len=5,
|
||||||
|
do_lower_case=bert_spec.do_lower_case,
|
||||||
|
uri=bert_spec.downloaded_files.get_path(),
|
||||||
|
)
|
||||||
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
||||||
labels = []
|
labels = []
|
||||||
input_masks = []
|
input_masks = []
|
||||||
|
|
|
@ -369,7 +369,8 @@ class _BertClassifier(TextClassifier):
|
||||||
self._text_preprocessor = preprocessor.BertClassifierPreprocessor(
|
self._text_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||||
seq_len=self._model_options.seq_len,
|
seq_len=self._model_options.seq_len,
|
||||||
do_lower_case=self._model_spec.do_lower_case,
|
do_lower_case=self._model_spec.do_lower_case,
|
||||||
uri=self._model_spec.uri)
|
uri=self._model_spec.downloaded_files.get_path(),
|
||||||
|
)
|
||||||
return (self._text_preprocessor.preprocess(train_data),
|
return (self._text_preprocessor.preprocess(train_data),
|
||||||
self._text_preprocessor.preprocess(validation_data))
|
self._text_preprocessor.preprocess(validation_data))
|
||||||
|
|
||||||
|
@ -388,7 +389,9 @@ class _BertClassifier(TextClassifier):
|
||||||
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
||||||
)
|
)
|
||||||
encoder = hub.KerasLayer(
|
encoder = hub.KerasLayer(
|
||||||
self._model_spec.uri, trainable=self._model_options.do_fine_tuning)
|
self._model_spec.downloaded_files.get_path(),
|
||||||
|
trainable=self._model_options.do_fine_tuning,
|
||||||
|
)
|
||||||
encoder_outputs = encoder(encoder_inputs)
|
encoder_outputs = encoder(encoder_inputs)
|
||||||
pooled_output = encoder_outputs["pooled_output"]
|
pooled_output = encoder_outputs["pooled_output"]
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,8 @@
|
||||||
import csv
|
import csv
|
||||||
import filecmp
|
import filecmp
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -30,6 +32,19 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
'bert_metadata.json'
|
'bert_metadata.json'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||||
|
# condition when downloading model since these tests may run in parallel.
|
||||||
|
mock_gettempdir = unittest_mock.patch.object(
|
||||||
|
tempfile,
|
||||||
|
'gettempdir',
|
||||||
|
return_value=self.create_tempdir(),
|
||||||
|
autospec=True,
|
||||||
|
)
|
||||||
|
self.mock_gettempdir = mock_gettempdir.start()
|
||||||
|
self.addCleanup(mock_gettempdir.stop)
|
||||||
|
|
||||||
def _get_data(self):
|
def _get_data(self):
|
||||||
labels_and_text = (('pos', 'super good'), (('neg', 'really bad')))
|
labels_and_text = (('pos', 'super good'), (('neg', 'really bad')))
|
||||||
csv_file = os.path.join(self.get_temp_dir(), 'data.csv')
|
csv_file = os.path.join(self.get_temp_dir(), 'data.csv')
|
||||||
|
|
|
@ -18,8 +18,6 @@ Setup for Mediapipe-Model-Maker package with setuptools.
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import setuptools
|
import setuptools
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,16 +38,6 @@ def _parse_requirements(path):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _copy_to_pip_src_dir(file):
|
|
||||||
"""Copy a file from bazel-bin to the pip_src dir."""
|
|
||||||
dst = file
|
|
||||||
dst_dir = os.path.dirname(dst)
|
|
||||||
if not os.path.exists(dst_dir):
|
|
||||||
os.makedirs(dst_dir)
|
|
||||||
src_file = os.path.join('../../bazel-bin/mediapipe/model_maker', file)
|
|
||||||
shutil.copyfile(src_file, file)
|
|
||||||
|
|
||||||
|
|
||||||
def _setup_build_dir():
|
def _setup_build_dir():
|
||||||
"""Setup the BUILD_DIR directory to build the mediapipe_model_maker package.
|
"""Setup the BUILD_DIR directory to build the mediapipe_model_maker package.
|
||||||
|
|
||||||
|
@ -80,33 +68,6 @@ def _setup_build_dir():
|
||||||
with open(build_target_file, 'w') as file:
|
with open(build_target_file, 'w') as file:
|
||||||
file.write(filedata)
|
file.write(filedata)
|
||||||
|
|
||||||
# Use bazel to download GCS model files
|
|
||||||
model_build_files = [
|
|
||||||
'models/text_classifier/BUILD',
|
|
||||||
]
|
|
||||||
for model_build_file in model_build_files:
|
|
||||||
build_target_file = os.path.join(BUILD_MM_DIR, model_build_file)
|
|
||||||
os.makedirs(os.path.dirname(build_target_file), exist_ok=True)
|
|
||||||
shutil.copy(model_build_file, build_target_file)
|
|
||||||
external_files = [
|
|
||||||
'models/text_classifier/mobilebert_tiny/keras_metadata.pb',
|
|
||||||
'models/text_classifier/mobilebert_tiny/saved_model.pb',
|
|
||||||
'models/text_classifier/mobilebert_tiny/assets/vocab.txt',
|
|
||||||
'models/text_classifier/mobilebert_tiny/variables/variables.data-00000-of-00001',
|
|
||||||
'models/text_classifier/mobilebert_tiny/variables/variables.index',
|
|
||||||
]
|
|
||||||
for elem in external_files:
|
|
||||||
external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem)
|
|
||||||
sys.stderr.write('downloading file: %s\n' % external_file)
|
|
||||||
fetch_model_command = [
|
|
||||||
'bazel',
|
|
||||||
'build',
|
|
||||||
external_file,
|
|
||||||
]
|
|
||||||
if subprocess.call(fetch_model_command) != 0:
|
|
||||||
sys.exit(-1)
|
|
||||||
_copy_to_pip_src_dir(external_file)
|
|
||||||
|
|
||||||
_setup_build_dir()
|
_setup_build_dir()
|
||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user