Change gesture_recognizer to download model files on-demand from GCS.

PiperOrigin-RevId: 508424508
This commit is contained in:
MediaPipe Team 2023-02-09 11:11:04 -08:00 committed by Copybara-Service
parent c3907229fe
commit 28f728bed5
11 changed files with 89 additions and 77 deletions

View File

@ -35,17 +35,3 @@ mediapipe_files(
"palm_detection_full.tflite", "palm_detection_full.tflite",
], ],
) )
filegroup(
name = "models",
srcs = [
"canned_gesture_classifier.tflite",
"gesture_embedder.tflite",
"gesture_embedder/keras_metadata.pb",
"gesture_embedder/saved_model.pb",
"gesture_embedder/variables/variables.data-00000-of-00001",
"gesture_embedder/variables/variables.index",
"hand_landmark_full.tflite",
"palm_detection_full.tflite",
],
)

View File

@ -35,7 +35,6 @@ py_library(
name = "model_util", name = "model_util",
srcs = ["model_util.py"], srcs = ["model_util.py"],
deps = [ deps = [
":file_util",
":quantization", ":quantization",
"//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/data:dataset",
], ],
@ -45,7 +44,6 @@ py_test(
name = "model_util_test", name = "model_util_test",
srcs = ["model_util_test.py"], srcs = ["model_util_test.py"],
deps = [ deps = [
":file_util",
":model_util", ":model_util",
":quantization", ":quantization",
":test_util", ":test_util",

View File

@ -27,7 +27,6 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset from mediapipe.model_maker.python.core.data import dataset
from mediapipe.model_maker.python.core.utils import file_util
from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import quantization
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0 DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
@ -53,8 +52,8 @@ def load_keras_model(model_path: str,
"""Loads a tensorflow Keras model from file and returns the Keras model. """Loads a tensorflow Keras model from file and returns the Keras model.
Args: Args:
model_path: Relative path to a directory containing model data, such as model_path: Absolute path to a directory containing model data, such as
<parent_path>/saved_model/. /<parent_path>/saved_model/.
compile_on_load: Whether the model should be compiled while loading. If compile_on_load: Whether the model should be compiled while loading. If
False, the model returned has to be compiled with the appropriate loss False, the model returned has to be compiled with the appropriate loss
function and custom metrics before running for inference on a test function and custom metrics before running for inference on a test
@ -63,22 +62,22 @@ def load_keras_model(model_path: str,
Returns: Returns:
A tensorflow Keras model. A tensorflow Keras model.
""" """
absolute_path = file_util.get_absolute_path(model_path)
return tf.keras.models.load_model( return tf.keras.models.load_model(
absolute_path, custom_objects={'tf': tf}, compile=compile_on_load) model_path, custom_objects={'tf': tf}, compile=compile_on_load
)
def load_tflite_model_buffer(model_path: str) -> bytearray: def load_tflite_model_buffer(model_path: str) -> bytearray:
"""Loads a TFLite model buffer from file. """Loads a TFLite model buffer from file.
Args: Args:
model_path: Relative path to a TFLite file model_path: Absolute path to a TFLite file, such as
/<parent_path>/<model_file>.tflite.
Returns: Returns:
A TFLite model buffer A TFLite model buffer
""" """
absolute_path = file_util.get_absolute_path(model_path) with tf.io.gfile.GFile(model_path, 'rb') as f:
with tf.io.gfile.GFile(absolute_path, 'rb') as f:
tflite_model_buffer = f.read() tflite_model_buffer = f.read()
return tflite_model_buffer return tflite_model_buffer

View File

@ -14,12 +14,10 @@
import os import os
from typing import Optional from typing import Optional
from unittest import mock as unittest_mock
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.utils import file_util
from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.core.utils import test_util from mediapipe.model_maker.python.core.utils import test_util
@ -27,15 +25,11 @@ from mediapipe.model_maker.python.core.utils import test_util
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
@unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True) def test_load_keras_model(self):
def test_load_keras_model(self, mock_get_absolute_path):
input_dim = 4 input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2) model = test_util.build_model(input_shape=[input_dim], num_classes=2)
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model') saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
model.save(saved_model_path) model.save(saved_model_path)
# model_util.load_keras_model takes in a relative path to files within the
# model_maker dir, so we patch the function for testing
mock_get_absolute_path.return_value = saved_model_path
loaded_model = model_util.load_keras_model(saved_model_path) loaded_model = model_util.load_keras_model(saved_model_path)
input_tensors = test_util.create_random_sample(size=[1, input_dim]) input_tensors = test_util.create_random_sample(size=[1, input_dim])
@ -43,16 +37,12 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
loaded_model_output = loaded_model.predict_on_batch(input_tensors) loaded_model_output = loaded_model.predict_on_batch(input_tensors)
self.assertTrue((model_output == loaded_model_output).all()) self.assertTrue((model_output == loaded_model_output).all())
@unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True) def test_load_tflite_model_buffer(self):
def test_load_tflite_model_buffer(self, mock_get_absolute_path):
input_dim = 4 input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2) model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_model = model_util.convert_to_tflite(model) tflite_model = model_util.convert_to_tflite(model)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file) model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
# model_util.load_tflite_model_buffer takes in a relative path to files
# within the model_maker dir, so we patch the function for testing
mock_get_absolute_path.return_value = tflite_file
tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file) tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file)
test_util.test_tflite( test_util.test_tflite(
keras_model=model, keras_model=model,

View File

@ -33,6 +33,9 @@ filegroup(
py_library( py_library(
name = "constants", name = "constants",
srcs = ["constants.py"], srcs = ["constants.py"],
deps = [
"//mediapipe/model_maker/python/core/utils:file_util",
],
) )
py_library( py_library(
@ -53,11 +56,11 @@ py_library(
py_test( py_test(
name = "dataset_test", name = "dataset_test",
srcs = ["dataset_test.py"], srcs = ["dataset_test.py"],
data = [ data = [":testdata"],
":testdata", tags = [
"//mediapipe/model_maker/models/gesture_recognizer:models", "notsan",
"requires-net:external",
], ],
tags = ["notsan"],
deps = [ deps = [
":dataset", ":dataset",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
@ -90,7 +93,6 @@ py_library(
py_library( py_library(
name = "gesture_recognizer", name = "gesture_recognizer",
srcs = ["gesture_recognizer.py"], srcs = ["gesture_recognizer.py"],
data = ["//mediapipe/model_maker/models/gesture_recognizer:models"],
deps = [ deps = [
":constants", ":constants",
":gesture_recognizer_options", ":gesture_recognizer_options",
@ -138,12 +140,12 @@ py_test(
name = "gesture_recognizer_test", name = "gesture_recognizer_test",
size = "large", size = "large",
srcs = ["gesture_recognizer_test.py"], srcs = ["gesture_recognizer_test.py"],
data = [ data = [":testdata"],
":testdata",
"//mediapipe/model_maker/models/gesture_recognizer:models",
],
shard_count = 2, shard_count = 2,
tags = ["notsan"], tags = [
"notsan",
"requires-net:external",
],
deps = [ deps = [
":gesture_recognizer_import", ":gesture_recognizer_import",
":hyperparameters", ":hyperparameters",
@ -156,9 +158,7 @@ py_test(
py_test( py_test(
name = "metadata_writer_test", name = "metadata_writer_test",
srcs = ["metadata_writer_test.py"], srcs = ["metadata_writer_test.py"],
data = [ data = [":testdata"],
":testdata",
],
deps = [ deps = [
":metadata_writer", ":metadata_writer",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
@ -169,10 +169,8 @@ py_test(
py_binary( py_binary(
name = "gesture_recognizer_demo", name = "gesture_recognizer_demo",
srcs = ["gesture_recognizer_demo.py"], srcs = ["gesture_recognizer_demo.py"],
data = [ data = [":testdata"],
":testdata",
"//mediapipe/model_maker/models/gesture_recognizer:models",
],
python_version = "PY3", python_version = "PY3",
tags = ["requires-net:external"],
deps = [":gesture_recognizer_import"], deps = [":gesture_recognizer_import"],
) )

View File

@ -13,8 +13,26 @@
# limitations under the License. # limitations under the License.
"""Gesture recognition constants.""" """Gesture recognition constants."""
GESTURE_EMBEDDER_KERAS_MODEL_PATH = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder' from mediapipe.model_maker.python.core.utils import file_util
GESTURE_EMBEDDER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder.tflite'
HAND_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/palm_detection_full.tflite' GESTURE_EMBEDDER_KERAS_MODEL_FILES = file_util.DownloadedFiles(
HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/hand_landmark_full.tflite' 'gesture_recognizer/gesture_embedder',
CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/canned_gesture_classifier.tflite' 'https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tar.gz',
is_folder=True,
)
GESTURE_EMBEDDER_TFLITE_MODEL_FILE = file_util.DownloadedFiles(
'gesture_recognizer/gesture_embedder.tflite',
'https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tflite',
)
HAND_DETECTOR_TFLITE_MODEL_FILE = file_util.DownloadedFiles(
'gesture_recognizer/palm_detection_full.tflite',
'https://storage.googleapis.com/mediapipe-assets/palm_detection_full.tflite',
)
HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE = file_util.DownloadedFiles(
'gesture_recognizer/hand_landmark_full.tflite',
'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite',
)
CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE = file_util.DownloadedFiles(
'gesture_recognizer/canned_gesture_classifier.tflite',
'https://storage.googleapis.com/mediapipe-assets/canned_gesture_classifier.tflite',
)

View File

@ -98,9 +98,11 @@ def _get_hand_data(all_image_paths: List[str],
""" """
hand_data_result = [] hand_data_result = []
hand_detector_model_buffer = model_util.load_tflite_model_buffer( hand_detector_model_buffer = model_util.load_tflite_model_buffer(
constants.HAND_DETECTOR_TFLITE_MODEL_FILE) constants.HAND_DETECTOR_TFLITE_MODEL_FILE.get_path()
)
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer( hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE) constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE.get_path()
)
hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter( hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter(
hand_detector_model_buffer, hand_landmarks_detector_model_buffer) hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
hand_landmarker_options = _HandLandmarkerOptions( hand_landmarker_options = _HandLandmarkerOptions(
@ -221,7 +223,8 @@ class Dataset(classification_dataset.ClassificationDataset):
hand_ds = tf.data.Dataset.from_tensor_slices(hand_data_dict) hand_ds = tf.data.Dataset.from_tensor_slices(hand_data_dict)
embedder_model = model_util.load_keras_model( embedder_model = model_util.load_keras_model(
constants.GESTURE_EMBEDDER_KERAS_MODEL_PATH) constants.GESTURE_EMBEDDER_KERAS_MODEL_FILES.get_path()
)
hand_ds = hand_ds.batch(batch_size=1) hand_ds = hand_ds.batch(batch_size=1)
hand_embedding_ds = hand_ds.map( hand_embedding_ds = hand_ds.map(

View File

@ -14,8 +14,9 @@
import os import os
import shutil import shutil
import tempfile
from typing import NamedTuple from typing import NamedTuple
import unittest from unittest import mock as unittest_mock
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
@ -29,6 +30,19 @@ _TEST_DATA_DIRNAME = 'raw_data'
class DatasetTest(tf.test.TestCase, parameterized.TestCase): class DatasetTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
# Mock tempfile.gettempdir() to be unique for each test to avoid race
# condition when downloading model since these tests may run in parallel.
mock_gettempdir = unittest_mock.patch.object(
tempfile,
'gettempdir',
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
def test_split(self): def test_split(self):
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME) input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
data = dataset.Dataset.from_folder( data = dataset.Dataset.from_folder(
@ -135,8 +149,9 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])), handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])),
) )
def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple): def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple):
with unittest.mock.patch.object( with unittest_mock.patch.object(
hand_landmarker.HandLandmarker, 'detect', return_value=hand): hand_landmarker.HandLandmarker, 'detect', return_value=hand
):
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME) input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
with self.assertRaisesRegex(ValueError, 'No valid hand is detected'): with self.assertRaisesRegex(ValueError, 'No valid hand is detected'):
dataset.Dataset.from_folder( dataset.Dataset.from_folder(

View File

@ -197,13 +197,17 @@ class GestureRecognizer(classifier.Classifier):
""" """
# TODO: Convert keras embedder model instead of using tflite # TODO: Convert keras embedder model instead of using tflite
gesture_embedding_model_buffer = model_util.load_tflite_model_buffer( gesture_embedding_model_buffer = model_util.load_tflite_model_buffer(
constants.GESTURE_EMBEDDER_TFLITE_MODEL_FILE) constants.GESTURE_EMBEDDER_TFLITE_MODEL_FILE.get_path()
)
hand_detector_model_buffer = model_util.load_tflite_model_buffer( hand_detector_model_buffer = model_util.load_tflite_model_buffer(
constants.HAND_DETECTOR_TFLITE_MODEL_FILE) constants.HAND_DETECTOR_TFLITE_MODEL_FILE.get_path()
)
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer( hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE) constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE.get_path()
)
canned_gesture_model_buffer = model_util.load_tflite_model_buffer( canned_gesture_model_buffer = model_util.load_tflite_model_buffer(
constants.CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE) constants.CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE.get_path()
)
if not tf.io.gfile.exists(self._hparams.export_dir): if not tf.io.gfile.exists(self._hparams.export_dir):
tf.io.gfile.makedirs(self._hparams.export_dir) tf.io.gfile.makedirs(self._hparams.export_dir)

View File

@ -48,6 +48,16 @@ class GestureRecognizerTest(tf.test.TestCase):
all_data = self._load_data() all_data = self._load_data()
# Splits data, 90% data for training, 10% for validation # Splits data, 90% data for training, 10% for validation
self._train_data, self._validation_data = all_data.split(0.9) self._train_data, self._validation_data = all_data.split(0.9)
# Mock tempfile.gettempdir() to be unique for each test to avoid race
# condition when downloading model since these tests may run in parallel.
mock_gettempdir = unittest_mock.patch.object(
tempfile,
'gettempdir',
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
def test_gesture_recognizer_model(self): def test_gesture_recognizer_model(self):
mo = gesture_recognizer.ModelOptions() mo = gesture_recognizer.ModelOptions()

View File

@ -82,7 +82,6 @@ def _setup_build_dir():
# Use bazel to download GCS model files # Use bazel to download GCS model files
model_build_files = [ model_build_files = [
'models/gesture_recognizer/BUILD',
'models/text_classifier/BUILD', 'models/text_classifier/BUILD',
] ]
for model_build_file in model_build_files: for model_build_file in model_build_files:
@ -90,14 +89,6 @@ def _setup_build_dir():
os.makedirs(os.path.dirname(build_target_file), exist_ok=True) os.makedirs(os.path.dirname(build_target_file), exist_ok=True)
shutil.copy(model_build_file, build_target_file) shutil.copy(model_build_file, build_target_file)
external_files = [ external_files = [
'models/gesture_recognizer/canned_gesture_classifier.tflite',
'models/gesture_recognizer/gesture_embedder.tflite',
'models/gesture_recognizer/hand_landmark_full.tflite',
'models/gesture_recognizer/palm_detection_full.tflite',
'models/gesture_recognizer/gesture_embedder/keras_metadata.pb',
'models/gesture_recognizer/gesture_embedder/saved_model.pb',
'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001',
'models/gesture_recognizer/gesture_embedder/variables/variables.index',
'models/text_classifier/mobilebert_tiny/keras_metadata.pb', 'models/text_classifier/mobilebert_tiny/keras_metadata.pb',
'models/text_classifier/mobilebert_tiny/saved_model.pb', 'models/text_classifier/mobilebert_tiny/saved_model.pb',
'models/text_classifier/mobilebert_tiny/assets/vocab.txt', 'models/text_classifier/mobilebert_tiny/assets/vocab.txt',