Change gesture_recognizer to download model files on-demand from GCS.
PiperOrigin-RevId: 508424508
This commit is contained in:
parent
c3907229fe
commit
28f728bed5
|
@ -35,17 +35,3 @@ mediapipe_files(
|
|||
"palm_detection_full.tflite",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "models",
|
||||
srcs = [
|
||||
"canned_gesture_classifier.tflite",
|
||||
"gesture_embedder.tflite",
|
||||
"gesture_embedder/keras_metadata.pb",
|
||||
"gesture_embedder/saved_model.pb",
|
||||
"gesture_embedder/variables/variables.data-00000-of-00001",
|
||||
"gesture_embedder/variables/variables.index",
|
||||
"hand_landmark_full.tflite",
|
||||
"palm_detection_full.tflite",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -35,7 +35,6 @@ py_library(
|
|||
name = "model_util",
|
||||
srcs = ["model_util.py"],
|
||||
deps = [
|
||||
":file_util",
|
||||
":quantization",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
],
|
||||
|
@ -45,7 +44,6 @@ py_test(
|
|||
name = "model_util_test",
|
||||
srcs = ["model_util_test.py"],
|
||||
deps = [
|
||||
":file_util",
|
||||
":model_util",
|
||||
":quantization",
|
||||
":test_util",
|
||||
|
|
|
@ -27,7 +27,6 @@ import numpy as np
|
|||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import dataset
|
||||
from mediapipe.model_maker.python.core.utils import file_util
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
|
||||
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
|
||||
|
@ -53,8 +52,8 @@ def load_keras_model(model_path: str,
|
|||
"""Loads a tensorflow Keras model from file and returns the Keras model.
|
||||
|
||||
Args:
|
||||
model_path: Relative path to a directory containing model data, such as
|
||||
<parent_path>/saved_model/.
|
||||
model_path: Absolute path to a directory containing model data, such as
|
||||
/<parent_path>/saved_model/.
|
||||
compile_on_load: Whether the model should be compiled while loading. If
|
||||
False, the model returned has to be compiled with the appropriate loss
|
||||
function and custom metrics before running for inference on a test
|
||||
|
@ -63,22 +62,22 @@ def load_keras_model(model_path: str,
|
|||
Returns:
|
||||
A tensorflow Keras model.
|
||||
"""
|
||||
absolute_path = file_util.get_absolute_path(model_path)
|
||||
return tf.keras.models.load_model(
|
||||
absolute_path, custom_objects={'tf': tf}, compile=compile_on_load)
|
||||
model_path, custom_objects={'tf': tf}, compile=compile_on_load
|
||||
)
|
||||
|
||||
|
||||
def load_tflite_model_buffer(model_path: str) -> bytearray:
|
||||
"""Loads a TFLite model buffer from file.
|
||||
|
||||
Args:
|
||||
model_path: Relative path to a TFLite file
|
||||
model_path: Absolute path to a TFLite file, such as
|
||||
/<parent_path>/<model_file>.tflite.
|
||||
|
||||
Returns:
|
||||
A TFLite model buffer
|
||||
"""
|
||||
absolute_path = file_util.get_absolute_path(model_path)
|
||||
with tf.io.gfile.GFile(absolute_path, 'rb') as f:
|
||||
with tf.io.gfile.GFile(model_path, 'rb') as f:
|
||||
tflite_model_buffer = f.read()
|
||||
return tflite_model_buffer
|
||||
|
||||
|
|
|
@ -14,12 +14,10 @@
|
|||
|
||||
import os
|
||||
from typing import Optional
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import file_util
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.core.utils import test_util
|
||||
|
@ -27,15 +25,11 @@ from mediapipe.model_maker.python.core.utils import test_util
|
|||
|
||||
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True)
|
||||
def test_load_keras_model(self, mock_get_absolute_path):
|
||||
def test_load_keras_model(self):
|
||||
input_dim = 4
|
||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||
model.save(saved_model_path)
|
||||
# model_util.load_keras_model takes in a relative path to files within the
|
||||
# model_maker dir, so we patch the function for testing
|
||||
mock_get_absolute_path.return_value = saved_model_path
|
||||
loaded_model = model_util.load_keras_model(saved_model_path)
|
||||
|
||||
input_tensors = test_util.create_random_sample(size=[1, input_dim])
|
||||
|
@ -43,16 +37,12 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
|||
loaded_model_output = loaded_model.predict_on_batch(input_tensors)
|
||||
self.assertTrue((model_output == loaded_model_output).all())
|
||||
|
||||
@unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True)
|
||||
def test_load_tflite_model_buffer(self, mock_get_absolute_path):
|
||||
def test_load_tflite_model_buffer(self):
|
||||
input_dim = 4
|
||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||
tflite_model = model_util.convert_to_tflite(model)
|
||||
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
|
||||
# model_util.load_tflite_model_buffer takes in a relative path to files
|
||||
# within the model_maker dir, so we patch the function for testing
|
||||
mock_get_absolute_path.return_value = tflite_file
|
||||
tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file)
|
||||
test_util.test_tflite(
|
||||
keras_model=model,
|
||||
|
|
|
@ -33,6 +33,9 @@ filegroup(
|
|||
py_library(
|
||||
name = "constants",
|
||||
srcs = ["constants.py"],
|
||||
deps = [
|
||||
"//mediapipe/model_maker/python/core/utils:file_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
@ -53,11 +56,11 @@ py_library(
|
|||
py_test(
|
||||
name = "dataset_test",
|
||||
srcs = ["dataset_test.py"],
|
||||
data = [
|
||||
":testdata",
|
||||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||
data = [":testdata"],
|
||||
tags = [
|
||||
"notsan",
|
||||
"requires-net:external",
|
||||
],
|
||||
tags = ["notsan"],
|
||||
deps = [
|
||||
":dataset",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
|
@ -90,7 +93,6 @@ py_library(
|
|||
py_library(
|
||||
name = "gesture_recognizer",
|
||||
srcs = ["gesture_recognizer.py"],
|
||||
data = ["//mediapipe/model_maker/models/gesture_recognizer:models"],
|
||||
deps = [
|
||||
":constants",
|
||||
":gesture_recognizer_options",
|
||||
|
@ -138,12 +140,12 @@ py_test(
|
|||
name = "gesture_recognizer_test",
|
||||
size = "large",
|
||||
srcs = ["gesture_recognizer_test.py"],
|
||||
data = [
|
||||
":testdata",
|
||||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||
],
|
||||
data = [":testdata"],
|
||||
shard_count = 2,
|
||||
tags = ["notsan"],
|
||||
tags = [
|
||||
"notsan",
|
||||
"requires-net:external",
|
||||
],
|
||||
deps = [
|
||||
":gesture_recognizer_import",
|
||||
":hyperparameters",
|
||||
|
@ -156,9 +158,7 @@ py_test(
|
|||
py_test(
|
||||
name = "metadata_writer_test",
|
||||
srcs = ["metadata_writer_test.py"],
|
||||
data = [
|
||||
":testdata",
|
||||
],
|
||||
data = [":testdata"],
|
||||
deps = [
|
||||
":metadata_writer",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
|
@ -169,10 +169,8 @@ py_test(
|
|||
py_binary(
|
||||
name = "gesture_recognizer_demo",
|
||||
srcs = ["gesture_recognizer_demo.py"],
|
||||
data = [
|
||||
":testdata",
|
||||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||
],
|
||||
data = [":testdata"],
|
||||
python_version = "PY3",
|
||||
tags = ["requires-net:external"],
|
||||
deps = [":gesture_recognizer_import"],
|
||||
)
|
||||
|
|
|
@ -13,8 +13,26 @@
|
|||
# limitations under the License.
|
||||
"""Gesture recognition constants."""
|
||||
|
||||
GESTURE_EMBEDDER_KERAS_MODEL_PATH = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder'
|
||||
GESTURE_EMBEDDER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder.tflite'
|
||||
HAND_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/palm_detection_full.tflite'
|
||||
HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/hand_landmark_full.tflite'
|
||||
CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/canned_gesture_classifier.tflite'
|
||||
from mediapipe.model_maker.python.core.utils import file_util
|
||||
|
||||
GESTURE_EMBEDDER_KERAS_MODEL_FILES = file_util.DownloadedFiles(
|
||||
'gesture_recognizer/gesture_embedder',
|
||||
'https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tar.gz',
|
||||
is_folder=True,
|
||||
)
|
||||
GESTURE_EMBEDDER_TFLITE_MODEL_FILE = file_util.DownloadedFiles(
|
||||
'gesture_recognizer/gesture_embedder.tflite',
|
||||
'https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tflite',
|
||||
)
|
||||
HAND_DETECTOR_TFLITE_MODEL_FILE = file_util.DownloadedFiles(
|
||||
'gesture_recognizer/palm_detection_full.tflite',
|
||||
'https://storage.googleapis.com/mediapipe-assets/palm_detection_full.tflite',
|
||||
)
|
||||
HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE = file_util.DownloadedFiles(
|
||||
'gesture_recognizer/hand_landmark_full.tflite',
|
||||
'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite',
|
||||
)
|
||||
CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE = file_util.DownloadedFiles(
|
||||
'gesture_recognizer/canned_gesture_classifier.tflite',
|
||||
'https://storage.googleapis.com/mediapipe-assets/canned_gesture_classifier.tflite',
|
||||
)
|
||||
|
|
|
@ -98,9 +98,11 @@ def _get_hand_data(all_image_paths: List[str],
|
|||
"""
|
||||
hand_data_result = []
|
||||
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.HAND_DETECTOR_TFLITE_MODEL_FILE)
|
||||
constants.HAND_DETECTOR_TFLITE_MODEL_FILE.get_path()
|
||||
)
|
||||
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE)
|
||||
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE.get_path()
|
||||
)
|
||||
hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter(
|
||||
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
|
||||
hand_landmarker_options = _HandLandmarkerOptions(
|
||||
|
@ -221,7 +223,8 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
hand_ds = tf.data.Dataset.from_tensor_slices(hand_data_dict)
|
||||
|
||||
embedder_model = model_util.load_keras_model(
|
||||
constants.GESTURE_EMBEDDER_KERAS_MODEL_PATH)
|
||||
constants.GESTURE_EMBEDDER_KERAS_MODEL_FILES.get_path()
|
||||
)
|
||||
|
||||
hand_ds = hand_ds.batch(batch_size=1)
|
||||
hand_embedding_ds = hand_ds.map(
|
||||
|
|
|
@ -14,8 +14,9 @@
|
|||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import NamedTuple
|
||||
import unittest
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
@ -29,6 +30,19 @@ _TEST_DATA_DIRNAME = 'raw_data'
|
|||
|
||||
class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||
# condition when downloading model since these tests may run in parallel.
|
||||
mock_gettempdir = unittest_mock.patch.object(
|
||||
tempfile,
|
||||
'gettempdir',
|
||||
return_value=self.create_tempdir(),
|
||||
autospec=True,
|
||||
)
|
||||
self.mock_gettempdir = mock_gettempdir.start()
|
||||
self.addCleanup(mock_gettempdir.stop)
|
||||
|
||||
def test_split(self):
|
||||
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||
data = dataset.Dataset.from_folder(
|
||||
|
@ -135,8 +149,9 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
|||
handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])),
|
||||
)
|
||||
def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple):
|
||||
with unittest.mock.patch.object(
|
||||
hand_landmarker.HandLandmarker, 'detect', return_value=hand):
|
||||
with unittest_mock.patch.object(
|
||||
hand_landmarker.HandLandmarker, 'detect', return_value=hand
|
||||
):
|
||||
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||
with self.assertRaisesRegex(ValueError, 'No valid hand is detected'):
|
||||
dataset.Dataset.from_folder(
|
||||
|
|
|
@ -197,13 +197,17 @@ class GestureRecognizer(classifier.Classifier):
|
|||
"""
|
||||
# TODO: Convert keras embedder model instead of using tflite
|
||||
gesture_embedding_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.GESTURE_EMBEDDER_TFLITE_MODEL_FILE)
|
||||
constants.GESTURE_EMBEDDER_TFLITE_MODEL_FILE.get_path()
|
||||
)
|
||||
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.HAND_DETECTOR_TFLITE_MODEL_FILE)
|
||||
constants.HAND_DETECTOR_TFLITE_MODEL_FILE.get_path()
|
||||
)
|
||||
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE)
|
||||
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE.get_path()
|
||||
)
|
||||
canned_gesture_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE)
|
||||
constants.CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE.get_path()
|
||||
)
|
||||
|
||||
if not tf.io.gfile.exists(self._hparams.export_dir):
|
||||
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||
|
|
|
@ -48,6 +48,16 @@ class GestureRecognizerTest(tf.test.TestCase):
|
|||
all_data = self._load_data()
|
||||
# Splits data, 90% data for training, 10% for validation
|
||||
self._train_data, self._validation_data = all_data.split(0.9)
|
||||
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||
# condition when downloading model since these tests may run in parallel.
|
||||
mock_gettempdir = unittest_mock.patch.object(
|
||||
tempfile,
|
||||
'gettempdir',
|
||||
return_value=self.create_tempdir(),
|
||||
autospec=True,
|
||||
)
|
||||
self.mock_gettempdir = mock_gettempdir.start()
|
||||
self.addCleanup(mock_gettempdir.stop)
|
||||
|
||||
def test_gesture_recognizer_model(self):
|
||||
mo = gesture_recognizer.ModelOptions()
|
||||
|
|
|
@ -82,7 +82,6 @@ def _setup_build_dir():
|
|||
|
||||
# Use bazel to download GCS model files
|
||||
model_build_files = [
|
||||
'models/gesture_recognizer/BUILD',
|
||||
'models/text_classifier/BUILD',
|
||||
]
|
||||
for model_build_file in model_build_files:
|
||||
|
@ -90,14 +89,6 @@ def _setup_build_dir():
|
|||
os.makedirs(os.path.dirname(build_target_file), exist_ok=True)
|
||||
shutil.copy(model_build_file, build_target_file)
|
||||
external_files = [
|
||||
'models/gesture_recognizer/canned_gesture_classifier.tflite',
|
||||
'models/gesture_recognizer/gesture_embedder.tflite',
|
||||
'models/gesture_recognizer/hand_landmark_full.tflite',
|
||||
'models/gesture_recognizer/palm_detection_full.tflite',
|
||||
'models/gesture_recognizer/gesture_embedder/keras_metadata.pb',
|
||||
'models/gesture_recognizer/gesture_embedder/saved_model.pb',
|
||||
'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001',
|
||||
'models/gesture_recognizer/gesture_embedder/variables/variables.index',
|
||||
'models/text_classifier/mobilebert_tiny/keras_metadata.pb',
|
||||
'models/text_classifier/mobilebert_tiny/saved_model.pb',
|
||||
'models/text_classifier/mobilebert_tiny/assets/vocab.txt',
|
||||
|
|
Loading…
Reference in New Issue
Block a user