diff --git a/mediapipe/model_maker/models/text_classifier/BUILD b/mediapipe/model_maker/models/text_classifier/BUILD index 4c54bbccc..dc6210a7d 100644 --- a/mediapipe/model_maker/models/text_classifier/BUILD +++ b/mediapipe/model_maker/models/text_classifier/BUILD @@ -32,14 +32,3 @@ mediapipe_files( "mobilebert_tiny/variables/variables.index", ], ) - -filegroup( - name = "mobilebert_tiny", - srcs = [ - "mobilebert_tiny/assets/vocab.txt", - "mobilebert_tiny/keras_metadata.pb", - "mobilebert_tiny/saved_model.pb", - "mobilebert_tiny/variables/variables.data-00000-of-00001", - "mobilebert_tiny/variables/variables.index", - ], -) diff --git a/mediapipe/model_maker/python/core/utils/file_util.py b/mediapipe/model_maker/python/core/utils/file_util.py index 29d11ebbe..7871d90cb 100644 --- a/mediapipe/model_maker/python/core/utils/file_util.py +++ b/mediapipe/model_maker/python/core/utils/file_util.py @@ -21,8 +21,6 @@ import tarfile import tempfile import requests -# resources dependency - _TEMPDIR_FOLDER = 'model_maker' @@ -97,29 +95,3 @@ class DownloadedFiles: with open(absolute_path, 'wb') as f: f.write(r.content) return str(absolute_path) - - -# TODO Remove after text_classifier supports downloading on demand. -def get_absolute_path(file_path: str) -> str: - """Gets the absolute path of a file in the model_maker directory. - - Args: - file_path: The path to a file relative to the `mediapipe` dir - - Returns: - The full path of the file - """ - # Extract the file path before and including 'model_maker' as the - # `mm_base_dir`. By joining it with the `path` after 'model_maker/', it - # yields to the absolute path of the model files directory. We must join - # on 'model_maker' because in the pypi package, the 'model_maker' directory - # is renamed to 'mediapipe_model_maker'. So we have to join on model_maker - # to ensure that the `mm_base_dir` path includes the renamed - # 'mediapipe_model_maker' directory. - cwd = os.path.dirname(__file__) - cwd_stop_idx = cwd.rfind('model_maker') + len('model_maker') - mm_base_dir = cwd[:cwd_stop_idx] - file_path_start_idx = file_path.find('model_maker') + len('model_maker') + 1 - mm_relative_path = file_path[file_path_start_idx:] - absolute_path = os.path.join(mm_base_dir, mm_relative_path) - return absolute_path diff --git a/mediapipe/model_maker/python/core/utils/file_util_test.py b/mediapipe/model_maker/python/core/utils/file_util_test.py index f9f4a5954..027756ff0 100644 --- a/mediapipe/model_maker/python/core/utils/file_util_test.py +++ b/mediapipe/model_maker/python/core/utils/file_util_test.py @@ -74,11 +74,6 @@ class FileUtilTest(absltest.TestCase): self.assertEqual(model_path, model_path_2) self.assertEqual(mock_get.call_count, 1) - def test_get_absolute_path(self): - test_file = 'mediapipe/model_maker/python/core/utils/testdata/test.txt' - absolute_path = file_util.get_absolute_path(test_file) - self.assertTrue(os.path.exists(absolute_path)) - if __name__ == '__main__': absltest.main() diff --git a/mediapipe/model_maker/python/text/core/BUILD b/mediapipe/model_maker/python/text/core/BUILD index db06c3a75..e0c53491a 100644 --- a/mediapipe/model_maker/python/text/core/BUILD +++ b/mediapipe/model_maker/python/text/core/BUILD @@ -31,5 +31,6 @@ py_library( deps = [ ":bert_model_options", "//mediapipe/model_maker/python/core:hyperparameters", + "//mediapipe/model_maker/python/core/utils:file_util", ], ) diff --git a/mediapipe/model_maker/python/text/core/bert_model_spec.py b/mediapipe/model_maker/python/text/core/bert_model_spec.py index 6c0085617..605435df0 100644 --- a/mediapipe/model_maker/python/text/core/bert_model_spec.py +++ b/mediapipe/model_maker/python/text/core/bert_model_spec.py @@ -17,6 +17,7 @@ import dataclasses from typing import Dict from mediapipe.model_maker.python.core import hyperparameters as hp +from mediapipe.model_maker.python.core.utils import file_util from mediapipe.model_maker.python.text.core import bert_model_options _DEFAULT_TFLITE_INPUT_NAME = { @@ -34,16 +35,17 @@ class BertModelSpec: Transformers for Language Understanding) for more details. Attributes: + downloaded_files: A DownloadedFiles object of the model files hparams: Hyperparameters used for training. model_options: Configurable options for a BERT model. do_lower_case: boolean, whether to lower case the input text. Should be True / False for uncased / cased models respectively, where the models - are specified by the `uri`. + are specified by `downloaded_files`. tflite_input_name: Dict, input names for the TFLite model. - uri: URI for the BERT module. name: The name of the object. """ + downloaded_files: file_util.DownloadedFiles hparams: hp.BaseHParams = hp.BaseHParams( epochs=3, batch_size=32, @@ -54,5 +56,4 @@ class BertModelSpec: do_lower_case: bool = True tflite_input_name: Dict[str, str] = dataclasses.field( default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME) - uri: str = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1' name: str = 'Bert' diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index ac5b04f20..8083e0cb0 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -61,6 +61,7 @@ py_library( py_test( name = "model_spec_test", srcs = ["model_spec_test.py"], + tags = ["requires-net:external"], deps = [ ":model_options", ":model_spec", @@ -89,9 +90,6 @@ py_library( py_test( name = "preprocessor_test", srcs = ["preprocessor_test.py"], - data = [ - "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", - ], tags = ["requires-net:external"], deps = [ ":dataset", @@ -113,9 +111,6 @@ py_library( py_library( name = "text_classifier", srcs = ["text_classifier.py"], - data = [ - "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", - ], deps = [ ":dataset", ":model_options", @@ -137,7 +132,6 @@ py_test( size = "large", srcs = ["text_classifier_test.py"], data = [ - "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", "//mediapipe/model_maker/python/text/text_classifier/testdata", ], tags = [ @@ -163,9 +157,6 @@ py_library( py_binary( name = "text_classifier_demo", srcs = ["text_classifier_demo.py"], - data = [ - "//mediapipe/model_maker/models/text_classifier:mobilebert_tiny", - ], deps = [ ":text_classifier_demo_lib", ], diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec.py b/mediapipe/model_maker/python/text/text_classifier/model_spec.py index a6bdd9522..d999f6867 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec.py @@ -25,7 +25,11 @@ from mediapipe.model_maker.python.text.text_classifier import model_options as m # BERT-based text classifier spec inherited from BertModelSpec BertClassifierSpec = bert_model_spec.BertModelSpec -MOBILEBERT_TINY_PATH = 'mediapipe/model_maker/models/text_classifier/mobilebert_tiny/' +MOBILEBERT_TINY_FILES = file_util.DownloadedFiles( + 'text_classifier/mobilebert_tiny', + 'https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny.tar.gz', + is_folder=True, +) @dataclasses.dataclass @@ -51,11 +55,11 @@ average_word_embedding_classifier_spec = functools.partial( mobilebert_classifier_spec = functools.partial( BertClassifierSpec, + downloaded_files=MOBILEBERT_TINY_FILES, hparams=hp.BaseHParams( epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' ), name='MobileBert', - uri=file_util.get_absolute_path(MOBILEBERT_TINY_PATH), tflite_input_name={ 'ids': 'serving_default_input_1:0', 'mask': 'serving_default_input_3:0', diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py index 3ea019b44..c2d96bac4 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py @@ -14,6 +14,8 @@ """Tests for model_spec.""" import os +import tempfile +from unittest import mock as unittest_mock import tensorflow as tf @@ -24,14 +26,24 @@ from mediapipe.model_maker.python.text.text_classifier import model_spec as ms class ModelSpecTest(tf.test.TestCase): + def setUp(self): + super().setUp() + # Mock tempfile.gettempdir() to be unique for each test to avoid race + # condition when downloading model since these tests may run in parallel. + mock_gettempdir = unittest_mock.patch.object( + tempfile, + 'gettempdir', + return_value=self.create_tempdir(), + autospec=True, + ) + self.mock_gettempdir = mock_gettempdir.start() + self.addCleanup(mock_gettempdir.stop) + def test_predefined_bert_spec(self): model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value() self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec) self.assertEqual(model_spec_obj.name, 'MobileBert') - self.assertIn( - 'mediapipe/model_maker/models/text_classifier/mobilebert_tiny', - model_spec_obj.uri, - ) + self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path())) self.assertTrue(model_spec_obj.do_lower_case) self.assertEqual( model_spec_obj.tflite_input_name, { diff --git a/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py b/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py index b9558b2b5..2ddc4aea9 100644 --- a/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py @@ -14,6 +14,8 @@ import csv import os +import tempfile +from unittest import mock as unittest_mock import numpy as np import numpy.testing as npt @@ -28,6 +30,19 @@ class PreprocessorTest(tf.test.TestCase): CSV_PARAMS_ = text_classifier_ds.CSVParameters( text_column='text', label_column='label') + def setUp(self): + super().setUp() + # Mock tempfile.gettempdir() to be unique for each test to avoid race + # condition when downloading model since these tests may run in parallel. + mock_gettempdir = unittest_mock.patch.object( + tempfile, + 'gettempdir', + return_value=self.create_tempdir(), + autospec=True, + ) + self.mock_gettempdir = mock_gettempdir.start() + self.addCleanup(mock_gettempdir.stop) + def _get_csv_file(self): labels_and_text = (('pos', 'super super super super good'), (('neg', 'really bad'))) @@ -71,7 +86,10 @@ class PreprocessorTest(tf.test.TestCase): filename=csv_file, csv_params=self.CSV_PARAMS_) bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() bert_preprocessor = preprocessor.BertClassifierPreprocessor( - seq_len=5, do_lower_case=bert_spec.do_lower_case, uri=bert_spec.uri) + seq_len=5, + do_lower_case=bert_spec.do_lower_case, + uri=bert_spec.downloaded_files.get_path(), + ) preprocessed_dataset = bert_preprocessor.preprocess(dataset) labels = [] input_masks = [] diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index f6abc8bf0..3d932ce90 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -369,7 +369,8 @@ class _BertClassifier(TextClassifier): self._text_preprocessor = preprocessor.BertClassifierPreprocessor( seq_len=self._model_options.seq_len, do_lower_case=self._model_spec.do_lower_case, - uri=self._model_spec.uri) + uri=self._model_spec.downloaded_files.get_path(), + ) return (self._text_preprocessor.preprocess(train_data), self._text_preprocessor.preprocess(validation_data)) @@ -388,7 +389,9 @@ class _BertClassifier(TextClassifier): shape=(self._model_options.seq_len,), dtype=tf.int32), ) encoder = hub.KerasLayer( - self._model_spec.uri, trainable=self._model_options.do_fine_tuning) + self._model_spec.downloaded_files.get_path(), + trainable=self._model_options.do_fine_tuning, + ) encoder_outputs = encoder(encoder_inputs) pooled_output = encoder_outputs["pooled_output"] diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index 1ae2bc553..638e557fb 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -15,6 +15,8 @@ import csv import filecmp import os +import tempfile +from unittest import mock as unittest_mock import tensorflow as tf @@ -30,6 +32,19 @@ class TextClassifierTest(tf.test.TestCase): 'bert_metadata.json' ) + def setUp(self): + super().setUp() + # Mock tempfile.gettempdir() to be unique for each test to avoid race + # condition when downloading model since these tests may run in parallel. + mock_gettempdir = unittest_mock.patch.object( + tempfile, + 'gettempdir', + return_value=self.create_tempdir(), + autospec=True, + ) + self.mock_gettempdir = mock_gettempdir.start() + self.addCleanup(mock_gettempdir.stop) + def _get_data(self): labels_and_text = (('pos', 'super good'), (('neg', 'really bad'))) csv_file = os.path.join(self.get_temp_dir(), 'data.csv') diff --git a/mediapipe/model_maker/setup.py b/mediapipe/model_maker/setup.py index 63a1f2056..ccf633909 100644 --- a/mediapipe/model_maker/setup.py +++ b/mediapipe/model_maker/setup.py @@ -18,8 +18,6 @@ Setup for Mediapipe-Model-Maker package with setuptools. import glob import os import shutil -import subprocess -import sys import setuptools @@ -40,16 +38,6 @@ def _parse_requirements(path): ] -def _copy_to_pip_src_dir(file): - """Copy a file from bazel-bin to the pip_src dir.""" - dst = file - dst_dir = os.path.dirname(dst) - if not os.path.exists(dst_dir): - os.makedirs(dst_dir) - src_file = os.path.join('../../bazel-bin/mediapipe/model_maker', file) - shutil.copyfile(src_file, file) - - def _setup_build_dir(): """Setup the BUILD_DIR directory to build the mediapipe_model_maker package. @@ -80,33 +68,6 @@ def _setup_build_dir(): with open(build_target_file, 'w') as file: file.write(filedata) - # Use bazel to download GCS model files - model_build_files = [ - 'models/text_classifier/BUILD', - ] - for model_build_file in model_build_files: - build_target_file = os.path.join(BUILD_MM_DIR, model_build_file) - os.makedirs(os.path.dirname(build_target_file), exist_ok=True) - shutil.copy(model_build_file, build_target_file) - external_files = [ - 'models/text_classifier/mobilebert_tiny/keras_metadata.pb', - 'models/text_classifier/mobilebert_tiny/saved_model.pb', - 'models/text_classifier/mobilebert_tiny/assets/vocab.txt', - 'models/text_classifier/mobilebert_tiny/variables/variables.data-00000-of-00001', - 'models/text_classifier/mobilebert_tiny/variables/variables.index', - ] - for elem in external_files: - external_file = os.path.join(f'{SRC_NAME}/mediapipe_model_maker', elem) - sys.stderr.write('downloading file: %s\n' % external_file) - fetch_model_command = [ - 'bazel', - 'build', - external_file, - ] - if subprocess.call(fetch_model_command) != 0: - sys.exit(-1) - _copy_to_pip_src_dir(external_file) - _setup_build_dir() setuptools.setup(