Change gesture_recognizer to download model files on-demand from GCS.

PiperOrigin-RevId: 508424508
This commit is contained in:
MediaPipe Team 2023-02-09 11:11:04 -08:00 committed by Copybara-Service
parent c3907229fe
commit 28f728bed5
11 changed files with 89 additions and 77 deletions

View File

@ -35,17 +35,3 @@ mediapipe_files(
"palm_detection_full.tflite",
],
)
filegroup(
name = "models",
srcs = [
"canned_gesture_classifier.tflite",
"gesture_embedder.tflite",
"gesture_embedder/keras_metadata.pb",
"gesture_embedder/saved_model.pb",
"gesture_embedder/variables/variables.data-00000-of-00001",
"gesture_embedder/variables/variables.index",
"hand_landmark_full.tflite",
"palm_detection_full.tflite",
],
)

View File

@ -35,7 +35,6 @@ py_library(
name = "model_util",
srcs = ["model_util.py"],
deps = [
":file_util",
":quantization",
"//mediapipe/model_maker/python/core/data:dataset",
],
@ -45,7 +44,6 @@ py_test(
name = "model_util_test",
srcs = ["model_util_test.py"],
deps = [
":file_util",
":model_util",
":quantization",
":test_util",

View File

@ -27,7 +27,6 @@ import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset
from mediapipe.model_maker.python.core.utils import file_util
from mediapipe.model_maker.python.core.utils import quantization
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
@ -53,8 +52,8 @@ def load_keras_model(model_path: str,
"""Loads a tensorflow Keras model from file and returns the Keras model.
Args:
model_path: Relative path to a directory containing model data, such as
<parent_path>/saved_model/.
model_path: Absolute path to a directory containing model data, such as
/<parent_path>/saved_model/.
compile_on_load: Whether the model should be compiled while loading. If
False, the model returned has to be compiled with the appropriate loss
function and custom metrics before running for inference on a test
@ -63,22 +62,22 @@ def load_keras_model(model_path: str,
Returns:
A tensorflow Keras model.
"""
absolute_path = file_util.get_absolute_path(model_path)
return tf.keras.models.load_model(
absolute_path, custom_objects={'tf': tf}, compile=compile_on_load)
model_path, custom_objects={'tf': tf}, compile=compile_on_load
)
def load_tflite_model_buffer(model_path: str) -> bytearray:
"""Loads a TFLite model buffer from file.
Args:
model_path: Relative path to a TFLite file
model_path: Absolute path to a TFLite file, such as
/<parent_path>/<model_file>.tflite.
Returns:
A TFLite model buffer
"""
absolute_path = file_util.get_absolute_path(model_path)
with tf.io.gfile.GFile(absolute_path, 'rb') as f:
with tf.io.gfile.GFile(model_path, 'rb') as f:
tflite_model_buffer = f.read()
return tflite_model_buffer

View File

@ -14,12 +14,10 @@
import os
from typing import Optional
from unittest import mock as unittest_mock
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import file_util
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.core.utils import test_util
@ -27,15 +25,11 @@ from mediapipe.model_maker.python.core.utils import test_util
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
@unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True)
def test_load_keras_model(self, mock_get_absolute_path):
def test_load_keras_model(self):
input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
model.save(saved_model_path)
# model_util.load_keras_model takes in a relative path to files within the
# model_maker dir, so we patch the function for testing
mock_get_absolute_path.return_value = saved_model_path
loaded_model = model_util.load_keras_model(saved_model_path)
input_tensors = test_util.create_random_sample(size=[1, input_dim])
@ -43,16 +37,12 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
loaded_model_output = loaded_model.predict_on_batch(input_tensors)
self.assertTrue((model_output == loaded_model_output).all())
@unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True)
def test_load_tflite_model_buffer(self, mock_get_absolute_path):
def test_load_tflite_model_buffer(self):
input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_model = model_util.convert_to_tflite(model)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
# model_util.load_tflite_model_buffer takes in a relative path to files
# within the model_maker dir, so we patch the function for testing
mock_get_absolute_path.return_value = tflite_file
tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file)
test_util.test_tflite(
keras_model=model,

View File

@ -33,6 +33,9 @@ filegroup(
py_library(
name = "constants",
srcs = ["constants.py"],
deps = [
"//mediapipe/model_maker/python/core/utils:file_util",
],
)
py_library(
@ -53,11 +56,11 @@ py_library(
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
data = [
":testdata",
"//mediapipe/model_maker/models/gesture_recognizer:models",
data = [":testdata"],
tags = [
"notsan",
"requires-net:external",
],
tags = ["notsan"],
deps = [
":dataset",
"//mediapipe/tasks/python/test:test_utils",
@ -90,7 +93,6 @@ py_library(
py_library(
name = "gesture_recognizer",
srcs = ["gesture_recognizer.py"],
data = ["//mediapipe/model_maker/models/gesture_recognizer:models"],
deps = [
":constants",
":gesture_recognizer_options",
@ -138,12 +140,12 @@ py_test(
name = "gesture_recognizer_test",
size = "large",
srcs = ["gesture_recognizer_test.py"],
data = [
":testdata",
"//mediapipe/model_maker/models/gesture_recognizer:models",
],
data = [":testdata"],
shard_count = 2,
tags = ["notsan"],
tags = [
"notsan",
"requires-net:external",
],
deps = [
":gesture_recognizer_import",
":hyperparameters",
@ -156,9 +158,7 @@ py_test(
py_test(
name = "metadata_writer_test",
srcs = ["metadata_writer_test.py"],
data = [
":testdata",
],
data = [":testdata"],
deps = [
":metadata_writer",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
@ -169,10 +169,8 @@ py_test(
py_binary(
name = "gesture_recognizer_demo",
srcs = ["gesture_recognizer_demo.py"],
data = [
":testdata",
"//mediapipe/model_maker/models/gesture_recognizer:models",
],
data = [":testdata"],
python_version = "PY3",
tags = ["requires-net:external"],
deps = [":gesture_recognizer_import"],
)

View File

@ -13,8 +13,26 @@
# limitations under the License.
"""Gesture recognition constants."""
GESTURE_EMBEDDER_KERAS_MODEL_PATH = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder'
GESTURE_EMBEDDER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder.tflite'
HAND_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/palm_detection_full.tflite'
HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/hand_landmark_full.tflite'
CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/canned_gesture_classifier.tflite'
from mediapipe.model_maker.python.core.utils import file_util
GESTURE_EMBEDDER_KERAS_MODEL_FILES = file_util.DownloadedFiles(
'gesture_recognizer/gesture_embedder',
'https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tar.gz',
is_folder=True,
)
GESTURE_EMBEDDER_TFLITE_MODEL_FILE = file_util.DownloadedFiles(
'gesture_recognizer/gesture_embedder.tflite',
'https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tflite',
)
HAND_DETECTOR_TFLITE_MODEL_FILE = file_util.DownloadedFiles(
'gesture_recognizer/palm_detection_full.tflite',
'https://storage.googleapis.com/mediapipe-assets/palm_detection_full.tflite',
)
HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE = file_util.DownloadedFiles(
'gesture_recognizer/hand_landmark_full.tflite',
'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite',
)
CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE = file_util.DownloadedFiles(
'gesture_recognizer/canned_gesture_classifier.tflite',
'https://storage.googleapis.com/mediapipe-assets/canned_gesture_classifier.tflite',
)

View File

@ -98,9 +98,11 @@ def _get_hand_data(all_image_paths: List[str],
"""
hand_data_result = []
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
constants.HAND_DETECTOR_TFLITE_MODEL_FILE)
constants.HAND_DETECTOR_TFLITE_MODEL_FILE.get_path()
)
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE)
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE.get_path()
)
hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter(
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
hand_landmarker_options = _HandLandmarkerOptions(
@ -221,7 +223,8 @@ class Dataset(classification_dataset.ClassificationDataset):
hand_ds = tf.data.Dataset.from_tensor_slices(hand_data_dict)
embedder_model = model_util.load_keras_model(
constants.GESTURE_EMBEDDER_KERAS_MODEL_PATH)
constants.GESTURE_EMBEDDER_KERAS_MODEL_FILES.get_path()
)
hand_ds = hand_ds.batch(batch_size=1)
hand_embedding_ds = hand_ds.map(

View File

@ -14,8 +14,9 @@
import os
import shutil
import tempfile
from typing import NamedTuple
import unittest
from unittest import mock as unittest_mock
from absl.testing import parameterized
import tensorflow as tf
@ -29,6 +30,19 @@ _TEST_DATA_DIRNAME = 'raw_data'
class DatasetTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
# Mock tempfile.gettempdir() to be unique for each test to avoid race
# condition when downloading model since these tests may run in parallel.
mock_gettempdir = unittest_mock.patch.object(
tempfile,
'gettempdir',
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
def test_split(self):
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
data = dataset.Dataset.from_folder(
@ -135,8 +149,9 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])),
)
def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple):
with unittest.mock.patch.object(
hand_landmarker.HandLandmarker, 'detect', return_value=hand):
with unittest_mock.patch.object(
hand_landmarker.HandLandmarker, 'detect', return_value=hand
):
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
with self.assertRaisesRegex(ValueError, 'No valid hand is detected'):
dataset.Dataset.from_folder(

View File

@ -197,13 +197,17 @@ class GestureRecognizer(classifier.Classifier):
"""
# TODO: Convert keras embedder model instead of using tflite
gesture_embedding_model_buffer = model_util.load_tflite_model_buffer(
constants.GESTURE_EMBEDDER_TFLITE_MODEL_FILE)
constants.GESTURE_EMBEDDER_TFLITE_MODEL_FILE.get_path()
)
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
constants.HAND_DETECTOR_TFLITE_MODEL_FILE)
constants.HAND_DETECTOR_TFLITE_MODEL_FILE.get_path()
)
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE)
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE.get_path()
)
canned_gesture_model_buffer = model_util.load_tflite_model_buffer(
constants.CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE)
constants.CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE.get_path()
)
if not tf.io.gfile.exists(self._hparams.export_dir):
tf.io.gfile.makedirs(self._hparams.export_dir)

View File

@ -48,6 +48,16 @@ class GestureRecognizerTest(tf.test.TestCase):
all_data = self._load_data()
# Splits data, 90% data for training, 10% for validation
self._train_data, self._validation_data = all_data.split(0.9)
# Mock tempfile.gettempdir() to be unique for each test to avoid race
# condition when downloading model since these tests may run in parallel.
mock_gettempdir = unittest_mock.patch.object(
tempfile,
'gettempdir',
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
def test_gesture_recognizer_model(self):
mo = gesture_recognizer.ModelOptions()

View File

@ -82,7 +82,6 @@ def _setup_build_dir():
# Use bazel to download GCS model files
model_build_files = [
'models/gesture_recognizer/BUILD',
'models/text_classifier/BUILD',
]
for model_build_file in model_build_files:
@ -90,14 +89,6 @@ def _setup_build_dir():
os.makedirs(os.path.dirname(build_target_file), exist_ok=True)
shutil.copy(model_build_file, build_target_file)
external_files = [
'models/gesture_recognizer/canned_gesture_classifier.tflite',
'models/gesture_recognizer/gesture_embedder.tflite',
'models/gesture_recognizer/hand_landmark_full.tflite',
'models/gesture_recognizer/palm_detection_full.tflite',
'models/gesture_recognizer/gesture_embedder/keras_metadata.pb',
'models/gesture_recognizer/gesture_embedder/saved_model.pb',
'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001',
'models/gesture_recognizer/gesture_embedder/variables/variables.index',
'models/text_classifier/mobilebert_tiny/keras_metadata.pb',
'models/text_classifier/mobilebert_tiny/saved_model.pb',
'models/text_classifier/mobilebert_tiny/assets/vocab.txt',