Use model bundle writer when exporting models in gesture recognizer
PiperOrigin-RevId: 487042776
This commit is contained in:
parent
0363d60511
commit
b3d19fa1af
|
@ -35,6 +35,7 @@ py_library(
|
|||
name = "model_util",
|
||||
srcs = ["model_util.py"],
|
||||
deps = [
|
||||
":file_util",
|
||||
":quantization",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
],
|
||||
|
@ -50,6 +51,18 @@ py_test(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "file_util",
|
||||
srcs = ["file_util.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "file_util_test",
|
||||
srcs = ["file_util_test.py"],
|
||||
data = ["//mediapipe/model_maker/python/core/utils/testdata"],
|
||||
deps = [":file_util"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "loss_functions",
|
||||
srcs = ["loss_functions.py"],
|
||||
|
|
36
mediapipe/model_maker/python/core/utils/file_util.py
Normal file
36
mediapipe/model_maker/python/core/utils/file_util.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for files."""
|
||||
|
||||
import os
|
||||
|
||||
# resources dependency
|
||||
|
||||
|
||||
def get_absolute_path(file_path: str) -> str:
|
||||
"""Gets the absolute path of a file.
|
||||
|
||||
Args:
|
||||
file_path: The path to a file relative to the `mediapipe` dir
|
||||
|
||||
Returns:
|
||||
The full path of the file
|
||||
"""
|
||||
# Extract the file path before mediapipe/ as the `base_dir`. By joining it
|
||||
# with the `path` which defines the relative path under mediapipe/, it
|
||||
# yields to the absolute path of the model files directory.
|
||||
cwd = os.path.dirname(__file__)
|
||||
base_dir = cwd[:cwd.rfind('mediapipe')]
|
||||
absolute_path = os.path.join(base_dir, file_path)
|
||||
return absolute_path
|
29
mediapipe/model_maker/python/core/utils/file_util_test.py
Normal file
29
mediapipe/model_maker/python/core/utils/file_util_test.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
from absl.testing import absltest
|
||||
from mediapipe.model_maker.python.core.utils import file_util
|
||||
|
||||
|
||||
class FileUtilTest(absltest.TestCase):
|
||||
|
||||
def test_get_absolute_path(self):
|
||||
test_file = 'mediapipe/model_maker/python/core/utils/testdata/test.txt'
|
||||
absolute_path = file_util.get_absolute_path(test_file)
|
||||
self.assertTrue(os.path.exists(absolute_path))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for keras models."""
|
||||
"""Utilities for models."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -26,8 +26,8 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
|||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
# resources dependency
|
||||
from mediapipe.model_maker.python.core.data import dataset
|
||||
from mediapipe.model_maker.python.core.utils import file_util
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
|
||||
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
|
||||
|
@ -62,16 +62,26 @@ def load_keras_model(model_path: str,
|
|||
Returns:
|
||||
A tensorflow Keras model.
|
||||
"""
|
||||
# Extract the file path before mediapipe/ as the `base_dir`. By joining it
|
||||
# with the `model_path` which defines the relative path under mediapipe/, it
|
||||
# yields to the aboslution path of the model files directory.
|
||||
cwd = os.path.dirname(__file__)
|
||||
base_dir = cwd[:cwd.rfind('mediapipe')]
|
||||
absolute_path = os.path.join(base_dir, model_path)
|
||||
absolute_path = file_util.get_absolute_path(model_path)
|
||||
return tf.keras.models.load_model(
|
||||
absolute_path, custom_objects={'tf': tf}, compile=compile_on_load)
|
||||
|
||||
|
||||
def load_tflite_model_buffer(model_path: str) -> bytearray:
|
||||
"""Loads a TFLite model buffer from file.
|
||||
|
||||
Args:
|
||||
model_path: Relative path to a TFLite file
|
||||
|
||||
Returns:
|
||||
A TFLite model buffer
|
||||
"""
|
||||
absolute_path = file_util.get_absolute_path(model_path)
|
||||
with tf.io.gfile.GFile(absolute_path, 'rb') as f:
|
||||
tflite_model_buffer = f.read()
|
||||
return tflite_model_buffer
|
||||
|
||||
|
||||
def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
train_data: Optional[dataset.Dataset] = None) -> int:
|
||||
|
|
|
@ -24,7 +24,7 @@ from mediapipe.model_maker.python.core.utils import test_util
|
|||
|
||||
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_load_model(self):
|
||||
def test_load_keras_model(self):
|
||||
input_dim = 4
|
||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||
|
@ -36,6 +36,19 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
|||
loaded_model_output = loaded_model.predict_on_batch(input_tensors)
|
||||
self.assertTrue((model_output == loaded_model_output).all())
|
||||
|
||||
def test_load_tflite_model_buffer(self):
|
||||
input_dim = 4
|
||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||
tflite_model = model_util.convert_to_tflite(model)
|
||||
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
|
||||
|
||||
tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file)
|
||||
test_util.test_tflite(
|
||||
keras_model=model,
|
||||
tflite_model=tflite_model_buffer,
|
||||
size=[1, input_dim])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='input_only_steps_per_epoch',
|
||||
|
|
23
mediapipe/model_maker/python/core/utils/testdata/BUILD
vendored
Normal file
23
mediapipe/model_maker/python/core/utils/testdata/BUILD
vendored
Normal file
|
@ -0,0 +1,23 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe/model_maker/python/core/utils:__subpackages__"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "testdata",
|
||||
srcs = ["test.txt"],
|
||||
)
|
0
mediapipe/model_maker/python/core/utils/testdata/test.txt
vendored
Normal file
0
mediapipe/model_maker/python/core/utils/testdata/test.txt
vendored
Normal file
5
mediapipe/tasks/testdata/vision/BUILD
vendored
5
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -144,8 +144,13 @@ filegroup(
|
|||
)
|
||||
|
||||
# Gestures related models. Visible to model_maker.
|
||||
# TODO: Upload canned gesture model and gesture embedding model to GCS after Model Card approval
|
||||
filegroup(
|
||||
name = "test_gesture_models",
|
||||
srcs = [
|
||||
"hand_landmark_full.tflite",
|
||||
"palm_detection_full.tflite",
|
||||
],
|
||||
visibility = [
|
||||
"//mediapipe/model_maker:__subpackages__",
|
||||
"//mediapipe/tasks:internal",
|
||||
|
|
Loading…
Reference in New Issue
Block a user