Make each model file in the model asset bundle to be aligned relative to the start of the file (alignment = 4).

PiperOrigin-RevId: 511624410
This commit is contained in:
Yuqi Li 2023-02-22 16:13:23 -08:00 committed by Copybara-Service
parent 40b0dc960a
commit 9e7950a69a
7 changed files with 137 additions and 27 deletions

View File

@ -131,7 +131,7 @@ py_library(
srcs = ["metadata_writer.py"], srcs = ["metadata_writer.py"],
deps = [ deps = [
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
"//mediapipe/tasks/python/metadata/metadata_writers:writer_utils", "//mediapipe/tasks/python/metadata/metadata_writers:model_asset_bundle_utils",
], ],
) )

View File

@ -21,7 +21,7 @@ from typing import Union
import tensorflow as tf import tensorflow as tf
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils from mediapipe.tasks.python.metadata.metadata_writers import model_asset_bundle_utils
_HAND_DETECTOR_TFLITE_NAME = "hand_detector.tflite" _HAND_DETECTOR_TFLITE_NAME = "hand_detector.tflite"
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME = "hand_landmarks_detector.tflite" _HAND_LANDMARKS_DETECTOR_TFLITE_NAME = "hand_landmarks_detector.tflite"
@ -100,8 +100,9 @@ class HandLandmarkerMetadataWriter:
} }
output_hand_landmarker_path = os.path.join(self._temp_folder.name, output_hand_landmarker_path = os.path.join(self._temp_folder.name,
_HAND_LANDMARKER_BUNDLE_NAME) _HAND_LANDMARKER_BUNDLE_NAME)
writer_utils.create_model_asset_bundle(landmark_models, model_asset_bundle_utils.create_model_asset_bundle(
output_hand_landmarker_path) landmark_models, output_hand_landmarker_path
)
hand_landmarker_model_buffer = read_file(output_hand_landmarker_path) hand_landmarker_model_buffer = read_file(output_hand_landmarker_path)
return hand_landmarker_model_buffer return hand_landmarker_model_buffer
@ -208,8 +209,9 @@ class MetadataWriter:
} }
output_hand_gesture_recognizer_path = os.path.join( output_hand_gesture_recognizer_path = os.path.join(
self._temp_folder.name, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME) self._temp_folder.name, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME)
writer_utils.create_model_asset_bundle(hand_gesture_recognizer_models, model_asset_bundle_utils.create_model_asset_bundle(
output_hand_gesture_recognizer_path) hand_gesture_recognizer_models, output_hand_gesture_recognizer_path
)
# Creates the model asset bundle for end-to-end hand gesture recognizer # Creates the model asset bundle for end-to-end hand gesture recognizer
# graph. # graph.
@ -222,8 +224,9 @@ class MetadataWriter:
output_file_path = os.path.join(self._temp_folder.name, output_file_path = os.path.join(self._temp_folder.name,
"gesture_recognizer.task") "gesture_recognizer.task")
writer_utils.create_model_asset_bundle(gesture_recognizer_models, model_asset_bundle_utils.create_model_asset_bundle(
output_file_path) gesture_recognizer_models, output_file_path
)
with open(output_file_path, "rb") as f: with open(output_file_path, "rb") as f:
gesture_recognizer_model_buffer = f.read() gesture_recognizer_model_buffer = f.read()
return gesture_recognizer_model_buffer, custom_gesture_classifier_metadata_json return gesture_recognizer_model_buffer, custom_gesture_classifier_metadata_json

View File

@ -49,3 +49,8 @@ py_library(
srcs = ["text_classifier.py"], srcs = ["text_classifier.py"],
deps = [":metadata_writer"], deps = [":metadata_writer"],
) )
py_library(
name = "model_asset_bundle_utils",
srcs = ["model_asset_bundle_utils.py"],
)

View File

@ -0,0 +1,71 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility methods for creating the model asset bundles."""
from typing import Dict
import zipfile
# Alignment that ensures that all uncompressed files in the model bundle file
# are aligned relative to the start of the file. This lets the files be
# accessed directly via mmap.
_ALIGNMENT = 4
class AlignZipFile(zipfile.ZipFile):
"""ZipFile that stores uncompressed files at particular alignment."""
def __init__(self, *args, alignment: int, **kwargs) -> None:
super().__init__(*args, **kwargs)
assert alignment > 0
self._alignment = alignment
def _writecheck(self, zinfo: zipfile.ZipInfo) -> None:
# Aligned the uncompressed files.
if zinfo.compress_type == zipfile.ZIP_STORED:
offset = self.fp.tell()
header_length = len(zinfo.FileHeader())
padding_length = (
self._alignment - (offset + header_length) % self._alignment
)
if padding_length:
offset += padding_length
self.fp.write(b"\x00" * padding_length)
assert self.fp.tell() == offset
zinfo.header_offset = offset
else:
raise ValueError(
"Only support the uncompressed file (compress_type =="
" zipfile.ZIP_STORED) in zip. The current file compress type is "
+ str(zinfo.compress_type)
)
super()._writecheck(zinfo)
def create_model_asset_bundle(
input_models: Dict[str, bytes], output_path: str
) -> None:
"""Creates the model asset bundle.
Args:
input_models: A dict of input models with key as the model file name and
value as the model content.
output_path: The output file path to save the model asset bundle.
"""
if not input_models or len(input_models) < 2:
raise ValueError("Needs at least two input models for model asset bundle.")
with AlignZipFile(output_path, mode="w", alignment=_ALIGNMENT) as zf:
for file_name, file_buffer in input_models.items():
zf.writestr(file_name, file_buffer)

View File

@ -14,8 +14,7 @@
# ============================================================================== # ==============================================================================
"""Helper methods for writing metadata into TFLite models.""" """Helper methods for writing metadata into TFLite models."""
from typing import Dict, List from typing import List
import zipfile
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
@ -84,20 +83,3 @@ def get_subgraph(model_buffer: bytearray) -> _schema_fb.SubGraph:
# multiple subgraphs yet, but models with mini-benchmark may have multiple # multiple subgraphs yet, but models with mini-benchmark may have multiple
# subgraphs for acceleration evaluation purpose. # subgraphs for acceleration evaluation purpose.
return model.Subgraphs(0) return model.Subgraphs(0)
def create_model_asset_bundle(input_models: Dict[str, bytes],
output_path: str) -> None:
"""Creates the model asset bundle.
Args:
input_models: A dict of input models with key as the model file name and
value as the model content.
output_path: The output file path to save the model asset bundle.
"""
if not input_models or len(input_models) < 2:
raise ValueError("Needs at least two input models for model asset bundle.")
with zipfile.ZipFile(output_path, mode="w") as zf:
for file_name, file_buffer in input_models.items():
zf.writestr(file_name, file_buffer)

View File

@ -67,3 +67,9 @@ py_test(
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],
) )
py_test(
name = "model_asset_bundle_utils_test",
srcs = ["model_asset_bundle_utils_test.py"],
deps = ["//mediapipe/tasks/python/metadata/metadata_writers:model_asset_bundle_utils"],
)

View File

@ -0,0 +1,43 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for model asset bundle utilities."""
import os
import tempfile
import zipfile
from absl.testing import absltest
from mediapipe.tasks.python.metadata.metadata_writers import model_asset_bundle_utils
class ModelAssetBundleUtilsTest(absltest.TestCase):
def test_create_model_asset_bundle(self):
with tempfile.TemporaryDirectory() as temp_dir:
bundle_file = os.path.join(temp_dir, 'test.task')
input_models = {'1.tflite': b'\x11\x22', '2.tflite': b'\x33'}
model_asset_bundle_utils.create_model_asset_bundle(
input_models, bundle_file
)
with zipfile.ZipFile(bundle_file) as zf:
for info in zf.infolist():
# Each file should be aligned.
header_length = len(info.FileHeader())
offset = info.header_offset + header_length
self.assertEqual(offset % model_asset_bundle_utils._ALIGNMENT, 0)
if __name__ == '__main__':
absltest.main()