Change gesture_recognizer to download model files on-demand from GCS.
PiperOrigin-RevId: 508424508
This commit is contained in:
parent
c3907229fe
commit
28f728bed5
|
@ -35,17 +35,3 @@ mediapipe_files(
|
||||||
"palm_detection_full.tflite",
|
"palm_detection_full.tflite",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
filegroup(
|
|
||||||
name = "models",
|
|
||||||
srcs = [
|
|
||||||
"canned_gesture_classifier.tflite",
|
|
||||||
"gesture_embedder.tflite",
|
|
||||||
"gesture_embedder/keras_metadata.pb",
|
|
||||||
"gesture_embedder/saved_model.pb",
|
|
||||||
"gesture_embedder/variables/variables.data-00000-of-00001",
|
|
||||||
"gesture_embedder/variables/variables.index",
|
|
||||||
"hand_landmark_full.tflite",
|
|
||||||
"palm_detection_full.tflite",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
|
@ -35,7 +35,6 @@ py_library(
|
||||||
name = "model_util",
|
name = "model_util",
|
||||||
srcs = ["model_util.py"],
|
srcs = ["model_util.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":file_util",
|
|
||||||
":quantization",
|
":quantization",
|
||||||
"//mediapipe/model_maker/python/core/data:dataset",
|
"//mediapipe/model_maker/python/core/data:dataset",
|
||||||
],
|
],
|
||||||
|
@ -45,7 +44,6 @@ py_test(
|
||||||
name = "model_util_test",
|
name = "model_util_test",
|
||||||
srcs = ["model_util_test.py"],
|
srcs = ["model_util_test.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":file_util",
|
|
||||||
":model_util",
|
":model_util",
|
||||||
":quantization",
|
":quantization",
|
||||||
":test_util",
|
":test_util",
|
||||||
|
|
|
@ -27,7 +27,6 @@ import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.data import dataset
|
from mediapipe.model_maker.python.core.data import dataset
|
||||||
from mediapipe.model_maker.python.core.utils import file_util
|
|
||||||
from mediapipe.model_maker.python.core.utils import quantization
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
|
|
||||||
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
|
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
|
||||||
|
@ -53,8 +52,8 @@ def load_keras_model(model_path: str,
|
||||||
"""Loads a tensorflow Keras model from file and returns the Keras model.
|
"""Loads a tensorflow Keras model from file and returns the Keras model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path: Relative path to a directory containing model data, such as
|
model_path: Absolute path to a directory containing model data, such as
|
||||||
<parent_path>/saved_model/.
|
/<parent_path>/saved_model/.
|
||||||
compile_on_load: Whether the model should be compiled while loading. If
|
compile_on_load: Whether the model should be compiled while loading. If
|
||||||
False, the model returned has to be compiled with the appropriate loss
|
False, the model returned has to be compiled with the appropriate loss
|
||||||
function and custom metrics before running for inference on a test
|
function and custom metrics before running for inference on a test
|
||||||
|
@ -63,22 +62,22 @@ def load_keras_model(model_path: str,
|
||||||
Returns:
|
Returns:
|
||||||
A tensorflow Keras model.
|
A tensorflow Keras model.
|
||||||
"""
|
"""
|
||||||
absolute_path = file_util.get_absolute_path(model_path)
|
|
||||||
return tf.keras.models.load_model(
|
return tf.keras.models.load_model(
|
||||||
absolute_path, custom_objects={'tf': tf}, compile=compile_on_load)
|
model_path, custom_objects={'tf': tf}, compile=compile_on_load
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_tflite_model_buffer(model_path: str) -> bytearray:
|
def load_tflite_model_buffer(model_path: str) -> bytearray:
|
||||||
"""Loads a TFLite model buffer from file.
|
"""Loads a TFLite model buffer from file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path: Relative path to a TFLite file
|
model_path: Absolute path to a TFLite file, such as
|
||||||
|
/<parent_path>/<model_file>.tflite.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A TFLite model buffer
|
A TFLite model buffer
|
||||||
"""
|
"""
|
||||||
absolute_path = file_util.get_absolute_path(model_path)
|
with tf.io.gfile.GFile(model_path, 'rb') as f:
|
||||||
with tf.io.gfile.GFile(absolute_path, 'rb') as f:
|
|
||||||
tflite_model_buffer = f.read()
|
tflite_model_buffer = f.read()
|
||||||
return tflite_model_buffer
|
return tflite_model_buffer
|
||||||
|
|
||||||
|
|
|
@ -14,12 +14,10 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from unittest import mock as unittest_mock
|
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.utils import file_util
|
|
||||||
from mediapipe.model_maker.python.core.utils import model_util
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
from mediapipe.model_maker.python.core.utils import quantization
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
from mediapipe.model_maker.python.core.utils import test_util
|
from mediapipe.model_maker.python.core.utils import test_util
|
||||||
|
@ -27,15 +25,11 @@ from mediapipe.model_maker.python.core.utils import test_util
|
||||||
|
|
||||||
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True)
|
def test_load_keras_model(self):
|
||||||
def test_load_keras_model(self, mock_get_absolute_path):
|
|
||||||
input_dim = 4
|
input_dim = 4
|
||||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||||
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
|
saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||||
model.save(saved_model_path)
|
model.save(saved_model_path)
|
||||||
# model_util.load_keras_model takes in a relative path to files within the
|
|
||||||
# model_maker dir, so we patch the function for testing
|
|
||||||
mock_get_absolute_path.return_value = saved_model_path
|
|
||||||
loaded_model = model_util.load_keras_model(saved_model_path)
|
loaded_model = model_util.load_keras_model(saved_model_path)
|
||||||
|
|
||||||
input_tensors = test_util.create_random_sample(size=[1, input_dim])
|
input_tensors = test_util.create_random_sample(size=[1, input_dim])
|
||||||
|
@ -43,16 +37,12 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
loaded_model_output = loaded_model.predict_on_batch(input_tensors)
|
loaded_model_output = loaded_model.predict_on_batch(input_tensors)
|
||||||
self.assertTrue((model_output == loaded_model_output).all())
|
self.assertTrue((model_output == loaded_model_output).all())
|
||||||
|
|
||||||
@unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True)
|
def test_load_tflite_model_buffer(self):
|
||||||
def test_load_tflite_model_buffer(self, mock_get_absolute_path):
|
|
||||||
input_dim = 4
|
input_dim = 4
|
||||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||||
tflite_model = model_util.convert_to_tflite(model)
|
tflite_model = model_util.convert_to_tflite(model)
|
||||||
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||||
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
|
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
|
||||||
# model_util.load_tflite_model_buffer takes in a relative path to files
|
|
||||||
# within the model_maker dir, so we patch the function for testing
|
|
||||||
mock_get_absolute_path.return_value = tflite_file
|
|
||||||
tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file)
|
tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file)
|
||||||
test_util.test_tflite(
|
test_util.test_tflite(
|
||||||
keras_model=model,
|
keras_model=model,
|
||||||
|
|
|
@ -33,6 +33,9 @@ filegroup(
|
||||||
py_library(
|
py_library(
|
||||||
name = "constants",
|
name = "constants",
|
||||||
srcs = ["constants.py"],
|
srcs = ["constants.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/model_maker/python/core/utils:file_util",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
@ -53,11 +56,11 @@ py_library(
|
||||||
py_test(
|
py_test(
|
||||||
name = "dataset_test",
|
name = "dataset_test",
|
||||||
srcs = ["dataset_test.py"],
|
srcs = ["dataset_test.py"],
|
||||||
data = [
|
data = [":testdata"],
|
||||||
":testdata",
|
tags = [
|
||||||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
"notsan",
|
||||||
|
"requires-net:external",
|
||||||
],
|
],
|
||||||
tags = ["notsan"],
|
|
||||||
deps = [
|
deps = [
|
||||||
":dataset",
|
":dataset",
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
@ -90,7 +93,6 @@ py_library(
|
||||||
py_library(
|
py_library(
|
||||||
name = "gesture_recognizer",
|
name = "gesture_recognizer",
|
||||||
srcs = ["gesture_recognizer.py"],
|
srcs = ["gesture_recognizer.py"],
|
||||||
data = ["//mediapipe/model_maker/models/gesture_recognizer:models"],
|
|
||||||
deps = [
|
deps = [
|
||||||
":constants",
|
":constants",
|
||||||
":gesture_recognizer_options",
|
":gesture_recognizer_options",
|
||||||
|
@ -138,12 +140,12 @@ py_test(
|
||||||
name = "gesture_recognizer_test",
|
name = "gesture_recognizer_test",
|
||||||
size = "large",
|
size = "large",
|
||||||
srcs = ["gesture_recognizer_test.py"],
|
srcs = ["gesture_recognizer_test.py"],
|
||||||
data = [
|
data = [":testdata"],
|
||||||
":testdata",
|
|
||||||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
|
||||||
],
|
|
||||||
shard_count = 2,
|
shard_count = 2,
|
||||||
tags = ["notsan"],
|
tags = [
|
||||||
|
"notsan",
|
||||||
|
"requires-net:external",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":gesture_recognizer_import",
|
":gesture_recognizer_import",
|
||||||
":hyperparameters",
|
":hyperparameters",
|
||||||
|
@ -156,9 +158,7 @@ py_test(
|
||||||
py_test(
|
py_test(
|
||||||
name = "metadata_writer_test",
|
name = "metadata_writer_test",
|
||||||
srcs = ["metadata_writer_test.py"],
|
srcs = ["metadata_writer_test.py"],
|
||||||
data = [
|
data = [":testdata"],
|
||||||
":testdata",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":metadata_writer",
|
":metadata_writer",
|
||||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||||
|
@ -169,10 +169,8 @@ py_test(
|
||||||
py_binary(
|
py_binary(
|
||||||
name = "gesture_recognizer_demo",
|
name = "gesture_recognizer_demo",
|
||||||
srcs = ["gesture_recognizer_demo.py"],
|
srcs = ["gesture_recognizer_demo.py"],
|
||||||
data = [
|
data = [":testdata"],
|
||||||
":testdata",
|
|
||||||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
|
||||||
],
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
|
tags = ["requires-net:external"],
|
||||||
deps = [":gesture_recognizer_import"],
|
deps = [":gesture_recognizer_import"],
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,8 +13,26 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Gesture recognition constants."""
|
"""Gesture recognition constants."""
|
||||||
|
|
||||||
GESTURE_EMBEDDER_KERAS_MODEL_PATH = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder'
|
from mediapipe.model_maker.python.core.utils import file_util
|
||||||
GESTURE_EMBEDDER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder.tflite'
|
|
||||||
HAND_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/palm_detection_full.tflite'
|
GESTURE_EMBEDDER_KERAS_MODEL_FILES = file_util.DownloadedFiles(
|
||||||
HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/hand_landmark_full.tflite'
|
'gesture_recognizer/gesture_embedder',
|
||||||
CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/canned_gesture_classifier.tflite'
|
'https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tar.gz',
|
||||||
|
is_folder=True,
|
||||||
|
)
|
||||||
|
GESTURE_EMBEDDER_TFLITE_MODEL_FILE = file_util.DownloadedFiles(
|
||||||
|
'gesture_recognizer/gesture_embedder.tflite',
|
||||||
|
'https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tflite',
|
||||||
|
)
|
||||||
|
HAND_DETECTOR_TFLITE_MODEL_FILE = file_util.DownloadedFiles(
|
||||||
|
'gesture_recognizer/palm_detection_full.tflite',
|
||||||
|
'https://storage.googleapis.com/mediapipe-assets/palm_detection_full.tflite',
|
||||||
|
)
|
||||||
|
HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE = file_util.DownloadedFiles(
|
||||||
|
'gesture_recognizer/hand_landmark_full.tflite',
|
||||||
|
'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite',
|
||||||
|
)
|
||||||
|
CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE = file_util.DownloadedFiles(
|
||||||
|
'gesture_recognizer/canned_gesture_classifier.tflite',
|
||||||
|
'https://storage.googleapis.com/mediapipe-assets/canned_gesture_classifier.tflite',
|
||||||
|
)
|
||||||
|
|
|
@ -98,9 +98,11 @@ def _get_hand_data(all_image_paths: List[str],
|
||||||
"""
|
"""
|
||||||
hand_data_result = []
|
hand_data_result = []
|
||||||
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
|
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||||
constants.HAND_DETECTOR_TFLITE_MODEL_FILE)
|
constants.HAND_DETECTOR_TFLITE_MODEL_FILE.get_path()
|
||||||
|
)
|
||||||
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
|
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||||
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE)
|
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE.get_path()
|
||||||
|
)
|
||||||
hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter(
|
hand_landmarker_writer = metadata_writer.HandLandmarkerMetadataWriter(
|
||||||
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
|
hand_detector_model_buffer, hand_landmarks_detector_model_buffer)
|
||||||
hand_landmarker_options = _HandLandmarkerOptions(
|
hand_landmarker_options = _HandLandmarkerOptions(
|
||||||
|
@ -221,7 +223,8 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
hand_ds = tf.data.Dataset.from_tensor_slices(hand_data_dict)
|
hand_ds = tf.data.Dataset.from_tensor_slices(hand_data_dict)
|
||||||
|
|
||||||
embedder_model = model_util.load_keras_model(
|
embedder_model = model_util.load_keras_model(
|
||||||
constants.GESTURE_EMBEDDER_KERAS_MODEL_PATH)
|
constants.GESTURE_EMBEDDER_KERAS_MODEL_FILES.get_path()
|
||||||
|
)
|
||||||
|
|
||||||
hand_ds = hand_ds.batch(batch_size=1)
|
hand_ds = hand_ds.batch(batch_size=1)
|
||||||
hand_embedding_ds = hand_ds.map(
|
hand_embedding_ds = hand_ds.map(
|
||||||
|
|
|
@ -14,8 +14,9 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import tempfile
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
import unittest
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -29,6 +30,19 @@ _TEST_DATA_DIRNAME = 'raw_data'
|
||||||
|
|
||||||
class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||||
|
# condition when downloading model since these tests may run in parallel.
|
||||||
|
mock_gettempdir = unittest_mock.patch.object(
|
||||||
|
tempfile,
|
||||||
|
'gettempdir',
|
||||||
|
return_value=self.create_tempdir(),
|
||||||
|
autospec=True,
|
||||||
|
)
|
||||||
|
self.mock_gettempdir = mock_gettempdir.start()
|
||||||
|
self.addCleanup(mock_gettempdir.stop)
|
||||||
|
|
||||||
def test_split(self):
|
def test_split(self):
|
||||||
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||||
data = dataset.Dataset.from_folder(
|
data = dataset.Dataset.from_folder(
|
||||||
|
@ -135,8 +149,9 @@ class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])),
|
handedness=[[1]], hand_landmarks=[[2]], hand_world_landmarks=[])),
|
||||||
)
|
)
|
||||||
def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple):
|
def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple):
|
||||||
with unittest.mock.patch.object(
|
with unittest_mock.patch.object(
|
||||||
hand_landmarker.HandLandmarker, 'detect', return_value=hand):
|
hand_landmarker.HandLandmarker, 'detect', return_value=hand
|
||||||
|
):
|
||||||
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||||
with self.assertRaisesRegex(ValueError, 'No valid hand is detected'):
|
with self.assertRaisesRegex(ValueError, 'No valid hand is detected'):
|
||||||
dataset.Dataset.from_folder(
|
dataset.Dataset.from_folder(
|
||||||
|
|
|
@ -197,13 +197,17 @@ class GestureRecognizer(classifier.Classifier):
|
||||||
"""
|
"""
|
||||||
# TODO: Convert keras embedder model instead of using tflite
|
# TODO: Convert keras embedder model instead of using tflite
|
||||||
gesture_embedding_model_buffer = model_util.load_tflite_model_buffer(
|
gesture_embedding_model_buffer = model_util.load_tflite_model_buffer(
|
||||||
constants.GESTURE_EMBEDDER_TFLITE_MODEL_FILE)
|
constants.GESTURE_EMBEDDER_TFLITE_MODEL_FILE.get_path()
|
||||||
|
)
|
||||||
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
|
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||||
constants.HAND_DETECTOR_TFLITE_MODEL_FILE)
|
constants.HAND_DETECTOR_TFLITE_MODEL_FILE.get_path()
|
||||||
|
)
|
||||||
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
|
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||||
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE)
|
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE.get_path()
|
||||||
|
)
|
||||||
canned_gesture_model_buffer = model_util.load_tflite_model_buffer(
|
canned_gesture_model_buffer = model_util.load_tflite_model_buffer(
|
||||||
constants.CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE)
|
constants.CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE.get_path()
|
||||||
|
)
|
||||||
|
|
||||||
if not tf.io.gfile.exists(self._hparams.export_dir):
|
if not tf.io.gfile.exists(self._hparams.export_dir):
|
||||||
tf.io.gfile.makedirs(self._hparams.export_dir)
|
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||||
|
|
|
@ -48,6 +48,16 @@ class GestureRecognizerTest(tf.test.TestCase):
|
||||||
all_data = self._load_data()
|
all_data = self._load_data()
|
||||||
# Splits data, 90% data for training, 10% for validation
|
# Splits data, 90% data for training, 10% for validation
|
||||||
self._train_data, self._validation_data = all_data.split(0.9)
|
self._train_data, self._validation_data = all_data.split(0.9)
|
||||||
|
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||||
|
# condition when downloading model since these tests may run in parallel.
|
||||||
|
mock_gettempdir = unittest_mock.patch.object(
|
||||||
|
tempfile,
|
||||||
|
'gettempdir',
|
||||||
|
return_value=self.create_tempdir(),
|
||||||
|
autospec=True,
|
||||||
|
)
|
||||||
|
self.mock_gettempdir = mock_gettempdir.start()
|
||||||
|
self.addCleanup(mock_gettempdir.stop)
|
||||||
|
|
||||||
def test_gesture_recognizer_model(self):
|
def test_gesture_recognizer_model(self):
|
||||||
mo = gesture_recognizer.ModelOptions()
|
mo = gesture_recognizer.ModelOptions()
|
||||||
|
|
|
@ -82,7 +82,6 @@ def _setup_build_dir():
|
||||||
|
|
||||||
# Use bazel to download GCS model files
|
# Use bazel to download GCS model files
|
||||||
model_build_files = [
|
model_build_files = [
|
||||||
'models/gesture_recognizer/BUILD',
|
|
||||||
'models/text_classifier/BUILD',
|
'models/text_classifier/BUILD',
|
||||||
]
|
]
|
||||||
for model_build_file in model_build_files:
|
for model_build_file in model_build_files:
|
||||||
|
@ -90,14 +89,6 @@ def _setup_build_dir():
|
||||||
os.makedirs(os.path.dirname(build_target_file), exist_ok=True)
|
os.makedirs(os.path.dirname(build_target_file), exist_ok=True)
|
||||||
shutil.copy(model_build_file, build_target_file)
|
shutil.copy(model_build_file, build_target_file)
|
||||||
external_files = [
|
external_files = [
|
||||||
'models/gesture_recognizer/canned_gesture_classifier.tflite',
|
|
||||||
'models/gesture_recognizer/gesture_embedder.tflite',
|
|
||||||
'models/gesture_recognizer/hand_landmark_full.tflite',
|
|
||||||
'models/gesture_recognizer/palm_detection_full.tflite',
|
|
||||||
'models/gesture_recognizer/gesture_embedder/keras_metadata.pb',
|
|
||||||
'models/gesture_recognizer/gesture_embedder/saved_model.pb',
|
|
||||||
'models/gesture_recognizer/gesture_embedder/variables/variables.data-00000-of-00001',
|
|
||||||
'models/gesture_recognizer/gesture_embedder/variables/variables.index',
|
|
||||||
'models/text_classifier/mobilebert_tiny/keras_metadata.pb',
|
'models/text_classifier/mobilebert_tiny/keras_metadata.pb',
|
||||||
'models/text_classifier/mobilebert_tiny/saved_model.pb',
|
'models/text_classifier/mobilebert_tiny/saved_model.pb',
|
||||||
'models/text_classifier/mobilebert_tiny/assets/vocab.txt',
|
'models/text_classifier/mobilebert_tiny/assets/vocab.txt',
|
||||||
|
|
Loading…
Reference in New Issue
Block a user