diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index 2dad9a617..77ed2e016 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -131,7 +131,7 @@ py_library( srcs = ["metadata_writer.py"], deps = [ "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", - "//mediapipe/tasks/python/metadata/metadata_writers:writer_utils", + "//mediapipe/tasks/python/metadata/metadata_writers:model_asset_bundle_utils", ], ) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py index b2e851afe..d6dc3ec2c 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py @@ -21,7 +21,7 @@ from typing import Union import tensorflow as tf from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer -from mediapipe.tasks.python.metadata.metadata_writers import writer_utils +from mediapipe.tasks.python.metadata.metadata_writers import model_asset_bundle_utils _HAND_DETECTOR_TFLITE_NAME = "hand_detector.tflite" _HAND_LANDMARKS_DETECTOR_TFLITE_NAME = "hand_landmarks_detector.tflite" @@ -100,8 +100,9 @@ class HandLandmarkerMetadataWriter: } output_hand_landmarker_path = os.path.join(self._temp_folder.name, _HAND_LANDMARKER_BUNDLE_NAME) - writer_utils.create_model_asset_bundle(landmark_models, - output_hand_landmarker_path) + model_asset_bundle_utils.create_model_asset_bundle( + landmark_models, output_hand_landmarker_path + ) hand_landmarker_model_buffer = read_file(output_hand_landmarker_path) return hand_landmarker_model_buffer @@ -208,8 +209,9 @@ class MetadataWriter: } output_hand_gesture_recognizer_path = os.path.join( self._temp_folder.name, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME) - writer_utils.create_model_asset_bundle(hand_gesture_recognizer_models, - output_hand_gesture_recognizer_path) + model_asset_bundle_utils.create_model_asset_bundle( + hand_gesture_recognizer_models, output_hand_gesture_recognizer_path + ) # Creates the model asset bundle for end-to-end hand gesture recognizer # graph. @@ -222,8 +224,9 @@ class MetadataWriter: output_file_path = os.path.join(self._temp_folder.name, "gesture_recognizer.task") - writer_utils.create_model_asset_bundle(gesture_recognizer_models, - output_file_path) + model_asset_bundle_utils.create_model_asset_bundle( + gesture_recognizer_models, output_file_path + ) with open(output_file_path, "rb") as f: gesture_recognizer_model_buffer = f.read() return gesture_recognizer_model_buffer, custom_gesture_classifier_metadata_json diff --git a/mediapipe/tasks/python/metadata/metadata_writers/BUILD b/mediapipe/tasks/python/metadata/metadata_writers/BUILD index 1177939bd..69d952998 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/BUILD +++ b/mediapipe/tasks/python/metadata/metadata_writers/BUILD @@ -49,3 +49,8 @@ py_library( srcs = ["text_classifier.py"], deps = [":metadata_writer"], ) + +py_library( + name = "model_asset_bundle_utils", + srcs = ["model_asset_bundle_utils.py"], +) diff --git a/mediapipe/tasks/python/metadata/metadata_writers/model_asset_bundle_utils.py b/mediapipe/tasks/python/metadata/metadata_writers/model_asset_bundle_utils.py new file mode 100644 index 000000000..b626dc7b1 --- /dev/null +++ b/mediapipe/tasks/python/metadata/metadata_writers/model_asset_bundle_utils.py @@ -0,0 +1,71 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility methods for creating the model asset bundles.""" + +from typing import Dict +import zipfile + +# Alignment that ensures that all uncompressed files in the model bundle file +# are aligned relative to the start of the file. This lets the files be +# accessed directly via mmap. +_ALIGNMENT = 4 + + +class AlignZipFile(zipfile.ZipFile): + """ZipFile that stores uncompressed files at particular alignment.""" + + def __init__(self, *args, alignment: int, **kwargs) -> None: + super().__init__(*args, **kwargs) + assert alignment > 0 + self._alignment = alignment + + def _writecheck(self, zinfo: zipfile.ZipInfo) -> None: + # Aligned the uncompressed files. + if zinfo.compress_type == zipfile.ZIP_STORED: + offset = self.fp.tell() + header_length = len(zinfo.FileHeader()) + padding_length = ( + self._alignment - (offset + header_length) % self._alignment + ) + if padding_length: + offset += padding_length + self.fp.write(b"\x00" * padding_length) + assert self.fp.tell() == offset + zinfo.header_offset = offset + else: + raise ValueError( + "Only support the uncompressed file (compress_type ==" + " zipfile.ZIP_STORED) in zip. The current file compress type is " + + str(zinfo.compress_type) + ) + super()._writecheck(zinfo) + + +def create_model_asset_bundle( + input_models: Dict[str, bytes], output_path: str +) -> None: + """Creates the model asset bundle. + + Args: + input_models: A dict of input models with key as the model file name and + value as the model content. + output_path: The output file path to save the model asset bundle. + """ + if not input_models or len(input_models) < 2: + raise ValueError("Needs at least two input models for model asset bundle.") + + with AlignZipFile(output_path, mode="w", alignment=_ALIGNMENT) as zf: + for file_name, file_buffer in input_models.items(): + zf.writestr(file_name, file_buffer) diff --git a/mediapipe/tasks/python/metadata/metadata_writers/writer_utils.py b/mediapipe/tasks/python/metadata/metadata_writers/writer_utils.py index eff5f553e..0a054812b 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/writer_utils.py +++ b/mediapipe/tasks/python/metadata/metadata_writers/writer_utils.py @@ -14,8 +14,7 @@ # ============================================================================== """Helper methods for writing metadata into TFLite models.""" -from typing import Dict, List -import zipfile +from typing import List from mediapipe.tasks.metadata import schema_py_generated as _schema_fb @@ -84,20 +83,3 @@ def get_subgraph(model_buffer: bytearray) -> _schema_fb.SubGraph: # multiple subgraphs yet, but models with mini-benchmark may have multiple # subgraphs for acceleration evaluation purpose. return model.Subgraphs(0) - - -def create_model_asset_bundle(input_models: Dict[str, bytes], - output_path: str) -> None: - """Creates the model asset bundle. - - Args: - input_models: A dict of input models with key as the model file name and - value as the model content. - output_path: The output file path to save the model asset bundle. - """ - if not input_models or len(input_models) < 2: - raise ValueError("Needs at least two input models for model asset bundle.") - - with zipfile.ZipFile(output_path, mode="w") as zf: - for file_name, file_buffer in input_models.items(): - zf.writestr(file_name, file_buffer) diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD b/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD index 948b3f8d9..539b3903b 100644 --- a/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD @@ -67,3 +67,9 @@ py_test( "//mediapipe/tasks/python/test:test_utils", ], ) + +py_test( + name = "model_asset_bundle_utils_test", + srcs = ["model_asset_bundle_utils_test.py"], + deps = ["//mediapipe/tasks/python/metadata/metadata_writers:model_asset_bundle_utils"], +) diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/model_asset_bundle_utils_test.py b/mediapipe/tasks/python/test/metadata/metadata_writers/model_asset_bundle_utils_test.py new file mode 100644 index 000000000..e42613932 --- /dev/null +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/model_asset_bundle_utils_test.py @@ -0,0 +1,43 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for model asset bundle utilities.""" + +import os +import tempfile +import zipfile + +from absl.testing import absltest +from mediapipe.tasks.python.metadata.metadata_writers import model_asset_bundle_utils + + +class ModelAssetBundleUtilsTest(absltest.TestCase): + + def test_create_model_asset_bundle(self): + with tempfile.TemporaryDirectory() as temp_dir: + bundle_file = os.path.join(temp_dir, 'test.task') + input_models = {'1.tflite': b'\x11\x22', '2.tflite': b'\x33'} + model_asset_bundle_utils.create_model_asset_bundle( + input_models, bundle_file + ) + with zipfile.ZipFile(bundle_file) as zf: + for info in zf.infolist(): + # Each file should be aligned. + header_length = len(info.FileHeader()) + offset = info.header_offset + header_length + self.assertEqual(offset % model_asset_bundle_utils._ALIGNMENT, 0) + + +if __name__ == '__main__': + absltest.main()