From b3d19fa1af3b23a57993bf3a006e390184459e9f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 8 Nov 2022 13:50:28 -0800 Subject: [PATCH] Use model bundle writer when exporting models in gesture recognizer PiperOrigin-RevId: 487042776 --- mediapipe/model_maker/python/core/utils/BUILD | 13 +++++++ .../python/core/utils/file_util.py | 36 +++++++++++++++++++ .../python/core/utils/file_util_test.py | 29 +++++++++++++++ .../python/core/utils/model_util.py | 26 +++++++++----- .../python/core/utils/model_util_test.py | 15 +++++++- .../python/core/utils/testdata/BUILD | 23 ++++++++++++ .../python/core/utils/testdata/test.txt | 0 mediapipe/tasks/testdata/vision/BUILD | 5 +++ 8 files changed, 138 insertions(+), 9 deletions(-) create mode 100644 mediapipe/model_maker/python/core/utils/file_util.py create mode 100644 mediapipe/model_maker/python/core/utils/file_util_test.py create mode 100644 mediapipe/model_maker/python/core/utils/testdata/BUILD create mode 100644 mediapipe/model_maker/python/core/utils/testdata/test.txt diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index a2ec52044..12fef631f 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -35,6 +35,7 @@ py_library( name = "model_util", srcs = ["model_util.py"], deps = [ + ":file_util", ":quantization", "//mediapipe/model_maker/python/core/data:dataset", ], @@ -50,6 +51,18 @@ py_test( ], ) +py_library( + name = "file_util", + srcs = ["file_util.py"], +) + +py_test( + name = "file_util_test", + srcs = ["file_util_test.py"], + data = ["//mediapipe/model_maker/python/core/utils/testdata"], + deps = [":file_util"], +) + py_library( name = "loss_functions", srcs = ["loss_functions.py"], diff --git a/mediapipe/model_maker/python/core/utils/file_util.py b/mediapipe/model_maker/python/core/utils/file_util.py new file mode 100644 index 000000000..bccf928e2 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/file_util.py @@ -0,0 +1,36 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for files.""" + +import os + +# resources dependency + + +def get_absolute_path(file_path: str) -> str: + """Gets the absolute path of a file. + + Args: + file_path: The path to a file relative to the `mediapipe` dir + + Returns: + The full path of the file + """ + # Extract the file path before mediapipe/ as the `base_dir`. By joining it + # with the `path` which defines the relative path under mediapipe/, it + # yields to the absolute path of the model files directory. + cwd = os.path.dirname(__file__) + base_dir = cwd[:cwd.rfind('mediapipe')] + absolute_path = os.path.join(base_dir, file_path) + return absolute_path diff --git a/mediapipe/model_maker/python/core/utils/file_util_test.py b/mediapipe/model_maker/python/core/utils/file_util_test.py new file mode 100644 index 000000000..4a2d6dcfb --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/file_util_test.py @@ -0,0 +1,29 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from absl.testing import absltest +from mediapipe.model_maker.python.core.utils import file_util + + +class FileUtilTest(absltest.TestCase): + + def test_get_absolute_path(self): + test_file = 'mediapipe/model_maker/python/core/utils/testdata/test.txt' + absolute_path = file_util.get_absolute_path(test_file) + self.assertTrue(os.path.exists(absolute_path)) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index ada0a61e3..01e301e43 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Utilities for keras models.""" +"""Utilities for models.""" from __future__ import absolute_import from __future__ import division @@ -26,8 +26,8 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import tensorflow as tf -# resources dependency from mediapipe.model_maker.python.core.data import dataset +from mediapipe.model_maker.python.core.utils import file_util from mediapipe.model_maker.python.core.utils import quantization DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0 @@ -62,16 +62,26 @@ def load_keras_model(model_path: str, Returns: A tensorflow Keras model. """ - # Extract the file path before mediapipe/ as the `base_dir`. By joining it - # with the `model_path` which defines the relative path under mediapipe/, it - # yields to the aboslution path of the model files directory. - cwd = os.path.dirname(__file__) - base_dir = cwd[:cwd.rfind('mediapipe')] - absolute_path = os.path.join(base_dir, model_path) + absolute_path = file_util.get_absolute_path(model_path) return tf.keras.models.load_model( absolute_path, custom_objects={'tf': tf}, compile=compile_on_load) +def load_tflite_model_buffer(model_path: str) -> bytearray: + """Loads a TFLite model buffer from file. + + Args: + model_path: Relative path to a TFLite file + + Returns: + A TFLite model buffer + """ + absolute_path = file_util.get_absolute_path(model_path) + with tf.io.gfile.GFile(absolute_path, 'rb') as f: + tflite_model_buffer = f.read() + return tflite_model_buffer + + def get_steps_per_epoch(steps_per_epoch: Optional[int] = None, batch_size: Optional[int] = None, train_data: Optional[dataset.Dataset] = None) -> int: diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py index 1f9e0f1db..bef9c8a97 100644 --- a/mediapipe/model_maker/python/core/utils/model_util_test.py +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -24,7 +24,7 @@ from mediapipe.model_maker.python.core.utils import test_util class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): - def test_load_model(self): + def test_load_keras_model(self): input_dim = 4 model = test_util.build_model(input_shape=[input_dim], num_classes=2) saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model') @@ -36,6 +36,19 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): loaded_model_output = loaded_model.predict_on_batch(input_tensors) self.assertTrue((model_output == loaded_model_output).all()) + def test_load_tflite_model_buffer(self): + input_dim = 4 + model = test_util.build_model(input_shape=[input_dim], num_classes=2) + tflite_model = model_util.convert_to_tflite(model) + tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') + model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file) + + tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file) + test_util.test_tflite( + keras_model=model, + tflite_model=tflite_model_buffer, + size=[1, input_dim]) + @parameterized.named_parameters( dict( testcase_name='input_only_steps_per_epoch', diff --git a/mediapipe/model_maker/python/core/utils/testdata/BUILD b/mediapipe/model_maker/python/core/utils/testdata/BUILD new file mode 100644 index 000000000..8eed72f78 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/testdata/BUILD @@ -0,0 +1,23 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//mediapipe/model_maker/python/core/utils:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "testdata", + srcs = ["test.txt"], +) diff --git a/mediapipe/model_maker/python/core/utils/testdata/test.txt b/mediapipe/model_maker/python/core/utils/testdata/test.txt new file mode 100644 index 000000000..e69de29bb diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index e23c4a66c..55d386185 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -144,8 +144,13 @@ filegroup( ) # Gestures related models. Visible to model_maker. +# TODO: Upload canned gesture model and gesture embedding model to GCS after Model Card approval filegroup( name = "test_gesture_models", + srcs = [ + "hand_landmark_full.tflite", + "palm_detection_full.tflite", + ], visibility = [ "//mediapipe/model_maker:__subpackages__", "//mediapipe/tasks:internal",