Support downloading model files on-demand from GCS in model_maker
PiperOrigin-RevId: 506174708
This commit is contained in:
parent
b53acf6267
commit
d283e6a05a
|
@ -61,6 +61,7 @@ py_test(
|
|||
name = "file_util_test",
|
||||
srcs = ["file_util_test.py"],
|
||||
data = ["//mediapipe/model_maker/python/core/utils/testdata"],
|
||||
tags = ["requires-net:external"],
|
||||
deps = [":file_util"],
|
||||
)
|
||||
|
||||
|
|
|
@ -13,11 +13,93 @@
|
|||
# limitations under the License.
|
||||
"""Utilities for files."""
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
import requests
|
||||
|
||||
# resources dependency
|
||||
|
||||
|
||||
_TEMPDIR_FOLDER = 'model_maker'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DownloadedFiles:
|
||||
"""File(s) that are downloaded from a url into a local directory.
|
||||
|
||||
If `is_folder` is True:
|
||||
1. `path` should be a folder
|
||||
2. `url` should point to a .tar.gz file which contains a single folder at
|
||||
the root level.
|
||||
|
||||
Attributes:
|
||||
path: Relative path in local directory.
|
||||
url: GCS url to download the file(s).
|
||||
is_folder: Whether the path and url represents a folder.
|
||||
"""
|
||||
|
||||
path: str
|
||||
url: str
|
||||
is_folder: bool = False
|
||||
|
||||
def get_path(self) -> str:
|
||||
"""Gets the path of files saved in a local directory.
|
||||
|
||||
If the path doesn't exist, this method will download the file(s) from the
|
||||
provided url. The path is not cleaned up so it can be reused for subsequent
|
||||
calls to the same path.
|
||||
Folders are expected to be zipped in a .tar.gz file which will be extracted
|
||||
into self.path in the local directory.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the extracted folder does not have a singular root
|
||||
directory.
|
||||
|
||||
Returns:
|
||||
The absolute path to the downloaded file(s)
|
||||
"""
|
||||
tmpdir = tempfile.gettempdir()
|
||||
absolute_path = pathlib.Path(
|
||||
os.path.join(tmpdir, _TEMPDIR_FOLDER, self.path)
|
||||
)
|
||||
if not absolute_path.exists():
|
||||
print(f'Downloading {self.url} to {absolute_path}')
|
||||
r = requests.get(self.url, allow_redirects=True)
|
||||
if self.is_folder:
|
||||
# Use tempf to store the downloaded .tar.gz file
|
||||
tempf = tempfile.NamedTemporaryFile(suffix='.tar.gz', mode='wb')
|
||||
tempf.write(r.content)
|
||||
tarf = tarfile.open(tempf.name)
|
||||
# Use tmpdir to store the extracted contents of the .tar.gz file
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tarf.extractall(tmpdir)
|
||||
tarf.close()
|
||||
tempf.close()
|
||||
subdirs = os.listdir(tmpdir)
|
||||
# Make sure tmpdir only has one subdirectory
|
||||
if len(subdirs) > 1 or not os.path.isdir(
|
||||
os.path.join(tmpdir, subdirs[0])
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Extracted folder from {self.url} doesn't contain a "
|
||||
f'single root directory: {subdirs}'
|
||||
)
|
||||
# Create the parent dir of absolute_path and copy the contents of the
|
||||
# top level dir in the .tar.gz file into absolute_path.
|
||||
pathlib.Path.mkdir(absolute_path.parent, parents=True, exist_ok=True)
|
||||
shutil.copytree(os.path.join(tmpdir, subdirs[0]), absolute_path)
|
||||
else:
|
||||
pathlib.Path.mkdir(absolute_path.parent, parents=True, exist_ok=True)
|
||||
with open(absolute_path, 'wb') as f:
|
||||
f.write(r.content)
|
||||
return str(absolute_path)
|
||||
|
||||
|
||||
# TODO Remove after text_classifier supports downloading on demand.
|
||||
def get_absolute_path(file_path: str) -> str:
|
||||
"""Gets the absolute path of a file in the model_maker directory.
|
||||
|
||||
|
|
|
@ -12,13 +12,68 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import tempfile
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
from absl.testing import absltest
|
||||
import requests
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import file_util
|
||||
|
||||
|
||||
class FileUtilTest(absltest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
mock_gettempdir = unittest_mock.patch.object(
|
||||
tempfile,
|
||||
'gettempdir',
|
||||
return_value=self.create_tempdir(),
|
||||
autospec=True,
|
||||
)
|
||||
self.mock_gettempdir = mock_gettempdir.start()
|
||||
self.addCleanup(mock_gettempdir.stop)
|
||||
|
||||
def test_get_path(self):
|
||||
path = 'gesture_recognizer/hand_landmark_full.tflite'
|
||||
url = 'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite'
|
||||
downloaded_files = file_util.DownloadedFiles(path, url, is_folder=False)
|
||||
model_path = downloaded_files.get_path()
|
||||
self.assertTrue(os.path.exists(model_path))
|
||||
self.assertGreater(os.path.getsize(model_path), 0)
|
||||
|
||||
def test_get_path_folder(self):
|
||||
folder_contents = [
|
||||
'keras_metadata.pb',
|
||||
'saved_model.pb',
|
||||
'assets/vocab.txt',
|
||||
'variables/variables.data-00000-of-00001',
|
||||
'variables/variables.index',
|
||||
]
|
||||
path = 'text_classifier/mobilebert_tiny'
|
||||
url = (
|
||||
'https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny.tar.gz'
|
||||
)
|
||||
downloaded_files = file_util.DownloadedFiles(path, url, is_folder=True)
|
||||
model_path = downloaded_files.get_path()
|
||||
self.assertTrue(os.path.exists(model_path))
|
||||
for file_name in folder_contents:
|
||||
file_path = os.path.join(model_path, file_name)
|
||||
self.assertTrue(os.path.exists(file_path))
|
||||
self.assertGreater(os.path.getsize(file_path), 0)
|
||||
|
||||
@unittest_mock.patch.object(requests, 'get', wraps=requests.get)
|
||||
def test_get_path_multiple_calls(self, mock_get):
|
||||
path = 'gesture_recognizer/hand_landmark_full.tflite'
|
||||
url = 'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite'
|
||||
downloaded_files = file_util.DownloadedFiles(path, url, is_folder=False)
|
||||
model_path = downloaded_files.get_path()
|
||||
self.assertTrue(os.path.exists(model_path))
|
||||
self.assertGreater(os.path.getsize(model_path), 0)
|
||||
model_path_2 = downloaded_files.get_path()
|
||||
self.assertEqual(model_path, model_path_2)
|
||||
self.assertEqual(mock_get.call_count, 1)
|
||||
|
||||
def test_get_absolute_path(self):
|
||||
test_file = 'mediapipe/model_maker/python/core/utils/testdata/test.txt'
|
||||
absolute_path = file_util.get_absolute_path(test_file)
|
||||
|
|
Loading…
Reference in New Issue
Block a user