Support downloading model files on-demand from GCS in model_maker

PiperOrigin-RevId: 506174708
This commit is contained in:
MediaPipe Team 2023-01-31 18:41:42 -08:00 committed by Copybara-Service
parent b53acf6267
commit d283e6a05a
3 changed files with 138 additions and 0 deletions

View File

@ -61,6 +61,7 @@ py_test(
name = "file_util_test",
srcs = ["file_util_test.py"],
data = ["//mediapipe/model_maker/python/core/utils/testdata"],
tags = ["requires-net:external"],
deps = [":file_util"],
)

View File

@ -13,11 +13,93 @@
# limitations under the License.
"""Utilities for files."""
import dataclasses
import os
import pathlib
import shutil
import tarfile
import tempfile
import requests
# resources dependency
_TEMPDIR_FOLDER = 'model_maker'
@dataclasses.dataclass
class DownloadedFiles:
"""File(s) that are downloaded from a url into a local directory.
If `is_folder` is True:
1. `path` should be a folder
2. `url` should point to a .tar.gz file which contains a single folder at
the root level.
Attributes:
path: Relative path in local directory.
url: GCS url to download the file(s).
is_folder: Whether the path and url represents a folder.
"""
path: str
url: str
is_folder: bool = False
def get_path(self) -> str:
"""Gets the path of files saved in a local directory.
If the path doesn't exist, this method will download the file(s) from the
provided url. The path is not cleaned up so it can be reused for subsequent
calls to the same path.
Folders are expected to be zipped in a .tar.gz file which will be extracted
into self.path in the local directory.
Raises:
RuntimeError: If the extracted folder does not have a singular root
directory.
Returns:
The absolute path to the downloaded file(s)
"""
tmpdir = tempfile.gettempdir()
absolute_path = pathlib.Path(
os.path.join(tmpdir, _TEMPDIR_FOLDER, self.path)
)
if not absolute_path.exists():
print(f'Downloading {self.url} to {absolute_path}')
r = requests.get(self.url, allow_redirects=True)
if self.is_folder:
# Use tempf to store the downloaded .tar.gz file
tempf = tempfile.NamedTemporaryFile(suffix='.tar.gz', mode='wb')
tempf.write(r.content)
tarf = tarfile.open(tempf.name)
# Use tmpdir to store the extracted contents of the .tar.gz file
with tempfile.TemporaryDirectory() as tmpdir:
tarf.extractall(tmpdir)
tarf.close()
tempf.close()
subdirs = os.listdir(tmpdir)
# Make sure tmpdir only has one subdirectory
if len(subdirs) > 1 or not os.path.isdir(
os.path.join(tmpdir, subdirs[0])
):
raise RuntimeError(
f"Extracted folder from {self.url} doesn't contain a "
f'single root directory: {subdirs}'
)
# Create the parent dir of absolute_path and copy the contents of the
# top level dir in the .tar.gz file into absolute_path.
pathlib.Path.mkdir(absolute_path.parent, parents=True, exist_ok=True)
shutil.copytree(os.path.join(tmpdir, subdirs[0]), absolute_path)
else:
pathlib.Path.mkdir(absolute_path.parent, parents=True, exist_ok=True)
with open(absolute_path, 'wb') as f:
f.write(r.content)
return str(absolute_path)
# TODO Remove after text_classifier supports downloading on demand.
def get_absolute_path(file_path: str) -> str:
"""Gets the absolute path of a file in the model_maker directory.

View File

@ -12,13 +12,68 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
from unittest import mock as unittest_mock
from absl.testing import absltest
import requests
from mediapipe.model_maker.python.core.utils import file_util
class FileUtilTest(absltest.TestCase):
def setUp(self):
super().setUp()
mock_gettempdir = unittest_mock.patch.object(
tempfile,
'gettempdir',
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
def test_get_path(self):
path = 'gesture_recognizer/hand_landmark_full.tflite'
url = 'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite'
downloaded_files = file_util.DownloadedFiles(path, url, is_folder=False)
model_path = downloaded_files.get_path()
self.assertTrue(os.path.exists(model_path))
self.assertGreater(os.path.getsize(model_path), 0)
def test_get_path_folder(self):
folder_contents = [
'keras_metadata.pb',
'saved_model.pb',
'assets/vocab.txt',
'variables/variables.data-00000-of-00001',
'variables/variables.index',
]
path = 'text_classifier/mobilebert_tiny'
url = (
'https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny.tar.gz'
)
downloaded_files = file_util.DownloadedFiles(path, url, is_folder=True)
model_path = downloaded_files.get_path()
self.assertTrue(os.path.exists(model_path))
for file_name in folder_contents:
file_path = os.path.join(model_path, file_name)
self.assertTrue(os.path.exists(file_path))
self.assertGreater(os.path.getsize(file_path), 0)
@unittest_mock.patch.object(requests, 'get', wraps=requests.get)
def test_get_path_multiple_calls(self, mock_get):
path = 'gesture_recognizer/hand_landmark_full.tflite'
url = 'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite'
downloaded_files = file_util.DownloadedFiles(path, url, is_folder=False)
model_path = downloaded_files.get_path()
self.assertTrue(os.path.exists(model_path))
self.assertGreater(os.path.getsize(model_path), 0)
model_path_2 = downloaded_files.get_path()
self.assertEqual(model_path, model_path_2)
self.assertEqual(mock_get.call_count, 1)
def test_get_absolute_path(self):
test_file = 'mediapipe/model_maker/python/core/utils/testdata/test.txt'
absolute_path = file_util.get_absolute_path(test_file)