Make each model file in the model asset bundle to be aligned relative to the start of the file (alignment = 4).
PiperOrigin-RevId: 511624410
This commit is contained in:
parent
40b0dc960a
commit
9e7950a69a
|
@ -131,7 +131,7 @@ py_library(
|
|||
srcs = ["metadata_writer.py"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:writer_utils",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:model_asset_bundle_utils",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from typing import Union
|
|||
|
||||
import tensorflow as tf
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import model_asset_bundle_utils
|
||||
|
||||
_HAND_DETECTOR_TFLITE_NAME = "hand_detector.tflite"
|
||||
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME = "hand_landmarks_detector.tflite"
|
||||
|
@ -100,8 +100,9 @@ class HandLandmarkerMetadataWriter:
|
|||
}
|
||||
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
|
||||
_HAND_LANDMARKER_BUNDLE_NAME)
|
||||
writer_utils.create_model_asset_bundle(landmark_models,
|
||||
output_hand_landmarker_path)
|
||||
model_asset_bundle_utils.create_model_asset_bundle(
|
||||
landmark_models, output_hand_landmarker_path
|
||||
)
|
||||
hand_landmarker_model_buffer = read_file(output_hand_landmarker_path)
|
||||
return hand_landmarker_model_buffer
|
||||
|
||||
|
@ -208,8 +209,9 @@ class MetadataWriter:
|
|||
}
|
||||
output_hand_gesture_recognizer_path = os.path.join(
|
||||
self._temp_folder.name, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME)
|
||||
writer_utils.create_model_asset_bundle(hand_gesture_recognizer_models,
|
||||
output_hand_gesture_recognizer_path)
|
||||
model_asset_bundle_utils.create_model_asset_bundle(
|
||||
hand_gesture_recognizer_models, output_hand_gesture_recognizer_path
|
||||
)
|
||||
|
||||
# Creates the model asset bundle for end-to-end hand gesture recognizer
|
||||
# graph.
|
||||
|
@ -222,8 +224,9 @@ class MetadataWriter:
|
|||
|
||||
output_file_path = os.path.join(self._temp_folder.name,
|
||||
"gesture_recognizer.task")
|
||||
writer_utils.create_model_asset_bundle(gesture_recognizer_models,
|
||||
output_file_path)
|
||||
model_asset_bundle_utils.create_model_asset_bundle(
|
||||
gesture_recognizer_models, output_file_path
|
||||
)
|
||||
with open(output_file_path, "rb") as f:
|
||||
gesture_recognizer_model_buffer = f.read()
|
||||
return gesture_recognizer_model_buffer, custom_gesture_classifier_metadata_json
|
||||
|
|
|
@ -49,3 +49,8 @@ py_library(
|
|||
srcs = ["text_classifier.py"],
|
||||
deps = [":metadata_writer"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_asset_bundle_utils",
|
||||
srcs = ["model_asset_bundle_utils.py"],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility methods for creating the model asset bundles."""
|
||||
|
||||
from typing import Dict
|
||||
import zipfile
|
||||
|
||||
# Alignment that ensures that all uncompressed files in the model bundle file
|
||||
# are aligned relative to the start of the file. This lets the files be
|
||||
# accessed directly via mmap.
|
||||
_ALIGNMENT = 4
|
||||
|
||||
|
||||
class AlignZipFile(zipfile.ZipFile):
|
||||
"""ZipFile that stores uncompressed files at particular alignment."""
|
||||
|
||||
def __init__(self, *args, alignment: int, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
assert alignment > 0
|
||||
self._alignment = alignment
|
||||
|
||||
def _writecheck(self, zinfo: zipfile.ZipInfo) -> None:
|
||||
# Aligned the uncompressed files.
|
||||
if zinfo.compress_type == zipfile.ZIP_STORED:
|
||||
offset = self.fp.tell()
|
||||
header_length = len(zinfo.FileHeader())
|
||||
padding_length = (
|
||||
self._alignment - (offset + header_length) % self._alignment
|
||||
)
|
||||
if padding_length:
|
||||
offset += padding_length
|
||||
self.fp.write(b"\x00" * padding_length)
|
||||
assert self.fp.tell() == offset
|
||||
zinfo.header_offset = offset
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only support the uncompressed file (compress_type =="
|
||||
" zipfile.ZIP_STORED) in zip. The current file compress type is "
|
||||
+ str(zinfo.compress_type)
|
||||
)
|
||||
super()._writecheck(zinfo)
|
||||
|
||||
|
||||
def create_model_asset_bundle(
|
||||
input_models: Dict[str, bytes], output_path: str
|
||||
) -> None:
|
||||
"""Creates the model asset bundle.
|
||||
|
||||
Args:
|
||||
input_models: A dict of input models with key as the model file name and
|
||||
value as the model content.
|
||||
output_path: The output file path to save the model asset bundle.
|
||||
"""
|
||||
if not input_models or len(input_models) < 2:
|
||||
raise ValueError("Needs at least two input models for model asset bundle.")
|
||||
|
||||
with AlignZipFile(output_path, mode="w", alignment=_ALIGNMENT) as zf:
|
||||
for file_name, file_buffer in input_models.items():
|
||||
zf.writestr(file_name, file_buffer)
|
|
@ -14,8 +14,7 @@
|
|||
# ==============================================================================
|
||||
"""Helper methods for writing metadata into TFLite models."""
|
||||
|
||||
from typing import Dict, List
|
||||
import zipfile
|
||||
from typing import List
|
||||
|
||||
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
||||
|
||||
|
@ -84,20 +83,3 @@ def get_subgraph(model_buffer: bytearray) -> _schema_fb.SubGraph:
|
|||
# multiple subgraphs yet, but models with mini-benchmark may have multiple
|
||||
# subgraphs for acceleration evaluation purpose.
|
||||
return model.Subgraphs(0)
|
||||
|
||||
|
||||
def create_model_asset_bundle(input_models: Dict[str, bytes],
|
||||
output_path: str) -> None:
|
||||
"""Creates the model asset bundle.
|
||||
|
||||
Args:
|
||||
input_models: A dict of input models with key as the model file name and
|
||||
value as the model content.
|
||||
output_path: The output file path to save the model asset bundle.
|
||||
"""
|
||||
if not input_models or len(input_models) < 2:
|
||||
raise ValueError("Needs at least two input models for model asset bundle.")
|
||||
|
||||
with zipfile.ZipFile(output_path, mode="w") as zf:
|
||||
for file_name, file_buffer in input_models.items():
|
||||
zf.writestr(file_name, file_buffer)
|
||||
|
|
|
@ -67,3 +67,9 @@ py_test(
|
|||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "model_asset_bundle_utils_test",
|
||||
srcs = ["model_asset_bundle_utils_test.py"],
|
||||
deps = ["//mediapipe/tasks/python/metadata/metadata_writers:model_asset_bundle_utils"],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for model asset bundle utilities."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
from absl.testing import absltest
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import model_asset_bundle_utils
|
||||
|
||||
|
||||
class ModelAssetBundleUtilsTest(absltest.TestCase):
|
||||
|
||||
def test_create_model_asset_bundle(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
bundle_file = os.path.join(temp_dir, 'test.task')
|
||||
input_models = {'1.tflite': b'\x11\x22', '2.tflite': b'\x33'}
|
||||
model_asset_bundle_utils.create_model_asset_bundle(
|
||||
input_models, bundle_file
|
||||
)
|
||||
with zipfile.ZipFile(bundle_file) as zf:
|
||||
for info in zf.infolist():
|
||||
# Each file should be aligned.
|
||||
header_length = len(info.FileHeader())
|
||||
offset = info.header_offset + header_length
|
||||
self.assertEqual(offset % model_asset_bundle_utils._ALIGNMENT, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
Loading…
Reference in New Issue
Block a user