From 28f728bed5e27a1ee5a01de54bc9a4f04ef7b5f8 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 9 Feb 2023 11:11:04 -0800 Subject: [PATCH] Change gesture_recognizer to download model files on-demand from GCS. PiperOrigin-RevId: 508424508 --- .../models/gesture_recognizer/BUILD | 14 -------- mediapipe/model_maker/python/core/utils/BUILD | 2 -- .../python/core/utils/model_util.py | 15 ++++----- .../python/core/utils/model_util_test.py | 14 ++------ .../python/vision/gesture_recognizer/BUILD | 32 +++++++++---------- .../vision/gesture_recognizer/constants.py | 28 +++++++++++++--- .../vision/gesture_recognizer/dataset.py | 9 ++++-- .../vision/gesture_recognizer/dataset_test.py | 21 ++++++++++-- .../gesture_recognizer/gesture_recognizer.py | 12 ++++--- .../gesture_recognizer_test.py | 10 ++++++ mediapipe/model_maker/setup.py | 9 ------ 11 files changed, 89 insertions(+), 77 deletions(-) diff --git a/mediapipe/model_maker/models/gesture_recognizer/BUILD b/mediapipe/model_maker/models/gesture_recognizer/BUILD index f8e5cdd21..5ead0e618 100644 --- a/mediapipe/model_maker/models/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/models/gesture_recognizer/BUILD @@ -35,17 +35,3 @@ mediapipe_files( "palm_detection_full.tflite", ], ) - -filegroup( - name = "models", - srcs = [ - "canned_gesture_classifier.tflite", - "gesture_embedder.tflite", - "gesture_embedder/keras_metadata.pb", - "gesture_embedder/saved_model.pb", - "gesture_embedder/variables/variables.data-00000-of-00001", - "gesture_embedder/variables/variables.index", - "hand_landmark_full.tflite", - "palm_detection_full.tflite", - ], -) diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index 3c9107dba..43c3d42f9 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -35,7 +35,6 @@ py_library( name = "model_util", srcs = ["model_util.py"], deps = [ - ":file_util", ":quantization", "//mediapipe/model_maker/python/core/data:dataset", ], @@ -45,7 +44,6 @@ py_test( name = "model_util_test", srcs = ["model_util_test.py"], deps = [ - ":file_util", ":model_util", ":quantization", ":test_util", diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index db02444df..69a8654ec 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -27,7 +27,6 @@ import numpy as np import tensorflow as tf from mediapipe.model_maker.python.core.data import dataset -from mediapipe.model_maker.python.core.utils import file_util from mediapipe.model_maker.python.core.utils import quantization DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0 @@ -53,8 +52,8 @@ def load_keras_model(model_path: str, """Loads a tensorflow Keras model from file and returns the Keras model. Args: - model_path: Relative path to a directory containing model data, such as - /saved_model/. + model_path: Absolute path to a directory containing model data, such as + //saved_model/. compile_on_load: Whether the model should be compiled while loading. If False, the model returned has to be compiled with the appropriate loss function and custom metrics before running for inference on a test @@ -63,22 +62,22 @@ def load_keras_model(model_path: str, Returns: A tensorflow Keras model. """ - absolute_path = file_util.get_absolute_path(model_path) return tf.keras.models.load_model( - absolute_path, custom_objects={'tf': tf}, compile=compile_on_load) + model_path, custom_objects={'tf': tf}, compile=compile_on_load + ) def load_tflite_model_buffer(model_path: str) -> bytearray: """Loads a TFLite model buffer from file. Args: - model_path: Relative path to a TFLite file + model_path: Absolute path to a TFLite file, such as + //.tflite. Returns: A TFLite model buffer """ - absolute_path = file_util.get_absolute_path(model_path) - with tf.io.gfile.GFile(absolute_path, 'rb') as f: + with tf.io.gfile.GFile(model_path, 'rb') as f: tflite_model_buffer = f.read() return tflite_model_buffer diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py index f0020db25..6961a5fc7 100644 --- a/mediapipe/model_maker/python/core/utils/model_util_test.py +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -14,12 +14,10 @@ import os from typing import Optional -from unittest import mock as unittest_mock from absl.testing import parameterized import tensorflow as tf -from mediapipe.model_maker.python.core.utils import file_util from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import test_util @@ -27,15 +25,11 @@ from mediapipe.model_maker.python.core.utils import test_util class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): - @unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True) - def test_load_keras_model(self, mock_get_absolute_path): + def test_load_keras_model(self): input_dim = 4 model = test_util.build_model(input_shape=[input_dim], num_classes=2) saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model') model.save(saved_model_path) - # model_util.load_keras_model takes in a relative path to files within the - # model_maker dir, so we patch the function for testing - mock_get_absolute_path.return_value = saved_model_path loaded_model = model_util.load_keras_model(saved_model_path) input_tensors = test_util.create_random_sample(size=[1, input_dim]) @@ -43,16 +37,12 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): loaded_model_output = loaded_model.predict_on_batch(input_tensors) self.assertTrue((model_output == loaded_model_output).all()) - @unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True) - def test_load_tflite_model_buffer(self, mock_get_absolute_path): + def test_load_tflite_model_buffer(self): input_dim = 4 model = test_util.build_model(input_shape=[input_dim], num_classes=2) tflite_model = model_util.convert_to_tflite(model) tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file) - # model_util.load_tflite_model_buffer takes in a relative path to files - # within the model_maker dir, so we patch the function for testing - mock_get_absolute_path.return_value = tflite_file tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file) test_util.test_tflite( keras_model=model, diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index cbdff7cf3..2dad9a617 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -33,6 +33,9 @@ filegroup( py_library( name = "constants", srcs = ["constants.py"], + deps = [ + "//mediapipe/model_maker/python/core/utils:file_util", + ], ) py_library( @@ -53,11 +56,11 @@ py_library( py_test( name = "dataset_test", srcs = ["dataset_test.py"], - data = [ - ":testdata", - "//mediapipe/model_maker/models/gesture_recognizer:models", + data = [":testdata"], + tags = [ + "notsan", + "requires-net:external", ], - tags = ["notsan"], deps = [ ":dataset", "//mediapipe/tasks/python/test:test_utils", @@ -90,7 +93,6 @@ py_library( py_library( name = "gesture_recognizer", srcs = ["gesture_recognizer.py"], - data = ["//mediapipe/model_maker/models/gesture_recognizer:models"], deps = [ ":constants", ":gesture_recognizer_options", @@ -138,12 +140,12 @@ py_test( name = "gesture_recognizer_test", size = "large", srcs = ["gesture_recognizer_test.py"], - data = [ - ":testdata", - "//mediapipe/model_maker/models/gesture_recognizer:models", - ], + data = [":testdata"], shard_count = 2, - tags = ["notsan"], + tags = [ + "notsan", + "requires-net:external", + ], deps = [ ":gesture_recognizer_import", ":hyperparameters", @@ -156,9 +158,7 @@ py_test( py_test( name = "metadata_writer_test", srcs = ["metadata_writer_test.py"], - data = [ - ":testdata", - ], + data = [":testdata"], deps = [ ":metadata_writer", "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", @@ -169,10 +169,8 @@ py_test( py_binary( name = "gesture_recognizer_demo", srcs = ["gesture_recognizer_demo.py"], - data = [ - ":testdata", - "//mediapipe/model_maker/models/gesture_recognizer:models", - ], + data = [":testdata"], python_version = "PY3", + tags = ["requires-net:external"], deps = [":gesture_recognizer_import"], ) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/constants.py b/mediapipe/model_maker/python/vision/gesture_recognizer/constants.py index ac9bba12a..acd569d0e 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/constants.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/constants.py @@ -13,8 +13,26 @@ # limitations under the License. """Gesture recognition constants.""" -GESTURE_EMBEDDER_KERAS_MODEL_PATH = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder' -GESTURE_EMBEDDER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder.tflite' -HAND_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/palm_detection_full.tflite' -HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/hand_landmark_full.tflite' -CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/canned_gesture_classifier.tflite' +from mediapipe.model_maker.python.core.utils import file_util + +GESTURE_EMBEDDER_KERAS_MODEL_FILES = file_util.DownloadedFiles( + 'gesture_recognizer/gesture_embedder', + 'https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tar.gz', + is_folder=True, +) +GESTURE_EMBEDDER_TFLITE_MODEL_FILE = file_util.DownloadedFiles( + 'gesture_recognizer/gesture_embedder.tflite', + 'https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tflite', +) +HAND_DETECTOR_TFLITE_MODEL_FILE = file_util.DownloadedFiles( + 'gesture_recognizer/palm_detection_full.tflite', + 'https://storage.googleapis.com/mediapipe-assets/palm_detection_full.tflite', +) +HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE = file_util.DownloadedFiles( + 'gesture_recognizer/hand_landmark_full.tflite', + 'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite', +) +CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE = file_util.DownloadedFiles( + 'gesture_recognizer/canned_gesture_classifier.tflite', + 'https://storage.googleapis.com/mediapipe-assets/canned_gesture_classifier.tflite', +) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py index 6a2c878c0..70a363f1a 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py @@ -98,9 +98,11 @@ def _get_hand_data(all_image_paths: List[str], """ hand_data_result = [] hand_detector_model_buffer = model_util.load_tflite_model_buffer( - constants.HAND_DETECTOR_TFLITE_MODEL_FILE) + constants.HAND_DETECTOR_TFLITE_MODEL_FILE.get_path() + ) hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer( - constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE) + constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE.get_path() + ) hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter( hand_detector_model_buffer, hand_landmarks_detector_model_buffer) hand_landmarker_options = _HandLandmarkerOptions( @@ -221,7 +223,8 @@ class Dataset(classification_dataset.ClassificationDataset): hand_ds = tf.data.Dataset.from_tensor_slices(hand_data_dict) embedder_model = model_util.load_keras_model( - constants.GESTURE_EMBEDDER_KERAS_MODEL_PATH) + constants.GESTURE_EMBEDDER_KERAS_MODEL_FILES.get_path() + ) hand_ds = hand_ds.batch(batch_size=1) hand_embedding_ds = hand_ds.map( diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py index 528d02edd..e9e7ddd06 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py @@ -14,8 +14,9 @@ import os import shutil +import tempfile from typing import NamedTuple -import unittest +from unittest import mock as unittest_mock from absl.testing import parameterized import tensorflow as tf @@ -29,6 +30,19 @@ _TEST_DATA_DIRNAME = 'raw_data' class DatasetTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + # Mock tempfile.gettempdir() to be unique for each test to avoid race + # condition when downloading model since these tests may run in parallel. + mock_gettempdir = unittest_mock.patch.object( + tempfile, + 'gettempdir', + return_value=self.create_tempdir(), + autospec=True, + ) + self.mock_gettempdir = mock_gettempdir.start() + self.addCleanup(mock_gettempdir.stop) + def test_split(self): input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME) data = dataset.Dataset.from_folder( @@ -135,8 +149,9 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase): handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])), ) def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple): - with unittest.mock.patch.object( - hand_landmarker.HandLandmarker, 'detect', return_value=hand): + with unittest_mock.patch.object( + hand_landmarker.HandLandmarker, 'detect', return_value=hand + ): input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME) with self.assertRaisesRegex(ValueError, 'No valid hand is detected'): dataset.Dataset.from_folder( diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py index b27f7161f..f009ef281 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py @@ -197,13 +197,17 @@ class GestureRecognizer(classifier.Classifier): """ # TODO: Convert keras embedder model instead of using tflite gesture_embedding_model_buffer = model_util.load_tflite_model_buffer( - constants.GESTURE_EMBEDDER_TFLITE_MODEL_FILE) + constants.GESTURE_EMBEDDER_TFLITE_MODEL_FILE.get_path() + ) hand_detector_model_buffer = model_util.load_tflite_model_buffer( - constants.HAND_DETECTOR_TFLITE_MODEL_FILE) + constants.HAND_DETECTOR_TFLITE_MODEL_FILE.get_path() + ) hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer( - constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE) + constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE.get_path() + ) canned_gesture_model_buffer = model_util.load_tflite_model_buffer( - constants.CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE) + constants.CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE.get_path() + ) if not tf.io.gfile.exists(self._hparams.export_dir): tf.io.gfile.makedirs(self._hparams.export_dir) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 4fdb74225..ce167df93 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -48,6 +48,16 @@ class GestureRecognizerTest(tf.test.TestCase): all_data = self._load_data() # Splits data, 90% data for training, 10% for validation self._train_data, self._validation_data = all_data.split(0.9) + # Mock tempfile.gettempdir() to be unique for each test to avoid race + # condition when downloading model since these tests may run in parallel. + mock_gettempdir = unittest_mock.patch.object( + tempfile, + 'gettempdir', + return_value=self.create_tempdir(), + autospec=True, + ) + self.mock_gettempdir = mock_gettempdir.start() + self.addCleanup(mock_gettempdir.stop) def test_gesture_recognizer_model(self): mo = gesture_recognizer.ModelOptions() diff --git a/mediapipe/model_maker/setup.py b/mediapipe/model_maker/setup.py index 1dac6301a..63a1f2056 100644 --- a/mediapipe/model_maker/setup.py +++ b/mediapipe/model_maker/setup.py @@ -82,7 +82,6 @@ def _setup_build_dir(): # Use bazel to download GCS model files model_build_files = [ - 'models/gesture_recognizer/BUILD', 'models/text_classifier/BUILD', ] for model_build_file in model_build_files: @@ -90,14 +89,6 @@ def _setup_build_dir(): os.makedirs(os.path.dirname(build_target_file), exist_ok=True) shutil.copy(model_build_file, build_target_file) external_files = [ - 'models/gesture_recognizer/canned_gesture_classifier.tflite', - 'models/gesture_recognizer/gesture_embedder.tflite', - 'models/gesture_recognizer/hand_landmark_full.tflite', - 'models/gesture_recognizer/palm_detection_full.tflite', - 'models/gesture_recognizer/gesture_embedder/keras_metadata.pb', - 'models/gesture_recognizer/gesture_embedder/saved_model.pb', - 'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001', - 'models/gesture_recognizer/gesture_embedder/variables/variables.index', 'models/text_classifier/mobilebert_tiny/keras_metadata.pb', 'models/text_classifier/mobilebert_tiny/saved_model.pb', 'models/text_classifier/mobilebert_tiny/assets/vocab.txt',