Make each model file in the model asset bundle to be aligned relative to the start of the file (alignment = 4).

PiperOrigin-RevId: 511624410
This commit is contained in:
Yuqi Li 2023-02-22 16:13:23 -08:00 committed by Copybara-Service
parent 40b0dc960a
commit 9e7950a69a
7 changed files with 137 additions and 27 deletions

View File

@ -131,7 +131,7 @@ py_library(
srcs = ["metadata_writer.py"],
deps = [
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
"//mediapipe/tasks/python/metadata/metadata_writers:writer_utils",
"//mediapipe/tasks/python/metadata/metadata_writers:model_asset_bundle_utils",
],
)

View File

@ -21,7 +21,7 @@ from typing import Union
import tensorflow as tf
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
from mediapipe.tasks.python.metadata.metadata_writers import model_asset_bundle_utils
_HAND_DETECTOR_TFLITE_NAME = "hand_detector.tflite"
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME = "hand_landmarks_detector.tflite"
@ -100,8 +100,9 @@ class HandLandmarkerMetadataWriter:
}
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
_HAND_LANDMARKER_BUNDLE_NAME)
writer_utils.create_model_asset_bundle(landmark_models,
output_hand_landmarker_path)
model_asset_bundle_utils.create_model_asset_bundle(
landmark_models, output_hand_landmarker_path
)
hand_landmarker_model_buffer = read_file(output_hand_landmarker_path)
return hand_landmarker_model_buffer
@ -208,8 +209,9 @@ class MetadataWriter:
}
output_hand_gesture_recognizer_path = os.path.join(
self._temp_folder.name, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME)
writer_utils.create_model_asset_bundle(hand_gesture_recognizer_models,
output_hand_gesture_recognizer_path)
model_asset_bundle_utils.create_model_asset_bundle(
hand_gesture_recognizer_models, output_hand_gesture_recognizer_path
)
# Creates the model asset bundle for end-to-end hand gesture recognizer
# graph.
@ -222,8 +224,9 @@ class MetadataWriter:
output_file_path = os.path.join(self._temp_folder.name,
"gesture_recognizer.task")
writer_utils.create_model_asset_bundle(gesture_recognizer_models,
output_file_path)
model_asset_bundle_utils.create_model_asset_bundle(
gesture_recognizer_models, output_file_path
)
with open(output_file_path, "rb") as f:
gesture_recognizer_model_buffer = f.read()
return gesture_recognizer_model_buffer, custom_gesture_classifier_metadata_json

View File

@ -49,3 +49,8 @@ py_library(
srcs = ["text_classifier.py"],
deps = [":metadata_writer"],
)
py_library(
name = "model_asset_bundle_utils",
srcs = ["model_asset_bundle_utils.py"],
)

View File

@ -0,0 +1,71 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility methods for creating the model asset bundles."""
from typing import Dict
import zipfile
# Alignment that ensures that all uncompressed files in the model bundle file
# are aligned relative to the start of the file. This lets the files be
# accessed directly via mmap.
_ALIGNMENT = 4
class AlignZipFile(zipfile.ZipFile):
"""ZipFile that stores uncompressed files at particular alignment."""
def __init__(self, *args, alignment: int, **kwargs) -> None:
super().__init__(*args, **kwargs)
assert alignment > 0
self._alignment = alignment
def _writecheck(self, zinfo: zipfile.ZipInfo) -> None:
# Aligned the uncompressed files.
if zinfo.compress_type == zipfile.ZIP_STORED:
offset = self.fp.tell()
header_length = len(zinfo.FileHeader())
padding_length = (
self._alignment - (offset + header_length) % self._alignment
)
if padding_length:
offset += padding_length
self.fp.write(b"\x00" * padding_length)
assert self.fp.tell() == offset
zinfo.header_offset = offset
else:
raise ValueError(
"Only support the uncompressed file (compress_type =="
" zipfile.ZIP_STORED) in zip. The current file compress type is "
+ str(zinfo.compress_type)
)
super()._writecheck(zinfo)
def create_model_asset_bundle(
input_models: Dict[str, bytes], output_path: str
) -> None:
"""Creates the model asset bundle.
Args:
input_models: A dict of input models with key as the model file name and
value as the model content.
output_path: The output file path to save the model asset bundle.
"""
if not input_models or len(input_models) < 2:
raise ValueError("Needs at least two input models for model asset bundle.")
with AlignZipFile(output_path, mode="w", alignment=_ALIGNMENT) as zf:
for file_name, file_buffer in input_models.items():
zf.writestr(file_name, file_buffer)

View File

@ -14,8 +14,7 @@
# ==============================================================================
"""Helper methods for writing metadata into TFLite models."""
from typing import Dict, List
import zipfile
from typing import List
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
@ -84,20 +83,3 @@ def get_subgraph(model_buffer: bytearray) -> _schema_fb.SubGraph:
# multiple subgraphs yet, but models with mini-benchmark may have multiple
# subgraphs for acceleration evaluation purpose.
return model.Subgraphs(0)
def create_model_asset_bundle(input_models: Dict[str, bytes],
output_path: str) -> None:
"""Creates the model asset bundle.
Args:
input_models: A dict of input models with key as the model file name and
value as the model content.
output_path: The output file path to save the model asset bundle.
"""
if not input_models or len(input_models) < 2:
raise ValueError("Needs at least two input models for model asset bundle.")
with zipfile.ZipFile(output_path, mode="w") as zf:
for file_name, file_buffer in input_models.items():
zf.writestr(file_name, file_buffer)

View File

@ -67,3 +67,9 @@ py_test(
"//mediapipe/tasks/python/test:test_utils",
],
)
py_test(
name = "model_asset_bundle_utils_test",
srcs = ["model_asset_bundle_utils_test.py"],
deps = ["//mediapipe/tasks/python/metadata/metadata_writers:model_asset_bundle_utils"],
)

View File

@ -0,0 +1,43 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for model asset bundle utilities."""
import os
import tempfile
import zipfile
from absl.testing import absltest
from mediapipe.tasks.python.metadata.metadata_writers import model_asset_bundle_utils
class ModelAssetBundleUtilsTest(absltest.TestCase):
def test_create_model_asset_bundle(self):
with tempfile.TemporaryDirectory() as temp_dir:
bundle_file = os.path.join(temp_dir, 'test.task')
input_models = {'1.tflite': b'\x11\x22', '2.tflite': b'\x33'}
model_asset_bundle_utils.create_model_asset_bundle(
input_models, bundle_file
)
with zipfile.ZipFile(bundle_file) as zf:
for info in zf.infolist():
# Each file should be aligned.
header_length = len(info.FileHeader())
offset = info.header_offset + header_length
self.assertEqual(offset % model_asset_bundle_utils._ALIGNMENT, 0)
if __name__ == '__main__':
absltest.main()