Use model bundle writer when exporting models in gesture recognizer

PiperOrigin-RevId: 487042776
This commit is contained in:
MediaPipe Team 2022-11-08 13:50:28 -08:00 committed by Copybara-Service
parent 0363d60511
commit b3d19fa1af
8 changed files with 138 additions and 9 deletions

View File

@ -35,6 +35,7 @@ py_library(
name = "model_util",
srcs = ["model_util.py"],
deps = [
":file_util",
":quantization",
"//mediapipe/model_maker/python/core/data:dataset",
],
@ -50,6 +51,18 @@ py_test(
],
)
py_library(
name = "file_util",
srcs = ["file_util.py"],
)
py_test(
name = "file_util_test",
srcs = ["file_util_test.py"],
data = ["//mediapipe/model_maker/python/core/utils/testdata"],
deps = [":file_util"],
)
py_library(
name = "loss_functions",
srcs = ["loss_functions.py"],

View File

@ -0,0 +1,36 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for files."""
import os
# resources dependency
def get_absolute_path(file_path: str) -> str:
"""Gets the absolute path of a file.
Args:
file_path: The path to a file relative to the `mediapipe` dir
Returns:
The full path of the file
"""
# Extract the file path before mediapipe/ as the `base_dir`. By joining it
# with the `path` which defines the relative path under mediapipe/, it
# yields to the absolute path of the model files directory.
cwd = os.path.dirname(__file__)
base_dir = cwd[:cwd.rfind('mediapipe')]
absolute_path = os.path.join(base_dir, file_path)
return absolute_path

View File

@ -0,0 +1,29 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from absl.testing import absltest
from mediapipe.model_maker.python.core.utils import file_util
class FileUtilTest(absltest.TestCase):
def test_get_absolute_path(self):
test_file = 'mediapipe/model_maker/python/core/utils/testdata/test.txt'
absolute_path = file_util.get_absolute_path(test_file)
self.assertTrue(os.path.exists(absolute_path))
if __name__ == '__main__':
absltest.main()

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for keras models."""
"""Utilities for models."""
from __future__ import absolute_import
from __future__ import division
@ -26,8 +26,8 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import tensorflow as tf
# resources dependency
from mediapipe.model_maker.python.core.data import dataset
from mediapipe.model_maker.python.core.utils import file_util
from mediapipe.model_maker.python.core.utils import quantization
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
@ -62,16 +62,26 @@ def load_keras_model(model_path: str,
Returns:
A tensorflow Keras model.
"""
# Extract the file path before mediapipe/ as the `base_dir`. By joining it
# with the `model_path` which defines the relative path under mediapipe/, it
# yields to the aboslution path of the model files directory.
cwd = os.path.dirname(__file__)
base_dir = cwd[:cwd.rfind('mediapipe')]
absolute_path = os.path.join(base_dir, model_path)
absolute_path = file_util.get_absolute_path(model_path)
return tf.keras.models.load_model(
absolute_path, custom_objects={'tf': tf}, compile=compile_on_load)
def load_tflite_model_buffer(model_path: str) -> bytearray:
"""Loads a TFLite model buffer from file.
Args:
model_path: Relative path to a TFLite file
Returns:
A TFLite model buffer
"""
absolute_path = file_util.get_absolute_path(model_path)
with tf.io.gfile.GFile(absolute_path, 'rb') as f:
tflite_model_buffer = f.read()
return tflite_model_buffer
def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
batch_size: Optional[int] = None,
train_data: Optional[dataset.Dataset] = None) -> int:

View File

@ -24,7 +24,7 @@ from mediapipe.model_maker.python.core.utils import test_util
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
def test_load_model(self):
def test_load_keras_model(self):
input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
@ -36,6 +36,19 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
loaded_model_output = loaded_model.predict_on_batch(input_tensors)
self.assertTrue((model_output == loaded_model_output).all())
def test_load_tflite_model_buffer(self):
input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_model = model_util.convert_to_tflite(model)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file)
test_util.test_tflite(
keras_model=model,
tflite_model=tflite_model_buffer,
size=[1, input_dim])
@parameterized.named_parameters(
dict(
testcase_name='input_only_steps_per_epoch',

View File

@ -0,0 +1,23 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_visibility = ["//mediapipe/model_maker/python/core/utils:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "testdata",
srcs = ["test.txt"],
)

View File

@ -144,8 +144,13 @@ filegroup(
)
# Gestures related models. Visible to model_maker.
# TODO: Upload canned gesture model and gesture embedding model to GCS after Model Card approval
filegroup(
name = "test_gesture_models",
srcs = [
"hand_landmark_full.tflite",
"palm_detection_full.tflite",
],
visibility = [
"//mediapipe/model_maker:__subpackages__",
"//mediapipe/tasks:internal",