Make each model file in the model asset bundle to be aligned relative to the start of the file (alignment = 4).
PiperOrigin-RevId: 511624410
This commit is contained in:
parent
40b0dc960a
commit
9e7950a69a
|
@ -131,7 +131,7 @@ py_library(
|
||||||
srcs = ["metadata_writer.py"],
|
srcs = ["metadata_writer.py"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||||
"//mediapipe/tasks/python/metadata/metadata_writers:writer_utils",
|
"//mediapipe/tasks/python/metadata/metadata_writers:model_asset_bundle_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ from typing import Union
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
|
from mediapipe.tasks.python.metadata.metadata_writers import model_asset_bundle_utils
|
||||||
|
|
||||||
_HAND_DETECTOR_TFLITE_NAME = "hand_detector.tflite"
|
_HAND_DETECTOR_TFLITE_NAME = "hand_detector.tflite"
|
||||||
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME = "hand_landmarks_detector.tflite"
|
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME = "hand_landmarks_detector.tflite"
|
||||||
|
@ -100,8 +100,9 @@ class HandLandmarkerMetadataWriter:
|
||||||
}
|
}
|
||||||
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
|
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
|
||||||
_HAND_LANDMARKER_BUNDLE_NAME)
|
_HAND_LANDMARKER_BUNDLE_NAME)
|
||||||
writer_utils.create_model_asset_bundle(landmark_models,
|
model_asset_bundle_utils.create_model_asset_bundle(
|
||||||
output_hand_landmarker_path)
|
landmark_models, output_hand_landmarker_path
|
||||||
|
)
|
||||||
hand_landmarker_model_buffer = read_file(output_hand_landmarker_path)
|
hand_landmarker_model_buffer = read_file(output_hand_landmarker_path)
|
||||||
return hand_landmarker_model_buffer
|
return hand_landmarker_model_buffer
|
||||||
|
|
||||||
|
@ -208,8 +209,9 @@ class MetadataWriter:
|
||||||
}
|
}
|
||||||
output_hand_gesture_recognizer_path = os.path.join(
|
output_hand_gesture_recognizer_path = os.path.join(
|
||||||
self._temp_folder.name, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME)
|
self._temp_folder.name, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME)
|
||||||
writer_utils.create_model_asset_bundle(hand_gesture_recognizer_models,
|
model_asset_bundle_utils.create_model_asset_bundle(
|
||||||
output_hand_gesture_recognizer_path)
|
hand_gesture_recognizer_models, output_hand_gesture_recognizer_path
|
||||||
|
)
|
||||||
|
|
||||||
# Creates the model asset bundle for end-to-end hand gesture recognizer
|
# Creates the model asset bundle for end-to-end hand gesture recognizer
|
||||||
# graph.
|
# graph.
|
||||||
|
@ -222,8 +224,9 @@ class MetadataWriter:
|
||||||
|
|
||||||
output_file_path = os.path.join(self._temp_folder.name,
|
output_file_path = os.path.join(self._temp_folder.name,
|
||||||
"gesture_recognizer.task")
|
"gesture_recognizer.task")
|
||||||
writer_utils.create_model_asset_bundle(gesture_recognizer_models,
|
model_asset_bundle_utils.create_model_asset_bundle(
|
||||||
output_file_path)
|
gesture_recognizer_models, output_file_path
|
||||||
|
)
|
||||||
with open(output_file_path, "rb") as f:
|
with open(output_file_path, "rb") as f:
|
||||||
gesture_recognizer_model_buffer = f.read()
|
gesture_recognizer_model_buffer = f.read()
|
||||||
return gesture_recognizer_model_buffer, custom_gesture_classifier_metadata_json
|
return gesture_recognizer_model_buffer, custom_gesture_classifier_metadata_json
|
||||||
|
|
|
@ -49,3 +49,8 @@ py_library(
|
||||||
srcs = ["text_classifier.py"],
|
srcs = ["text_classifier.py"],
|
||||||
deps = [":metadata_writer"],
|
deps = [":metadata_writer"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "model_asset_bundle_utils",
|
||||||
|
srcs = ["model_asset_bundle_utils.py"],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,71 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utility methods for creating the model asset bundles."""
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
# Alignment that ensures that all uncompressed files in the model bundle file
|
||||||
|
# are aligned relative to the start of the file. This lets the files be
|
||||||
|
# accessed directly via mmap.
|
||||||
|
_ALIGNMENT = 4
|
||||||
|
|
||||||
|
|
||||||
|
class AlignZipFile(zipfile.ZipFile):
|
||||||
|
"""ZipFile that stores uncompressed files at particular alignment."""
|
||||||
|
|
||||||
|
def __init__(self, *args, alignment: int, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
assert alignment > 0
|
||||||
|
self._alignment = alignment
|
||||||
|
|
||||||
|
def _writecheck(self, zinfo: zipfile.ZipInfo) -> None:
|
||||||
|
# Aligned the uncompressed files.
|
||||||
|
if zinfo.compress_type == zipfile.ZIP_STORED:
|
||||||
|
offset = self.fp.tell()
|
||||||
|
header_length = len(zinfo.FileHeader())
|
||||||
|
padding_length = (
|
||||||
|
self._alignment - (offset + header_length) % self._alignment
|
||||||
|
)
|
||||||
|
if padding_length:
|
||||||
|
offset += padding_length
|
||||||
|
self.fp.write(b"\x00" * padding_length)
|
||||||
|
assert self.fp.tell() == offset
|
||||||
|
zinfo.header_offset = offset
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Only support the uncompressed file (compress_type =="
|
||||||
|
" zipfile.ZIP_STORED) in zip. The current file compress type is "
|
||||||
|
+ str(zinfo.compress_type)
|
||||||
|
)
|
||||||
|
super()._writecheck(zinfo)
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_asset_bundle(
|
||||||
|
input_models: Dict[str, bytes], output_path: str
|
||||||
|
) -> None:
|
||||||
|
"""Creates the model asset bundle.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_models: A dict of input models with key as the model file name and
|
||||||
|
value as the model content.
|
||||||
|
output_path: The output file path to save the model asset bundle.
|
||||||
|
"""
|
||||||
|
if not input_models or len(input_models) < 2:
|
||||||
|
raise ValueError("Needs at least two input models for model asset bundle.")
|
||||||
|
|
||||||
|
with AlignZipFile(output_path, mode="w", alignment=_ALIGNMENT) as zf:
|
||||||
|
for file_name, file_buffer in input_models.items():
|
||||||
|
zf.writestr(file_name, file_buffer)
|
|
@ -14,8 +14,7 @@
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Helper methods for writing metadata into TFLite models."""
|
"""Helper methods for writing metadata into TFLite models."""
|
||||||
|
|
||||||
from typing import Dict, List
|
from typing import List
|
||||||
import zipfile
|
|
||||||
|
|
||||||
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
||||||
|
|
||||||
|
@ -84,20 +83,3 @@ def get_subgraph(model_buffer: bytearray) -> _schema_fb.SubGraph:
|
||||||
# multiple subgraphs yet, but models with mini-benchmark may have multiple
|
# multiple subgraphs yet, but models with mini-benchmark may have multiple
|
||||||
# subgraphs for acceleration evaluation purpose.
|
# subgraphs for acceleration evaluation purpose.
|
||||||
return model.Subgraphs(0)
|
return model.Subgraphs(0)
|
||||||
|
|
||||||
|
|
||||||
def create_model_asset_bundle(input_models: Dict[str, bytes],
|
|
||||||
output_path: str) -> None:
|
|
||||||
"""Creates the model asset bundle.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_models: A dict of input models with key as the model file name and
|
|
||||||
value as the model content.
|
|
||||||
output_path: The output file path to save the model asset bundle.
|
|
||||||
"""
|
|
||||||
if not input_models or len(input_models) < 2:
|
|
||||||
raise ValueError("Needs at least two input models for model asset bundle.")
|
|
||||||
|
|
||||||
with zipfile.ZipFile(output_path, mode="w") as zf:
|
|
||||||
for file_name, file_buffer in input_models.items():
|
|
||||||
zf.writestr(file_name, file_buffer)
|
|
||||||
|
|
|
@ -67,3 +67,9 @@ py_test(
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "model_asset_bundle_utils_test",
|
||||||
|
srcs = ["model_asset_bundle_utils_test.py"],
|
||||||
|
deps = ["//mediapipe/tasks/python/metadata/metadata_writers:model_asset_bundle_utils"],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for model asset bundle utilities."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import model_asset_bundle_utils
|
||||||
|
|
||||||
|
|
||||||
|
class ModelAssetBundleUtilsTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_create_model_asset_bundle(self):
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
bundle_file = os.path.join(temp_dir, 'test.task')
|
||||||
|
input_models = {'1.tflite': b'\x11\x22', '2.tflite': b'\x33'}
|
||||||
|
model_asset_bundle_utils.create_model_asset_bundle(
|
||||||
|
input_models, bundle_file
|
||||||
|
)
|
||||||
|
with zipfile.ZipFile(bundle_file) as zf:
|
||||||
|
for info in zf.infolist():
|
||||||
|
# Each file should be aligned.
|
||||||
|
header_length = len(info.FileHeader())
|
||||||
|
offset = info.header_offset + header_length
|
||||||
|
self.assertEqual(offset % model_asset_bundle_utils._ALIGNMENT, 0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
Loading…
Reference in New Issue
Block a user