Add download model on demand to text classifier

PiperOrigin-RevId: 508441452
This commit is contained in:
MediaPipe Team 2023-02-09 12:13:05 -08:00 committed by Copybara-Service
parent 28f728bed5
commit 99fc975f49
12 changed files with 67 additions and 105 deletions

View File

@ -32,14 +32,3 @@ mediapipe_files(
"mobilebert_tiny/variables/variables.index", "mobilebert_tiny/variables/variables.index",
], ],
) )
filegroup(
name = "mobilebert_tiny",
srcs = [
"mobilebert_tiny/assets/vocab.txt",
"mobilebert_tiny/keras_metadata.pb",
"mobilebert_tiny/saved_model.pb",
"mobilebert_tiny/variables/variables.data-00000-of-00001",
"mobilebert_tiny/variables/variables.index",
],
)

View File

@ -21,8 +21,6 @@ import tarfile
import tempfile import tempfile
import requests import requests
# resources dependency
_TEMPDIR_FOLDER = 'model_maker' _TEMPDIR_FOLDER = 'model_maker'
@ -97,29 +95,3 @@ class DownloadedFiles:
with open(absolute_path, 'wb') as f: with open(absolute_path, 'wb') as f:
f.write(r.content) f.write(r.content)
return str(absolute_path) return str(absolute_path)
# TODO Remove after text_classifier supports downloading on demand.
def get_absolute_path(file_path: str) -> str:
"""Gets the absolute path of a file in the model_maker directory.
Args:
file_path: The path to a file relative to the `mediapipe` dir
Returns:
The full path of the file
"""
# Extract the file path before and including 'model_maker' as the
# `mm_base_dir`. By joining it with the `path` after 'model_maker/', it
# yields to the absolute path of the model files directory. We must join
# on 'model_maker' because in the pypi package, the 'model_maker' directory
# is renamed to 'mediapipe_model_maker'. So we have to join on model_maker
# to ensure that the `mm_base_dir` path includes the renamed
# 'mediapipe_model_maker' directory.
cwd = os.path.dirname(__file__)
cwd_stop_idx = cwd.rfind('model_maker') + len('model_maker')
mm_base_dir = cwd[:cwd_stop_idx]
file_path_start_idx = file_path.find('model_maker') + len('model_maker') + 1
mm_relative_path = file_path[file_path_start_idx:]
absolute_path = os.path.join(mm_base_dir, mm_relative_path)
return absolute_path

View File

@ -74,11 +74,6 @@ class FileUtilTest(absltest.TestCase):
self.assertEqual(model_path, model_path_2) self.assertEqual(model_path, model_path_2)
self.assertEqual(mock_get.call_count, 1) self.assertEqual(mock_get.call_count, 1)
def test_get_absolute_path(self):
test_file = 'mediapipe/model_maker/python/core/utils/testdata/test.txt'
absolute_path = file_util.get_absolute_path(test_file)
self.assertTrue(os.path.exists(absolute_path))
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -31,5 +31,6 @@ py_library(
deps = [ deps = [
":bert_model_options", ":bert_model_options",
"//mediapipe/model_maker/python/core:hyperparameters", "//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/core/utils:file_util",
], ],
) )

View File

@ -17,6 +17,7 @@ import dataclasses
from typing import Dict from typing import Dict
from mediapipe.model_maker.python.core import hyperparameters as hp from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.utils import file_util
from mediapipe.model_maker.python.text.core import bert_model_options from mediapipe.model_maker.python.text.core import bert_model_options
_DEFAULT_TFLITE_INPUT_NAME = { _DEFAULT_TFLITE_INPUT_NAME = {
@ -34,16 +35,17 @@ class BertModelSpec:
Transformers for Language Understanding) for more details. Transformers for Language Understanding) for more details.
Attributes: Attributes:
downloaded_files: A DownloadedFiles object of the model files
hparams: Hyperparameters used for training. hparams: Hyperparameters used for training.
model_options: Configurable options for a BERT model. model_options: Configurable options for a BERT model.
do_lower_case: boolean, whether to lower case the input text. Should be do_lower_case: boolean, whether to lower case the input text. Should be
True / False for uncased / cased models respectively, where the models True / False for uncased / cased models respectively, where the models
are specified by the `uri`. are specified by `downloaded_files`.
tflite_input_name: Dict, input names for the TFLite model. tflite_input_name: Dict, input names for the TFLite model.
uri: URI for the BERT module.
name: The name of the object. name: The name of the object.
""" """
downloaded_files: file_util.DownloadedFiles
hparams: hp.BaseHParams = hp.BaseHParams( hparams: hp.BaseHParams = hp.BaseHParams(
epochs=3, epochs=3,
batch_size=32, batch_size=32,
@ -54,5 +56,4 @@ class BertModelSpec:
do_lower_case: bool = True do_lower_case: bool = True
tflite_input_name: Dict[str, str] = dataclasses.field( tflite_input_name: Dict[str, str] = dataclasses.field(
default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME) default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)
uri: str = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1'
name: str = 'Bert' name: str = 'Bert'

View File

@ -61,6 +61,7 @@ py_library(
py_test( py_test(
name = "model_spec_test", name = "model_spec_test",
srcs = ["model_spec_test.py"], srcs = ["model_spec_test.py"],
tags = ["requires-net:external"],
deps = [ deps = [
":model_options", ":model_options",
":model_spec", ":model_spec",
@ -89,9 +90,6 @@ py_library(
py_test( py_test(
name = "preprocessor_test", name = "preprocessor_test",
srcs = ["preprocessor_test.py"], srcs = ["preprocessor_test.py"],
data = [
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
],
tags = ["requires-net:external"], tags = ["requires-net:external"],
deps = [ deps = [
":dataset", ":dataset",
@ -113,9 +111,6 @@ py_library(
py_library( py_library(
name = "text_classifier", name = "text_classifier",
srcs = ["text_classifier.py"], srcs = ["text_classifier.py"],
data = [
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
],
deps = [ deps = [
":dataset", ":dataset",
":model_options", ":model_options",
@ -137,7 +132,6 @@ py_test(
size = "large", size = "large",
srcs = ["text_classifier_test.py"], srcs = ["text_classifier_test.py"],
data = [ data = [
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
"//mediapipe/model_maker/python/text/text_classifier/testdata", "//mediapipe/model_maker/python/text/text_classifier/testdata",
], ],
tags = [ tags = [
@ -163,9 +157,6 @@ py_library(
py_binary( py_binary(
name = "text_classifier_demo", name = "text_classifier_demo",
srcs = ["text_classifier_demo.py"], srcs = ["text_classifier_demo.py"],
data = [
"//mediapipe/model_maker/models/text_classifier:mobilebert_tiny",
],
deps = [ deps = [
":text_classifier_demo_lib", ":text_classifier_demo_lib",
], ],

View File

@ -25,7 +25,11 @@ from mediapipe.model_maker.python.text.text_classifier import model_options as m
# BERT-based text classifier spec inherited from BertModelSpec # BERT-based text classifier spec inherited from BertModelSpec
BertClassifierSpec = bert_model_spec.BertModelSpec BertClassifierSpec = bert_model_spec.BertModelSpec
MOBILEBERT_TINY_PATH = 'mediapipe/model_maker/models/text_classifier/mobilebert_tiny/' MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
'text_classifier/mobilebert_tiny',
'https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny.tar.gz',
is_folder=True,
)
@dataclasses.dataclass @dataclasses.dataclass
@ -51,11 +55,11 @@ average_word_embedding_classifier_spec = functools.partial(
mobilebert_classifier_spec = functools.partial( mobilebert_classifier_spec = functools.partial(
BertClassifierSpec, BertClassifierSpec,
downloaded_files=MOBILEBERT_TINY_FILES,
hparams=hp.BaseHParams( hparams=hp.BaseHParams(
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
), ),
name='MobileBert', name='MobileBert',
uri=file_util.get_absolute_path(MOBILEBERT_TINY_PATH),
tflite_input_name={ tflite_input_name={
'ids': 'serving_default_input_1:0', 'ids': 'serving_default_input_1:0',
'mask': 'serving_default_input_3:0', 'mask': 'serving_default_input_3:0',

View File

@ -14,6 +14,8 @@
"""Tests for model_spec.""" """Tests for model_spec."""
import os import os
import tempfile
from unittest import mock as unittest_mock
import tensorflow as tf import tensorflow as tf
@ -24,14 +26,24 @@ from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
class ModelSpecTest(tf.test.TestCase): class ModelSpecTest(tf.test.TestCase):
def setUp(self):
super().setUp()
# Mock tempfile.gettempdir() to be unique for each test to avoid race
# condition when downloading model since these tests may run in parallel.
mock_gettempdir = unittest_mock.patch.object(
tempfile,
'gettempdir',
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
def test_predefined_bert_spec(self): def test_predefined_bert_spec(self):
model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value() model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value()
self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec) self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec)
self.assertEqual(model_spec_obj.name, 'MobileBert') self.assertEqual(model_spec_obj.name, 'MobileBert')
self.assertIn( self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path()))
'mediapipe/model_maker/models/text_classifier/mobilebert_tiny',
model_spec_obj.uri,
)
self.assertTrue(model_spec_obj.do_lower_case) self.assertTrue(model_spec_obj.do_lower_case)
self.assertEqual( self.assertEqual(
model_spec_obj.tflite_input_name, { model_spec_obj.tflite_input_name, {

View File

@ -14,6 +14,8 @@
import csv import csv
import os import os
import tempfile
from unittest import mock as unittest_mock
import numpy as np import numpy as np
import numpy.testing as npt import numpy.testing as npt
@ -28,6 +30,19 @@ class PreprocessorTest(tf.test.TestCase):
CSV_PARAMS_ = text_classifier_ds.CSVParameters( CSV_PARAMS_ = text_classifier_ds.CSVParameters(
text_column='text', label_column='label') text_column='text', label_column='label')
def setUp(self):
super().setUp()
# Mock tempfile.gettempdir() to be unique for each test to avoid race
# condition when downloading model since these tests may run in parallel.
mock_gettempdir = unittest_mock.patch.object(
tempfile,
'gettempdir',
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
def _get_csv_file(self): def _get_csv_file(self):
labels_and_text = (('pos', 'super super super super good'), labels_and_text = (('pos', 'super super super super good'),
(('neg', 'really bad'))) (('neg', 'really bad')))
@ -71,7 +86,10 @@ class PreprocessorTest(tf.test.TestCase):
filename=csv_file, csv_params=self.CSV_PARAMS_) filename=csv_file, csv_params=self.CSV_PARAMS_)
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
bert_preprocessor = preprocessor.BertClassifierPreprocessor( bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=5, do_lower_case=bert_spec.do_lower_case, uri=bert_spec.uri) seq_len=5,
do_lower_case=bert_spec.do_lower_case,
uri=bert_spec.downloaded_files.get_path(),
)
preprocessed_dataset = bert_preprocessor.preprocess(dataset) preprocessed_dataset = bert_preprocessor.preprocess(dataset)
labels = [] labels = []
input_masks = [] input_masks = []

View File

@ -369,7 +369,8 @@ class _BertClassifier(TextClassifier):
self._text_preprocessor = preprocessor.BertClassifierPreprocessor( self._text_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=self._model_options.seq_len, seq_len=self._model_options.seq_len,
do_lower_case=self._model_spec.do_lower_case, do_lower_case=self._model_spec.do_lower_case,
uri=self._model_spec.uri) uri=self._model_spec.downloaded_files.get_path(),
)
return (self._text_preprocessor.preprocess(train_data), return (self._text_preprocessor.preprocess(train_data),
self._text_preprocessor.preprocess(validation_data)) self._text_preprocessor.preprocess(validation_data))
@ -388,7 +389,9 @@ class _BertClassifier(TextClassifier):
shape=(self._model_options.seq_len,), dtype=tf.int32), shape=(self._model_options.seq_len,), dtype=tf.int32),
) )
encoder = hub.KerasLayer( encoder = hub.KerasLayer(
self._model_spec.uri, trainable=self._model_options.do_fine_tuning) self._model_spec.downloaded_files.get_path(),
trainable=self._model_options.do_fine_tuning,
)
encoder_outputs = encoder(encoder_inputs) encoder_outputs = encoder(encoder_inputs)
pooled_output = encoder_outputs["pooled_output"] pooled_output = encoder_outputs["pooled_output"]

View File

@ -15,6 +15,8 @@
import csv import csv
import filecmp import filecmp
import os import os
import tempfile
from unittest import mock as unittest_mock
import tensorflow as tf import tensorflow as tf
@ -30,6 +32,19 @@ class TextClassifierTest(tf.test.TestCase):
'bert_metadata.json' 'bert_metadata.json'
) )
def setUp(self):
super().setUp()
# Mock tempfile.gettempdir() to be unique for each test to avoid race
# condition when downloading model since these tests may run in parallel.
mock_gettempdir = unittest_mock.patch.object(
tempfile,
'gettempdir',
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
def _get_data(self): def _get_data(self):
labels_and_text = (('pos', 'super good'), (('neg', 'really bad'))) labels_and_text = (('pos', 'super good'), (('neg', 'really bad')))
csv_file = os.path.join(self.get_temp_dir(), 'data.csv') csv_file = os.path.join(self.get_temp_dir(), 'data.csv')

View File

@ -18,8 +18,6 @@ Setup for Mediapipe-Model-Maker package with setuptools.
import glob import glob
import os import os
import shutil import shutil
import subprocess
import sys
import setuptools import setuptools
@ -40,16 +38,6 @@ def _parse_requirements(path):
] ]
def _copy_to_pip_src_dir(file):
"""Copy a file from bazel-bin to the pip_src dir."""
dst = file
dst_dir = os.path.dirname(dst)
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
src_file = os.path.join('../../bazel-bin/mediapipe/model_maker', file)
shutil.copyfile(src_file, file)
def _setup_build_dir(): def _setup_build_dir():
"""Setup the BUILD_DIR directory to build the mediapipe_model_maker package. """Setup the BUILD_DIR directory to build the mediapipe_model_maker package.
@ -80,33 +68,6 @@ def _setup_build_dir():
with open(build_target_file, 'w') as file: with open(build_target_file, 'w') as file:
file.write(filedata) file.write(filedata)
# Use bazel to download GCS model files
model_build_files = [
'models/text_classifier/BUILD',
]
for model_build_file in model_build_files:
build_target_file = os.path.join(BUILD_MM_DIR, model_build_file)
os.makedirs(os.path.dirname(build_target_file), exist_ok=True)
shutil.copy(model_build_file, build_target_file)
external_files = [
'models/text_classifier/mobilebert_tiny/keras_metadata.pb',
'models/text_classifier/mobilebert_tiny/saved_model.pb',
'models/text_classifier/mobilebert_tiny/assets/vocab.txt',
'models/text_classifier/mobilebert_tiny/variables/variables.data-00000-of-00001',
'models/text_classifier/mobilebert_tiny/variables/variables.index',
]
for elem in external_files:
external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem)
sys.stderr.write('downloading file: %s\n' % external_file)
fetch_model_command = [
'bazel',
'build',
external_file,
]
if subprocess.call(fetch_model_command) != 0:
sys.exit(-1)
_copy_to_pip_src_dir(external_file)
_setup_build_dir() _setup_build_dir()
setuptools.setup( setuptools.setup(