Use model bundle writer when exporting models in gesture recognizer
PiperOrigin-RevId: 487042776
This commit is contained in:
parent
0363d60511
commit
b3d19fa1af
|
@ -35,6 +35,7 @@ py_library(
|
||||||
name = "model_util",
|
name = "model_util",
|
||||||
srcs = ["model_util.py"],
|
srcs = ["model_util.py"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":file_util",
|
||||||
":quantization",
|
":quantization",
|
||||||
"//mediapipe/model_maker/python/core/data:dataset",
|
"//mediapipe/model_maker/python/core/data:dataset",
|
||||||
],
|
],
|
||||||
|
@ -50,6 +51,18 @@ py_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "file_util",
|
||||||
|
srcs = ["file_util.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "file_util_test",
|
||||||
|
srcs = ["file_util_test.py"],
|
||||||
|
data = ["//mediapipe/model_maker/python/core/utils/testdata"],
|
||||||
|
deps = [":file_util"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "loss_functions",
|
name = "loss_functions",
|
||||||
srcs = ["loss_functions.py"],
|
srcs = ["loss_functions.py"],
|
||||||
|
|
36
mediapipe/model_maker/python/core/utils/file_util.py
Normal file
36
mediapipe/model_maker/python/core/utils/file_util.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Utilities for files."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# resources dependency
|
||||||
|
|
||||||
|
|
||||||
|
def get_absolute_path(file_path: str) -> str:
|
||||||
|
"""Gets the absolute path of a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: The path to a file relative to the `mediapipe` dir
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The full path of the file
|
||||||
|
"""
|
||||||
|
# Extract the file path before mediapipe/ as the `base_dir`. By joining it
|
||||||
|
# with the `path` which defines the relative path under mediapipe/, it
|
||||||
|
# yields to the absolute path of the model files directory.
|
||||||
|
cwd = os.path.dirname(__file__)
|
||||||
|
base_dir = cwd[:cwd.rfind('mediapipe')]
|
||||||
|
absolute_path = os.path.join(base_dir, file_path)
|
||||||
|
return absolute_path
|
29
mediapipe/model_maker/python/core/utils/file_util_test.py
Normal file
29
mediapipe/model_maker/python/core/utils/file_util_test.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
from mediapipe.model_maker.python.core.utils import file_util
|
||||||
|
|
||||||
|
|
||||||
|
class FileUtilTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_get_absolute_path(self):
|
||||||
|
test_file = 'mediapipe/model_maker/python/core/utils/testdata/test.txt'
|
||||||
|
absolute_path = file_util.get_absolute_path(test_file)
|
||||||
|
self.assertTrue(os.path.exists(absolute_path))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Utilities for keras models."""
|
"""Utilities for models."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
@ -26,8 +26,8 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
# resources dependency
|
|
||||||
from mediapipe.model_maker.python.core.data import dataset
|
from mediapipe.model_maker.python.core.data import dataset
|
||||||
|
from mediapipe.model_maker.python.core.utils import file_util
|
||||||
from mediapipe.model_maker.python.core.utils import quantization
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
|
|
||||||
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
|
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
|
||||||
|
@ -62,16 +62,26 @@ def load_keras_model(model_path: str,
|
||||||
Returns:
|
Returns:
|
||||||
A tensorflow Keras model.
|
A tensorflow Keras model.
|
||||||
"""
|
"""
|
||||||
# Extract the file path before mediapipe/ as the `base_dir`. By joining it
|
absolute_path = file_util.get_absolute_path(model_path)
|
||||||
# with the `model_path` which defines the relative path under mediapipe/, it
|
|
||||||
# yields to the aboslution path of the model files directory.
|
|
||||||
cwd = os.path.dirname(__file__)
|
|
||||||
base_dir = cwd[:cwd.rfind('mediapipe')]
|
|
||||||
absolute_path = os.path.join(base_dir, model_path)
|
|
||||||
return tf.keras.models.load_model(
|
return tf.keras.models.load_model(
|
||||||
absolute_path, custom_objects={'tf': tf}, compile=compile_on_load)
|
absolute_path, custom_objects={'tf': tf}, compile=compile_on_load)
|
||||||
|
|
||||||
|
|
||||||
|
def load_tflite_model_buffer(model_path: str) -> bytearray:
|
||||||
|
"""Loads a TFLite model buffer from file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Relative path to a TFLite file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A TFLite model buffer
|
||||||
|
"""
|
||||||
|
absolute_path = file_util.get_absolute_path(model_path)
|
||||||
|
with tf.io.gfile.GFile(absolute_path, 'rb') as f:
|
||||||
|
tflite_model_buffer = f.read()
|
||||||
|
return tflite_model_buffer
|
||||||
|
|
||||||
|
|
||||||
def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
|
def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
|
||||||
batch_size: Optional[int] = None,
|
batch_size: Optional[int] = None,
|
||||||
train_data: Optional[dataset.Dataset] = None) -> int:
|
train_data: Optional[dataset.Dataset] = None) -> int:
|
||||||
|
|
|
@ -24,7 +24,7 @@ from mediapipe.model_maker.python.core.utils import test_util
|
||||||
|
|
||||||
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def test_load_model(self):
|
def test_load_keras_model(self):
|
||||||
input_dim = 4
|
input_dim = 4
|
||||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||||
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
|
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||||
|
@ -36,6 +36,19 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
loaded_model_output = loaded_model.predict_on_batch(input_tensors)
|
loaded_model_output = loaded_model.predict_on_batch(input_tensors)
|
||||||
self.assertTrue((model_output == loaded_model_output).all())
|
self.assertTrue((model_output == loaded_model_output).all())
|
||||||
|
|
||||||
|
def test_load_tflite_model_buffer(self):
|
||||||
|
input_dim = 4
|
||||||
|
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||||
|
tflite_model = model_util.convert_to_tflite(model)
|
||||||
|
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||||
|
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
|
||||||
|
|
||||||
|
tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file)
|
||||||
|
test_util.test_tflite(
|
||||||
|
keras_model=model,
|
||||||
|
tflite_model=tflite_model_buffer,
|
||||||
|
size=[1, input_dim])
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
dict(
|
dict(
|
||||||
testcase_name='input_only_steps_per_epoch',
|
testcase_name='input_only_steps_per_epoch',
|
||||||
|
|
23
mediapipe/model_maker/python/core/utils/testdata/BUILD
vendored
Normal file
23
mediapipe/model_maker/python/core/utils/testdata/BUILD
vendored
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe/model_maker/python/core/utils:__subpackages__"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "testdata",
|
||||||
|
srcs = ["test.txt"],
|
||||||
|
)
|
0
mediapipe/model_maker/python/core/utils/testdata/test.txt
vendored
Normal file
0
mediapipe/model_maker/python/core/utils/testdata/test.txt
vendored
Normal file
5
mediapipe/tasks/testdata/vision/BUILD
vendored
5
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -144,8 +144,13 @@ filegroup(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Gestures related models. Visible to model_maker.
|
# Gestures related models. Visible to model_maker.
|
||||||
|
# TODO: Upload canned gesture model and gesture embedding model to GCS after Model Card approval
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "test_gesture_models",
|
name = "test_gesture_models",
|
||||||
|
srcs = [
|
||||||
|
"hand_landmark_full.tflite",
|
||||||
|
"palm_detection_full.tflite",
|
||||||
|
],
|
||||||
visibility = [
|
visibility = [
|
||||||
"//mediapipe/model_maker:__subpackages__",
|
"//mediapipe/model_maker:__subpackages__",
|
||||||
"//mediapipe/tasks:internal",
|
"//mediapipe/tasks:internal",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user