From d283e6a05abcba303884f1f7232c1ac64597554b Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 31 Jan 2023 18:41:42 -0800 Subject: [PATCH] Support downloading model files on-demand from GCS in model_maker PiperOrigin-RevId: 506174708 --- mediapipe/model_maker/python/core/utils/BUILD | 1 + .../python/core/utils/file_util.py | 82 +++++++++++++++++++ .../python/core/utils/file_util_test.py | 55 +++++++++++++ 3 files changed, 138 insertions(+) diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index 492bba0a9..3c9107dba 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -61,6 +61,7 @@ py_test( name = "file_util_test", srcs = ["file_util_test.py"], data = ["//mediapipe/model_maker/python/core/utils/testdata"], + tags = ["requires-net:external"], deps = [":file_util"], ) diff --git a/mediapipe/model_maker/python/core/utils/file_util.py b/mediapipe/model_maker/python/core/utils/file_util.py index 66addad54..29d11ebbe 100644 --- a/mediapipe/model_maker/python/core/utils/file_util.py +++ b/mediapipe/model_maker/python/core/utils/file_util.py @@ -13,11 +13,93 @@ # limitations under the License. """Utilities for files.""" +import dataclasses import os +import pathlib +import shutil +import tarfile +import tempfile +import requests # resources dependency +_TEMPDIR_FOLDER = 'model_maker' + + +@dataclasses.dataclass +class DownloadedFiles: + """File(s) that are downloaded from a url into a local directory. + + If `is_folder` is True: + 1. `path` should be a folder + 2. `url` should point to a .tar.gz file which contains a single folder at + the root level. + + Attributes: + path: Relative path in local directory. + url: GCS url to download the file(s). + is_folder: Whether the path and url represents a folder. + """ + + path: str + url: str + is_folder: bool = False + + def get_path(self) -> str: + """Gets the path of files saved in a local directory. + + If the path doesn't exist, this method will download the file(s) from the + provided url. The path is not cleaned up so it can be reused for subsequent + calls to the same path. + Folders are expected to be zipped in a .tar.gz file which will be extracted + into self.path in the local directory. + + Raises: + RuntimeError: If the extracted folder does not have a singular root + directory. + + Returns: + The absolute path to the downloaded file(s) + """ + tmpdir = tempfile.gettempdir() + absolute_path = pathlib.Path( + os.path.join(tmpdir, _TEMPDIR_FOLDER, self.path) + ) + if not absolute_path.exists(): + print(f'Downloading {self.url} to {absolute_path}') + r = requests.get(self.url, allow_redirects=True) + if self.is_folder: + # Use tempf to store the downloaded .tar.gz file + tempf = tempfile.NamedTemporaryFile(suffix='.tar.gz', mode='wb') + tempf.write(r.content) + tarf = tarfile.open(tempf.name) + # Use tmpdir to store the extracted contents of the .tar.gz file + with tempfile.TemporaryDirectory() as tmpdir: + tarf.extractall(tmpdir) + tarf.close() + tempf.close() + subdirs = os.listdir(tmpdir) + # Make sure tmpdir only has one subdirectory + if len(subdirs) > 1 or not os.path.isdir( + os.path.join(tmpdir, subdirs[0]) + ): + raise RuntimeError( + f"Extracted folder from {self.url} doesn't contain a " + f'single root directory: {subdirs}' + ) + # Create the parent dir of absolute_path and copy the contents of the + # top level dir in the .tar.gz file into absolute_path. + pathlib.Path.mkdir(absolute_path.parent, parents=True, exist_ok=True) + shutil.copytree(os.path.join(tmpdir, subdirs[0]), absolute_path) + else: + pathlib.Path.mkdir(absolute_path.parent, parents=True, exist_ok=True) + with open(absolute_path, 'wb') as f: + f.write(r.content) + return str(absolute_path) + + +# TODO Remove after text_classifier supports downloading on demand. def get_absolute_path(file_path: str) -> str: """Gets the absolute path of a file in the model_maker directory. diff --git a/mediapipe/model_maker/python/core/utils/file_util_test.py b/mediapipe/model_maker/python/core/utils/file_util_test.py index 4a2d6dcfb..f9f4a5954 100644 --- a/mediapipe/model_maker/python/core/utils/file_util_test.py +++ b/mediapipe/model_maker/python/core/utils/file_util_test.py @@ -12,13 +12,68 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import tempfile +from unittest import mock as unittest_mock from absl.testing import absltest +import requests + from mediapipe.model_maker.python.core.utils import file_util class FileUtilTest(absltest.TestCase): + def setUp(self): + super().setUp() + mock_gettempdir = unittest_mock.patch.object( + tempfile, + 'gettempdir', + return_value=self.create_tempdir(), + autospec=True, + ) + self.mock_gettempdir = mock_gettempdir.start() + self.addCleanup(mock_gettempdir.stop) + + def test_get_path(self): + path = 'gesture_recognizer/hand_landmark_full.tflite' + url = 'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite' + downloaded_files = file_util.DownloadedFiles(path, url, is_folder=False) + model_path = downloaded_files.get_path() + self.assertTrue(os.path.exists(model_path)) + self.assertGreater(os.path.getsize(model_path), 0) + + def test_get_path_folder(self): + folder_contents = [ + 'keras_metadata.pb', + 'saved_model.pb', + 'assets/vocab.txt', + 'variables/variables.data-00000-of-00001', + 'variables/variables.index', + ] + path = 'text_classifier/mobilebert_tiny' + url = ( + 'https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny.tar.gz' + ) + downloaded_files = file_util.DownloadedFiles(path, url, is_folder=True) + model_path = downloaded_files.get_path() + self.assertTrue(os.path.exists(model_path)) + for file_name in folder_contents: + file_path = os.path.join(model_path, file_name) + self.assertTrue(os.path.exists(file_path)) + self.assertGreater(os.path.getsize(file_path), 0) + + @unittest_mock.patch.object(requests, 'get', wraps=requests.get) + def test_get_path_multiple_calls(self, mock_get): + path = 'gesture_recognizer/hand_landmark_full.tflite' + url = 'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite' + downloaded_files = file_util.DownloadedFiles(path, url, is_folder=False) + model_path = downloaded_files.get_path() + self.assertTrue(os.path.exists(model_path)) + self.assertGreater(os.path.getsize(model_path), 0) + model_path_2 = downloaded_files.get_path() + self.assertEqual(model_path, model_path_2) + self.assertEqual(mock_get.call_count, 1) + def test_get_absolute_path(self): test_file = 'mediapipe/model_maker/python/core/utils/testdata/test.txt' absolute_path = file_util.get_absolute_path(test_file)