Internal Changes

PiperOrigin-RevId: 487673720
This commit is contained in:
MediaPipe Team 2022-11-10 16:46:12 -08:00 committed by Copybara-Service
parent 2ea5184c51
commit ec327cedcb
61 changed files with 1741 additions and 10 deletions

View File

@ -750,19 +750,12 @@ objc_library(
],
)
proto_library(
mediapipe_proto_library(
name = "scale_mode_proto",
srcs = ["scale_mode.proto"],
visibility = ["//visibility:public"],
)
mediapipe_cc_proto_library(
name = "scale_mode_cc_proto",
srcs = ["scale_mode.proto"],
visibility = ["//visibility:public"],
deps = [":scale_mode_proto"],
)
cc_library(
name = "gl_quad_renderer",
srcs = ["gl_quad_renderer.cc"],

View File

@ -0,0 +1,51 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//mediapipe/framework/tool:mediapipe_files.bzl",
"mediapipe_files",
)
licenses(["notice"])
package(
default_visibility = ["//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__"],
)
mediapipe_files(
srcs = [
"canned_gesture_classifier.tflite",
"gesture_embedder.tflite",
"gesture_embedder/keras_metadata.pb",
"gesture_embedder/saved_model.pb",
"gesture_embedder/variables/variables.data-00000-of-00001",
"gesture_embedder/variables/variables.index",
"hand_landmark_full.tflite",
"palm_detection_full.tflite",
],
)
filegroup(
name = "models",
srcs = [
"canned_gesture_classifier.tflite",
"gesture_embedder.tflite",
"gesture_embedder/keras_metadata.pb",
"gesture_embedder/saved_model.pb",
"gesture_embedder/variables/variables.data-00000-of-00001",
"gesture_embedder/variables/variables.index",
"hand_landmark_full.tflite",
"palm_detection_full.tflite",
],
)

View File

@ -0,0 +1,165 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict test compatibility macro.
# Placeholder for internal Python strict library and test compatibility macro.
licenses(["notice"])
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
# TODO: Remove the unncessary test data once the demo data are moved to an open-sourced
# directory.
filegroup(
name = "test_data",
srcs = glob([
"test_data/**",
]),
)
py_library(
name = "constants",
srcs = ["constants.py"],
)
# TODO: Change to py_library after migrating the MediaPipe hand solution
# library to MediaPipe hand task library.
py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = [
":constants",
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/data:data_util",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/python/solutions:hands",
],
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
data = [
":test_data",
"//mediapipe/model_maker/models/gesture_recognizer:models",
],
deps = [
":dataset",
"//mediapipe/python/solutions:hands",
"//mediapipe/tasks/python/test:test_utils",
],
)
py_library(
name = "hyperparameters",
srcs = ["hyperparameters.py"],
deps = [
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
py_library(
name = "model_options",
srcs = ["model_options.py"],
)
py_library(
name = "gesture_recognizer_options",
srcs = ["gesture_recognizer_options.py"],
deps = [
":hyperparameters",
":model_options",
],
)
py_library(
name = "gesture_recognizer",
srcs = ["gesture_recognizer.py"],
data = ["//mediapipe/model_maker/models/gesture_recognizer:models"],
deps = [
":constants",
":gesture_recognizer_options",
":hyperparameters",
":metadata_writer",
":model_options",
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:loss_functions",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
],
)
py_library(
name = "gesture_recognizer_import",
srcs = ["__init__.py"],
deps = [
":dataset",
":gesture_recognizer",
":gesture_recognizer_options",
":hyperparameters",
":model_options",
],
)
py_library(
name = "metadata_writer",
srcs = ["metadata_writer.py"],
deps = [
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
"//mediapipe/tasks/python/metadata/metadata_writers:writer_utils",
],
)
py_test(
name = "gesture_recognizer_test",
size = "large",
srcs = ["gesture_recognizer_test.py"],
data = [
":test_data",
"//mediapipe/model_maker/models/gesture_recognizer:models",
],
shard_count = 2,
deps = [
":gesture_recognizer_import",
"//mediapipe/model_maker/python/core/utils:test_util",
"//mediapipe/tasks/python/test:test_utils",
],
)
py_test(
name = "metadata_writer_test",
srcs = ["metadata_writer_test.py"],
data = [
":test_data",
],
deps = [
":metadata_writer",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
"//mediapipe/tasks/python/test:test_utils",
],
)
py_binary(
name = "gesture_recognizer_demo",
srcs = ["gesture_recognizer_demo.py"],
data = [
":test_data",
"//mediapipe/model_maker/models/gesture_recognizer:models",
],
python_version = "PY3",
deps = [":gesture_recognizer_import"],
)

View File

@ -0,0 +1,27 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe Model Maker Python Public API For Gesture Recognizer."""
from mediapipe.model_maker.python.vision.gesture_recognizer import dataset
from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer
from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer_options
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options
GestureRecognizer = gesture_recognizer.GestureRecognizer
ModelOptions = model_options.GestureRecognizerModelOptions
HParams = hyperparameters.HParams
Dataset = dataset.Dataset
HandDataPreprocessingParams = dataset.HandDataPreprocessingParams
GestureRecognizerOptions = gesture_recognizer_options.GestureRecognizerOptions

View File

@ -0,0 +1,20 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gesture recognition constants."""
GESTURE_EMBEDDER_KERAS_MODEL_PATH = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder'
GESTURE_EMBEDDER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder.tflite'
HAND_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/palm_detection_full.tflite'
HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/hand_landmark_full.tflite'
CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/canned_gesture_classifier.tflite'

View File

@ -0,0 +1,238 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gesture recognition dataset library."""
import dataclasses
import os
import random
from typing import List, NamedTuple, Optional
import cv2
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.core.data import data_util
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.vision.gesture_recognizer import constants
from mediapipe.python.solutions import hands as mp_hands
@dataclasses.dataclass
class HandDataPreprocessingParams:
"""A dataclass wraps the hand data preprocessing hyperparameters.
Attributes:
shuffle: A boolean controlling if shuffle the dataset. Default to true.
min_detection_confidence: confidence threshold for hand detection.
"""
shuffle: bool = True
min_detection_confidence: float = 0.7
@dataclasses.dataclass
class HandData:
"""A dataclass represents hand data for training gesture recognizer model.
See https://google.github.io/mediapipe/solutions/hands#mediapipe-hands for
more details of the hand gesture data API.
Attributes:
hand: normalized hand landmarks of shape 21x3 from the screen based
hand-landmark model.
world_hand: hand landmarks of shape 21x3 in world coordinates.
handedness: Collection of handedness confidence of the detected hands (i.e.
is it a left or right hand).
"""
hand: List[List[float]]
world_hand: List[List[float]]
handedness: List[float]
def _validate_data_sample(data: NamedTuple) -> bool:
"""Validates the input hand data sample.
Args:
data: input hand data sample.
Returns:
False if the input data namedtuple does not contain the fields including
'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness'
or any of these attributes' values are none. Otherwise, True.
"""
if (not hasattr(data, 'multi_hand_landmarks') or
data.multi_hand_landmarks is None):
return False
if (not hasattr(data, 'multi_hand_world_landmarks') or
data.multi_hand_world_landmarks is None):
return False
if not hasattr(data, 'multi_handedness') or data.multi_handedness is None:
return False
return True
def _get_hand_data(all_image_paths: List[str],
min_detection_confidence: float) -> Optional[HandData]:
"""Computes hand data (landmarks and handedness) in the input image.
Args:
all_image_paths: all input image paths.
min_detection_confidence: hand detection confidence threshold
Returns:
A HandData object. Returns None if no hand is detected.
"""
hand_data_result = []
with mp_hands.Hands(
static_image_mode=True,
max_num_hands=1,
min_detection_confidence=min_detection_confidence) as hands:
for path in all_image_paths:
tf.compat.v1.logging.info('Loading image %s', path)
image = data_util.load_image(path)
# Flip image around y-axis for correct handedness output
image = cv2.flip(image, 1)
data = hands.process(image)
if not _validate_data_sample(data):
hand_data_result.append(None)
continue
hand_landmarks = [[
hand_landmark.x, hand_landmark.y, hand_landmark.z
] for hand_landmark in data.multi_hand_landmarks[0].landmark]
hand_world_landmarks = [[
hand_landmark.x, hand_landmark.y, hand_landmark.z
] for hand_landmark in data.multi_hand_world_landmarks[0].landmark]
handedness_scores = [
handedness.score
for handedness in data.multi_handedness[0].classification
]
hand_data_result.append(
HandData(
hand=hand_landmarks,
world_hand=hand_world_landmarks,
handedness=handedness_scores))
return hand_data_result
class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for hand gesture recognizer."""
@classmethod
def from_folder(
cls,
dirname: str,
hparams: Optional[HandDataPreprocessingParams] = None
) -> classification_dataset.ClassificationDataset:
"""Loads images and labels from the given directory.
Directory contents are expected to be in the format:
<root_dir>/<gesture_name>/*.jpg". One of the `gesture_name` must be `none`
(case insensitive). The `none` sub-directory is expected to contain images
of hands that don't belong to other gesture classes in <root_dir>. Assumes
the image data of the same label are in the same subdirectory.
Args:
dirname: Name of the directory containing the data files.
hparams: Optional hyperparameters for processing input hand gesture
images.
Returns:
Dataset containing landmarks, labels, and other related info.
Raises:
ValueError: if the input data directory is empty or the label set does not
contain label 'none' (case insensitive).
"""
data_root = os.path.abspath(dirname)
# Assumes the image data of the same label are in the same subdirectory,
# gets image path and label names.
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
if not all_image_paths:
raise ValueError('Image dataset directory is empty.')
if not hparams:
hparams = HandDataPreprocessingParams()
if hparams.shuffle:
# Random shuffle data.
random.shuffle(all_image_paths)
label_names = sorted(
name for name in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, name)))
if 'none' not in [v.lower() for v in label_names]:
raise ValueError('Label set does not contain label "None".')
# Move label 'none' to the front of label list.
none_idx = [v.lower() for v in label_names].index('none')
none_value = label_names.pop(none_idx)
label_names.insert(0, none_value)
index_by_label = dict(
(name, index) for index, name in enumerate(label_names))
all_gesture_indices = [
index_by_label[os.path.basename(os.path.dirname(path))]
for path in all_image_paths
]
# Compute hand data (including local hand landmark, world hand landmark, and
# handedness) for all the input images.
hand_data = _get_hand_data(
all_image_paths=all_image_paths,
min_detection_confidence=hparams.min_detection_confidence)
# Get a list of the valid hand landmark sample in the hand data list.
valid_indices = [
i for i in range(len(hand_data)) if hand_data[i] is not None
]
# Remove 'None' element from the hand data and label list.
valid_hand_data = [dataclasses.asdict(hand_data[i]) for i in valid_indices]
if not valid_hand_data:
raise ValueError('No valid hand is detected.')
valid_label = [all_gesture_indices[i] for i in valid_indices]
# Convert list of dictionaries to dictionary of lists.
hand_data_dict = {
k: [lm[k] for lm in valid_hand_data] for k in valid_hand_data[0]
}
hand_ds = tf.data.Dataset.from_tensor_slices(hand_data_dict)
embedder_model = model_util.load_keras_model(
constants.GESTURE_EMBEDDER_KERAS_MODEL_PATH)
hand_ds = hand_ds.batch(batch_size=1)
hand_embedding_ds = hand_ds.map(
map_func=lambda feature: embedder_model(dict(feature)),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
hand_embedding_ds = hand_embedding_ds.unbatch()
# Create label dataset
label_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(valid_label, tf.int64))
label_one_hot_ds = label_ds.map(
map_func=lambda index: tf.one_hot(index, len(label_names)),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
# Create a dataset with (hand_embedding, one_hot_label) pairs
hand_embedding_label_ds = tf.data.Dataset.zip(
(hand_embedding_ds, label_one_hot_ds))
tf.compat.v1.logging.info(
'Load valid hands with size: {}, num_label: {}, labels: {}.'.format(
len(valid_hand_data), len(label_names), ','.join(label_names)))
return Dataset(
dataset=hand_embedding_label_ds,
size=len(valid_hand_data),
label_names=label_names)

View File

@ -0,0 +1,161 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.s
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
import shutil
from typing import NamedTuple
import unittest
from absl import flags
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.vision.gesture_recognizer import dataset
from mediapipe.python.solutions import hands as mp_hands
from mediapipe.tasks.python.test import test_utils
FLAGS = flags.FLAGS
_TEST_DATA_DIRNAME = 'raw_data'
class DatasetTest(tf.test.TestCase, parameterized.TestCase):
def test_split(self):
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
data = dataset.Dataset.from_folder(
dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams())
train_data, test_data = data.split(0.5)
self.assertLen(train_data, 17)
for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4]))
self.assertEqual(train_data.num_classes, 4)
self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock'])
self.assertLen(test_data, 18)
for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4]))
self.assertEqual(test_data.num_classes, 4)
self.assertEqual(test_data.label_names, ['none', 'call', 'four', 'rock'])
def test_from_folder(self):
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
data = dataset.Dataset.from_folder(
dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams())
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4]))
self.assertLen(data, 35)
self.assertEqual(data.num_classes, 4)
self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock'])
def test_create_dataset_from_empty_folder_raise_value_error(self):
with self.assertRaisesRegex(ValueError, 'Image dataset directory is empty'):
dataset.Dataset.from_folder(
dirname=self.get_temp_dir(),
hparams=dataset.HandDataPreprocessingParams())
def test_create_dataset_from_folder_without_none_raise_value_error(self):
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
tmp_dir = self.create_tempdir()
# Copy input dataset to a temporary directory and skip 'None' directory.
for name in os.listdir(input_data_dir):
if name == 'none':
continue
src_dir = os.path.join(input_data_dir, name)
dst_dir = os.path.join(tmp_dir, name)
shutil.copytree(src_dir, dst_dir)
with self.assertRaisesRegex(ValueError,
'Label set does not contain label "None"'):
dataset.Dataset.from_folder(
dirname=tmp_dir, hparams=dataset.HandDataPreprocessingParams())
def test_create_dataset_from_folder_with_capital_letter_in_folder_name(self):
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
tmp_dir = self.create_tempdir()
# Copy input dataset to a temporary directory and change the base folder
# name to upper case letter, e.g. 'none' -> 'NONE'
for name in os.listdir(input_data_dir):
src_dir = os.path.join(input_data_dir, name)
dst_dir = os.path.join(tmp_dir, name.upper())
shutil.copytree(src_dir, dst_dir)
upper_base_folder_name = list(os.listdir(tmp_dir))
self.assertCountEqual(upper_base_folder_name,
['CALL', 'FOUR', 'NONE', 'ROCK'])
data = dataset.Dataset.from_folder(
dirname=tmp_dir, hparams=dataset.HandDataPreprocessingParams())
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4]))
self.assertLen(data, 35)
self.assertEqual(data.num_classes, 4)
self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK'])
@parameterized.named_parameters(
dict(
testcase_name='invalid_field_name_multi_hand_landmark',
hand=collections.namedtuple('Hand', [
'multi_hand_landmark', 'multi_hand_world_landmarks',
'multi_handedness'
])(1, 2, 3)),
dict(
testcase_name='invalid_field_name_multi_hand_world_landmarks',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmark',
'multi_handedness'
])(1, 2, 3)),
dict(
testcase_name='invalid_field_name_multi_handed',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmarks',
'multi_handed'
])(1, 2, 3)),
dict(
testcase_name='multi_hand_landmarks_is_none',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmarks',
'multi_handedness'
])(None, 2, 3)),
dict(
testcase_name='multi_hand_world_landmarks_is_none',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmarks',
'multi_handedness'
])(1, None, 3)),
dict(
testcase_name='multi_handedness_is_none',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmarks',
'multi_handedness'
])(1, 2, None)),
)
def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple):
with unittest.mock.patch.object(
mp_hands.Hands, 'process', return_value=hand):
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
with self.assertRaisesRegex(ValueError, 'No valid hand is detected'):
dataset.Dataset.from_folder(
dirname=input_data_dir,
hparams=dataset.HandDataPreprocessingParams())
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,239 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""APIs to train gesture recognizer model."""
import os
from typing import List
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import loss_functions
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.vision.gesture_recognizer import constants
from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer_options
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters as hp
from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options as model_opt
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer as base_metadata_writer
_EMBEDDING_SIZE = 128
class GestureRecognizer(classifier.Classifier):
"""GestureRecognizer for building hand gesture recognizer model.
Attributes:
embedding_size: Size of the input gesture embedding vector.
"""
def __init__(self, label_names: List[str],
model_options: model_opt.GestureRecognizerModelOptions,
hparams: hp.HParams):
"""Initializes GestureRecognizer class.
Args:
label_names: A list of label names for the classes.
model_options: options to create gesture recognizer model.
hparams: The hyperparameters for training hand gesture recognizer model.
"""
super().__init__(
model_spec=None, label_names=label_names, shuffle=hparams.shuffle)
self._model_options = model_options
self._hparams = hparams
self._history = None
self.embedding_size = _EMBEDDING_SIZE
@classmethod
def create(
cls,
train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset,
options: gesture_recognizer_options.GestureRecognizerOptions,
) -> 'GestureRecognizer':
"""Creates and trains a hand gesture recognizer with input datasets.
If a checkpoint file exists in the {options.hparams.export_dir}/checkpoint/
directory, the training process will load the weight from the checkpoint
file for continual training.
Args:
train_data: Training data.
validation_data: Validation data. If None, skips validation process.
options: options for creating and training gesture recognizer model.
Returns:
An instance of GestureRecognizer.
"""
if options.model_options is None:
options.model_options = model_opt.GestureRecognizerModelOptions()
if options.hparams is None:
options.hparams = hp.HParams()
gesture_recognizer = cls(
label_names=train_data.label_names,
model_options=options.model_options,
hparams=options.hparams)
gesture_recognizer._create_model()
train_dataset = train_data.gen_tf_dataset(
batch_size=options.hparams.batch_size,
is_training=True,
shuffle=options.hparams.shuffle)
options.hparams.steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=options.hparams.steps_per_epoch,
batch_size=options.hparams.batch_size,
train_data=train_data)
train_dataset = train_dataset.take(count=options.hparams.steps_per_epoch)
validation_dataset = validation_data.gen_tf_dataset(
batch_size=options.hparams.batch_size, is_training=False)
tf.compat.v1.logging.info('Training the gesture recognizer model...')
gesture_recognizer._train(
train_data=train_dataset, validation_data=validation_dataset)
return gesture_recognizer
def _train(self, train_data: tf.data.Dataset,
validation_data: tf.data.Dataset):
"""Trains the model with input train_data.
The training results are recorded by a self.History object returned by
tf.keras.Model.fit().
Args:
train_data: Training data.
validation_data: Validation data.
"""
hparams = self._hparams
scheduler = lambda epoch: hparams.learning_rate * (hparams.lr_decay**epoch)
scheduler_callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
job_dir = hparams.export_dir
checkpoint_path = os.path.join(job_dir, 'epoch_models')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
save_weights_only=True)
best_model_path = os.path.join(job_dir, 'best_model_weights')
best_model_callback = tf.keras.callbacks.ModelCheckpoint(
best_model_path,
monitor='val_loss',
mode='min',
save_best_only=True,
save_weights_only=True)
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=os.path.join(job_dir, 'logs'))
self._model.compile(
optimizer='adam',
loss=loss_functions.FocalLoss(gamma=self._hparams.gamma),
metrics=['categorical_accuracy'])
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
if latest_checkpoint:
print(f'Resuming from {latest_checkpoint}')
self._model.load_weights(latest_checkpoint)
self._history = self._model.fit(
x=train_data,
epochs=hparams.epochs,
validation_data=validation_data,
validation_freq=1,
callbacks=[
checkpoint_callback, best_model_callback, scheduler_callback,
tensorboard_callback
],
)
def _create_model(self):
"""Creates the hand gesture recognizer model.
The gesture embedding model is pretrained and loaded from a tf.saved_model.
"""
inputs = tf.keras.Input(
shape=[self.embedding_size],
batch_size=None,
dtype=tf.float32,
name='hand_embedding')
x = tf.keras.layers.BatchNormalization()(inputs)
x = tf.keras.layers.ReLU()(x)
dropout_rate = self._model_options.dropout_rate
x = tf.keras.layers.Dropout(rate=dropout_rate, name='dropout')(x)
outputs = tf.keras.layers.Dense(
self._num_classes,
activation='softmax',
name='custom_gesture_recognizer')(
x)
self._model = tf.keras.Model(inputs=inputs, outputs=outputs)
print(self._model.summary())
def export_model(self, model_name: str = 'gesture_recognizer.task'):
"""Converts the model to TFLite and exports as a model bundle file.
Saves a model bundle file and metadata json file to hparams.export_dir. The
resulting model bundle file will contain necessary models for hand
detection, canned gesture classification, and customized gesture
classification. Only the model bundle file is needed for the downstream
gesture recognition task. The metadata.json file is saved only to
interpret the contents of the model bundle file.
The customized gesture model is in float without quantization. The model is
lightweight and there is no need to balance performance and efficiency by
quantization. The default score_thresholding is set to 0.5 as it can be
adjusted during inference.
Args:
model_name: File name to save model bundle file. The full export path is
{export_dir}/{model_name}.
"""
# TODO: Convert keras embedder model instead of using tflite
gesture_embedding_model_buffer = model_util.load_tflite_model_buffer(
constants.GESTURE_EMBEDDER_TFLITE_MODEL_FILE)
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
constants.HAND_DETECTOR_TFLITE_MODEL_FILE)
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE)
canned_gesture_model_buffer = model_util.load_tflite_model_buffer(
constants.CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE)
if not tf.io.gfile.exists(self._hparams.export_dir):
tf.io.gfile.makedirs(self._hparams.export_dir)
model_bundle_file = os.path.join(self._hparams.export_dir, model_name)
metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
gesture_classifier_options = metadata_writer.GestureClassifierOptions(
model_buffer=model_util.convert_to_tflite(self._model),
labels=base_metadata_writer.Labels().add(list(self._label_names)),
score_thresholding=base_metadata_writer.ScoreThresholding(
global_score_threshold=0.5))
writer = metadata_writer.MetadataWriter.create(
hand_detector_model_buffer, hand_landmarks_detector_model_buffer,
gesture_embedding_model_buffer, canned_gesture_model_buffer,
gesture_classifier_options)
model_bundle_content, metadata_json = writer.populate()
with open(model_bundle_file, 'wb') as f:
f.write(model_bundle_content)
with open(metadata_file, 'w') as f:
f.write(metadata_json)

View File

@ -0,0 +1,78 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Demo for making an gesture recognizer model by Mediapipe Model Maker."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# Dependency imports
from absl import app
from absl import flags
from absl import logging
from mediapipe.model_maker.python.vision import gesture_recognizer
FLAGS = flags.FLAGS
# TODO: Move hand gesture recognizer demo dataset to an
# open-sourced directory.
TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data'
def define_flags():
flags.DEFINE_string('export_dir', None,
'The directory to save exported files.')
flags.DEFINE_string('input_data_dir', None,
'The directory with input training data.')
flags.mark_flag_as_required('export_dir')
def run(data_dir: str, export_dir: str):
"""Runs demo."""
data = gesture_recognizer.Dataset.from_folder(dirname=data_dir)
train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)
model = gesture_recognizer.GestureRecognizer.create(
train_data=train_data,
validation_data=validation_data,
options=gesture_recognizer.GestureRecognizerOptions(
hparams=gesture_recognizer.HParams(export_dir=export_dir)))
metric = model.evaluate(test_data, batch_size=2)
print('Evaluation metric')
print(metric)
model.export_model()
def main(_):
logging.set_verbosity(logging.INFO)
if FLAGS.input_data_dir is None:
data_dir = os.path.join(FLAGS.test_srcdir, TEST_DATA_DIR)
else:
data_dir = FLAGS.input_data_dir
export_dir = os.path.expanduser(FLAGS.export_dir)
run(data_dir=data_dir, export_dir=export_dir)
if __name__ == '__main__':
define_flags()
app.run(main)

View File

@ -0,0 +1,32 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Options for building gesture recognizer."""
import dataclasses
from typing import Optional
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options as model_opt
@dataclasses.dataclass
class GestureRecognizerOptions:
"""Configurable options for building gesture recognizer.
Attributes:
model_options: A set of options for configuring the selected model.
hparams: A set of hyperparameters used to train the gesture recognizer.
"""
model_options: Optional[model_opt.GestureRecognizerModelOptions] = None
hparams: Optional[hyperparameters.HParams] = None

View File

@ -0,0 +1,132 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
from unittest import mock as unittest_mock
import zipfile
import mock
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import test_util
from mediapipe.model_maker.python.vision import gesture_recognizer
from mediapipe.tasks.python.test import test_utils
_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data'
class GestureRecognizerTest(tf.test.TestCase):
def _load_data(self):
input_data_dir = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'raw_data'))
data = gesture_recognizer.Dataset.from_folder(
dirname=input_data_dir,
hparams=gesture_recognizer.HandDataPreprocessingParams(shuffle=True))
return data
def setUp(self):
super().setUp()
self._model_options = gesture_recognizer.ModelOptions()
self._hparams = gesture_recognizer.HParams(epochs=2)
self._gesture_recognizer_options = (
gesture_recognizer.GestureRecognizerOptions(
model_options=self._model_options, hparams=self._hparams))
all_data = self._load_data()
# Splits data, 90% data for training, 10% for testing
self._train_data, self._test_data = all_data.split(0.9)
def test_gesture_recognizer_model(self):
model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._test_data,
options=self._gesture_recognizer_options)
self._test_accuracy(model)
def test_export_gesture_recognizer_model(self):
model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._test_data,
options=self._gesture_recognizer_options)
model.export_model()
model_bundle_file = os.path.join(self._hparams.export_dir,
'gesture_recognizer.task')
with zipfile.ZipFile(model_bundle_file) as zf:
self.assertEqual(
set(zf.namelist()),
set(['hand_landmarker.task', 'hand_gesture_recognizer.task']))
zf.extractall(self.get_temp_dir())
hand_gesture_recognizer_bundle_file = os.path.join(
self.get_temp_dir(), 'hand_gesture_recognizer.task')
with zipfile.ZipFile(hand_gesture_recognizer_bundle_file) as zf:
self.assertEqual(
set(zf.namelist()),
set([
'canned_gesture_classifier.tflite',
'custom_gesture_classifier.tflite', 'gesture_embedder.tflite'
]))
zf.extractall(self.get_temp_dir())
gesture_classifier_tflite_file = os.path.join(
self.get_temp_dir(), 'custom_gesture_classifier.tflite')
test_util.test_tflite_file(
keras_model=model._model,
tflite_file=gesture_classifier_tflite_file,
size=[1, model.embedding_size])
def _test_accuracy(self, model, threshold=0.5):
_, accuracy = model.evaluate(self._test_data)
tf.compat.v1.logging.info(f'accuracy: {accuracy}')
self.assertGreaterEqual(accuracy, threshold)
@unittest_mock.patch.object(
gesture_recognizer.hyperparameters,
'HParams',
autospec=True,
return_value=gesture_recognizer.HParams(epochs=1))
@unittest_mock.patch.object(
gesture_recognizer.model_options,
'GestureRecognizerModelOptions',
autospec=True,
return_value=gesture_recognizer.ModelOptions())
def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
self, mock_hparams, mock_model_options):
options = gesture_recognizer.GestureRecognizerOptions()
gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._test_data,
options=options)
mock_hparams.assert_called_once()
mock_model_options.assert_called_once()
def test_continual_training_by_loading_checkpoint(self):
mock_stdout = io.StringIO()
with mock.patch('sys.stdout', mock_stdout):
model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._test_data,
options=self._gesture_recognizer_options)
model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._test_data,
options=self._gesture_recognizer_options)
self._test_accuracy(model)
self.assertRegex(mock_stdout.getvalue(), 'Resuming from')
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,40 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hyperparameters for training customized gesture recognizer models."""
import dataclasses
from mediapipe.model_maker.python.core import hyperparameters as hp
@dataclasses.dataclass
class HParams(hp.BaseHParams):
"""The hyperparameters for training gesture recognizer.
Attributes:
learning_rate: Learning rate to use for gradient descent training.
batch_size: Batch size for training.
epochs: Number of training iterations over the dataset.
lr_decay: Learning rate decay to use for gradient descent training.
gamma: Gamma parameter for focal loss.
"""
# Parameters from BaseHParams class.
learning_rate: float = 0.001
batch_size: int = 2
epochs: int = 10
# Parameters about training configuration
# TODO: Move lr_decay to hp.baseHParams.
lr_decay: float = 0.99
gamma: int = 2

View File

@ -0,0 +1,193 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Writes metadata and creates model asset bundle for gesture recognizer."""
import dataclasses
import os
import tempfile
from typing import Union
import tensorflow as tf
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
_HAND_DETECTOR_TFLITE_NAME = "hand_detector.tflite"
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME = "hand_landmarks_detector.tflite"
_HAND_LANDMARKER_BUNDLE_NAME = "hand_landmarker.task"
_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME = "hand_gesture_recognizer.task"
_GESTURE_EMBEDDER_TFLITE_NAME = "gesture_embedder.tflite"
_CANNED_GESTURE_CLASSIFIER_TFLITE_NAME = "canned_gesture_classifier.tflite"
_CUSTOM_GESTURE_CLASSIFIER_TFLITE_NAME = "custom_gesture_classifier.tflite"
_MODEL_NAME = "HandGestureRecognition"
_MODEL_DESCRIPTION = "Recognize the hand gesture in the image."
_INPUT_NAME = "embedding"
_INPUT_DESCRIPTION = "Embedding feature vector from gesture embedder."
_OUTPUT_NAME = "scores"
_OUTPUT_DESCRIPTION = "Hand gesture category scores."
@dataclasses.dataclass
class GestureClassifierOptions:
"""Options to write metadata for gesture classifier.
Attributes:
model_buffer: Gesture classifier TFLite model buffer.
labels: Labels for the gesture classifier.
score_thresholding: Parameters to performs thresholding on output tensor
values [1].
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
"""
model_buffer: bytearray
labels: metadata_writer.Labels
score_thresholding: metadata_writer.ScoreThresholding
def read_file(file_path: str, mode: str = "rb") -> Union[str, bytes]:
with tf.io.gfile.GFile(file_path, mode) as f:
return f.read()
class MetadataWriter:
"""MetadataWriter to write the metadata and the model asset bundle."""
def __init__(
self, hand_detector_model_buffer: bytearray,
hand_landmarks_detector_model_buffer: bytearray,
gesture_embedder_model_buffer: bytearray,
canned_gesture_classifier_model_buffer: bytearray,
custom_gesture_classifier_metadata_writer: metadata_writer.MetadataWriter
) -> None:
"""Initialize MetadataWriter to write the metadata and model asset bundle.
Args:
hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from
the TFLite hand detector model file.
hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata
loaded from the TFLite hand landmarks detector model file.
gesture_embedder_model_buffer: A valid flatbuffer *with* metadata loaded
from the TFLite gesture embedder model file.
canned_gesture_classifier_model_buffer: A valid flatbuffer *with* metadata
loaded from the TFLite canned gesture classifier model file.
custom_gesture_classifier_metadata_writer: Metadata writer to write custom
gesture classifier metadata into the TFLite file.
"""
self._hand_detector_model_buffer = hand_detector_model_buffer
self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer
self._gesture_embedder_model_buffer = gesture_embedder_model_buffer
self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer
self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer
self._temp_folder = tempfile.TemporaryDirectory()
def __del__(self):
if os.path.exists(self._temp_folder.name):
self._temp_folder.cleanup()
@classmethod
def create(
cls,
hand_detector_model_buffer: bytearray,
hand_landmarks_detector_model_buffer: bytearray,
gesture_embedder_model_buffer: bytearray,
canned_gesture_classifier_model_buffer: bytearray,
custom_gesture_classifier_options: GestureClassifierOptions,
) -> "MetadataWriter":
"""Creates MetadataWriter to write the metadata for gesture recognizer.
Args:
hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from
the TFLite hand detector model file.
hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata
loaded from the TFLite hand landmarks detector model file.
gesture_embedder_model_buffer: A valid flatbuffer *with* metadata loaded
from the TFLite gesture embedder model file.
canned_gesture_classifier_model_buffer: A valid flatbuffer *with* metadata
loaded from the TFLite canned gesture classifier model file.
custom_gesture_classifier_options: Custom gesture classifier options to
write custom gesture classifier metadata into the TFLite file.
Returns:
An MetadataWrite object.
"""
writer = metadata_writer.MetadataWriter.create(
custom_gesture_classifier_options.model_buffer)
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
writer.add_feature_input(name=_INPUT_NAME, description=_INPUT_DESCRIPTION)
writer.add_classification_output(
labels=custom_gesture_classifier_options.labels,
score_thresholding=custom_gesture_classifier_options.score_thresholding,
name=_OUTPUT_NAME,
description=_OUTPUT_DESCRIPTION)
return cls(hand_detector_model_buffer, hand_landmarks_detector_model_buffer,
gesture_embedder_model_buffer,
canned_gesture_classifier_model_buffer, writer)
def populate(self):
"""Populates the metadata and creates model asset bundle.
Note that only the output model asset bundle is used for deployment.
The output JSON content is used to interpret the custom gesture classifier
metadata content.
Returns:
A tuple of (model_asset_bundle_in_bytes, metadata_json_content)
"""
# Creates the model asset bundle for hand landmarker task.
landmark_models = {
_HAND_DETECTOR_TFLITE_NAME:
self._hand_detector_model_buffer,
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME:
self._hand_landmarks_detector_model_buffer
}
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
_HAND_LANDMARKER_BUNDLE_NAME)
writer_utils.create_model_asset_bundle(landmark_models,
output_hand_landmarker_path)
# Write metadata into custom gesture classifier model.
self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate(
)
# Creates the model asset bundle for hand gesture recognizer sub graph.
hand_gesture_recognizer_models = {
_GESTURE_EMBEDDER_TFLITE_NAME:
self._gesture_embedder_model_buffer,
_CANNED_GESTURE_CLASSIFIER_TFLITE_NAME:
self._canned_gesture_classifier_model_buffer,
_CUSTOM_GESTURE_CLASSIFIER_TFLITE_NAME:
self._custom_gesture_classifier_model_buffer
}
output_hand_gesture_recognizer_path = os.path.join(
self._temp_folder.name, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME)
writer_utils.create_model_asset_bundle(hand_gesture_recognizer_models,
output_hand_gesture_recognizer_path)
# Creates the model asset bundle for end-to-end hand gesture recognizer
# graph.
gesture_recognizer_models = {
_HAND_LANDMARKER_BUNDLE_NAME:
read_file(output_hand_landmarker_path),
_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME:
read_file(output_hand_gesture_recognizer_path),
}
output_file_path = os.path.join(self._temp_folder.name,
"gesture_recognizer.task")
writer_utils.create_model_asset_bundle(gesture_recognizer_models,
output_file_path)
with open(output_file_path, "rb") as f:
gesture_recognizer_model_buffer = f.read()
return gesture_recognizer_model_buffer, custom_gesture_classifier_metadata_json

View File

@ -0,0 +1,90 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for metadata_writer."""
import os
import zipfile
import tensorflow as tf
from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer as base_metadata_writer
from mediapipe.tasks.python.test import test_utils
_TEST_DATA_DIR = "mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata"
_EXPECTED_JSON = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "custom_gesture_classifier_meta.json"))
_CUSTOM_GESTURE_CLASSIFIER_PATH = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "custom_gesture_classifier.tflite"))
class MetadataWriterTest(tf.test.TestCase):
def test_write_metadata_and_create_model_asset_bundle_successful(self):
# Use dummy model buffer for unit test only.
hand_detector_model_buffer = b"\x11\x12"
hand_landmarks_detector_model_buffer = b"\x22"
gesture_embedder_model_buffer = b"\x33"
canned_gesture_classifier_model_buffer = b"\x44"
custom_gesture_classifier_metadata_writer = metadata_writer.GestureClassifierOptions(
model_buffer=metadata_writer.read_file(_CUSTOM_GESTURE_CLASSIFIER_PATH),
labels=base_metadata_writer.Labels().add(
["None", "Paper", "Rock", "Scissors"]),
score_thresholding=base_metadata_writer.ScoreThresholding(
global_score_threshold=0.5))
writer = metadata_writer.MetadataWriter.create(
hand_detector_model_buffer, hand_landmarks_detector_model_buffer,
gesture_embedder_model_buffer, canned_gesture_classifier_model_buffer,
custom_gesture_classifier_metadata_writer)
model_bundle_content, metadata_json = writer.populate()
with open(_EXPECTED_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
# Checks the top-level model bundle can be extracted successfully.
model_bundle_filepath = os.path.join(self.get_temp_dir(),
"gesture_recognition.task")
with open(model_bundle_filepath, "wb") as f:
f.write(model_bundle_content)
with zipfile.ZipFile(model_bundle_filepath) as zf:
self.assertEqual(
set(zf.namelist()),
set(["hand_landmarker.task", "hand_gesture_recognizer.task"]))
zf.extractall(self.get_temp_dir())
# Checks the model bundles for sub-task can be extracted successfully.
hand_landmarker_bundle_filepath = os.path.join(self.get_temp_dir(),
"hand_landmarker.task")
with zipfile.ZipFile(hand_landmarker_bundle_filepath) as zf:
self.assertEqual(
set(zf.namelist()),
set(["hand_landmarks_detector.tflite", "hand_detector.tflite"]))
hand_gesture_recognizer_bundle_filepath = os.path.join(
self.get_temp_dir(), "hand_gesture_recognizer.task")
with zipfile.ZipFile(hand_gesture_recognizer_bundle_filepath) as zf:
self.assertEqual(
set(zf.namelist()),
set([
"canned_gesture_classifier.tflite",
"custom_gesture_classifier.tflite", "gesture_embedder.tflite"
]))
if __name__ == "__main__":
tf.test.main()

View File

@ -0,0 +1,27 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurable model options for gesture recognizer models."""
import dataclasses
@dataclasses.dataclass
class GestureRecognizerModelOptions:
"""Configurable options for gesture recognizer model.
Attributes:
dropout_rate: The fraction of the input units to drop, used in dropout
layer.
"""
dropout_rate: float = 0.05

View File

@ -0,0 +1,56 @@
{
"name": "HandGestureRecognition",
"description": "Recognize the hand gesture in the image.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "embedding",
"description": "Embedding feature vector from gesture embedder.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
}
],
"output_tensor_metadata": [
{
"name": "scores",
"description": "Hand gesture category scores.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"process_units": [
{
"options_type": "ScoreThresholdingOptions",
"options": {
"global_score_threshold": 0.5
}
}
],
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
]
}
],
"min_parser_version": "1.0.0"
}

View File

@ -120,6 +120,34 @@ py_library(
],
)
py_library(
name = "solution_base",
srcs = ["solution_base.py"],
srcs_version = "PY3",
visibility = [
"//mediapipe/python:__subpackages__",
],
deps = [
":_framework_bindings",
":packet_creator",
":packet_getter",
"//mediapipe/calculators/core:constant_side_packet_calculator_py_pb2",
"//mediapipe/calculators/image:image_transformation_calculator_py_pb2",
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_py_pb2",
"//mediapipe/calculators/util:landmarks_smoothing_calculator_py_pb2",
"//mediapipe/calculators/util:logic_calculator_py_pb2",
"//mediapipe/calculators/util:thresholding_calculator_py_pb2",
"//mediapipe/framework:calculator_py_pb2",
"//mediapipe/framework/formats:classification_py_pb2",
"//mediapipe/framework/formats:detection_py_pb2",
"//mediapipe/framework/formats:landmark_py_pb2",
"//mediapipe/framework/formats:rect_py_pb2",
"//mediapipe/modules/objectron/calculators:annotation_py_pb2",
"//mediapipe/modules/objectron/calculators:lift_2d_frame_annotation_to_3d_calculator_py_pb2",
"@com_google_protobuf//:protobuf_python",
],
)
py_test(
name = "calculator_graph_test",
srcs = ["calculator_graph_test.py"],
@ -175,3 +203,24 @@ py_test(
":_framework_bindings",
],
)
py_test(
name = "solution_base_test",
srcs = ["solution_base_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":solution_base",
"//file/google_src",
"//file/localfile",
"//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/calculators/core:side_packet_to_stream_calculator",
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/util:detection_unique_id_calculator",
"//mediapipe/calculators/util:to_image_calculator",
"//mediapipe/framework:calculator_py_pb2",
"//mediapipe/framework/formats:detection_py_pb2",
"@com_google_protobuf//:protobuf_python",
],
)

View File

@ -40,7 +40,6 @@ from mediapipe.calculators.util import landmarks_smoothing_calculator_pb2
from mediapipe.calculators.util import logic_calculator_pb2
from mediapipe.calculators.util import thresholding_calculator_pb2
from mediapipe.framework import calculator_pb2
from mediapipe.framework.formats import body_rig_pb2
from mediapipe.framework.formats import classification_pb2
from mediapipe.framework.formats import detection_pb2
from mediapipe.framework.formats import landmark_pb2

View File

@ -0,0 +1,105 @@
# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//visibility:public"])
py_library(
name = "hands",
srcs = [
"hands.py",
"hands_connections.py",
],
data = [
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
"//mediapipe/modules/hand_landmark:hand_landmark_lite.tflite",
"//mediapipe/modules/hand_landmark:hand_landmark_tracking_cpu_graph",
"//mediapipe/modules/hand_landmark:handedness.txt",
"//mediapipe/modules/palm_detection:palm_detection_full.tflite",
"//mediapipe/modules/palm_detection:palm_detection_lite.tflite",
],
srcs_version = "PY3",
deps = [
"//mediapipe/calculators/core:constant_side_packet_calculator_py_pb2",
"//mediapipe/calculators/core:gate_calculator_py_pb2",
"//mediapipe/calculators/core:split_vector_calculator_py_pb2",
"//mediapipe/calculators/tensor:image_to_tensor_calculator_py_pb2",
"//mediapipe/calculators/tensor:inference_calculator_py_pb2",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_py_pb2",
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_py_pb2",
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator_py_pb2",
"//mediapipe/calculators/tflite:ssd_anchors_calculator_py_pb2",
"//mediapipe/calculators/util:association_calculator_py_pb2",
"//mediapipe/calculators/util:detections_to_rects_calculator_py_pb2",
"//mediapipe/calculators/util:logic_calculator_py_pb2",
"//mediapipe/calculators/util:non_max_suppression_calculator_py_pb2",
"//mediapipe/calculators/util:rect_transformation_calculator_py_pb2",
"//mediapipe/calculators/util:thresholding_calculator_py_pb2",
"//mediapipe/python:solution_base",
],
)
py_library(
name = "drawing_styles",
srcs = ["drawing_styles.py"],
srcs_version = "PY3",
deps = [
"drawing_utils",
"face_mesh",
"hands",
"pose",
],
)
py_library(
name = "drawing_utils",
srcs = ["drawing_utils.py"],
srcs_version = "PY3",
deps = [
"//mediapipe/framework/formats:detection_py_pb2",
"//mediapipe/framework/formats:landmark_py_pb2",
"//mediapipe/framework/formats:location_data_py_pb2",
],
)
py_test(
name = "drawing_utils_test",
srcs = ["drawing_utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":drawing_utils",
"//mediapipe/framework/formats:detection_py_pb2",
"//mediapipe/framework/formats:landmark_py_pb2",
"@com_google_protobuf//:protobuf_python",
],
)
py_test(
name = "hands_test",
srcs = ["hands_test.py"],
data = [
":testdata/asl_hand.25fps.mp4",
":testdata/asl_hand.full.npz",
":testdata/hands.jpg",
],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":drawing_styles",
":drawing_utils",
":hands",
],
)

View File

@ -88,6 +88,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/burger_rotated.jpg?generation=1665065843774448"],
)
http_file(
name = "com_google_mediapipe_canned_gesture_classifier_tflite",
sha256 = "2fc7e279966a7a9e15fc869223793e390791fc61fdc0062f9bc7d0eef6be98a2",
urls = ["https://storage.googleapis.com/mediapipe-assets/canned_gesture_classifier.tflite?generation=1668124189331326"],
)
http_file(
name = "com_google_mediapipe_cat_jpg",
sha256 = "2533197401eebe9410ea4d063f86c43fbd2666f3e8165a38aca155c0d09c21be",
@ -286,6 +292,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/general_meta.json?generation=1665422822603848"],
)
http_file(
name = "com_google_mediapipe_gesture_embedder_tflite",
sha256 = "54abe78de1d1cd5e3cdaa0dab01db18e3ec7e09a76e7c3b5fa278572f7a60977",
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tflite?generation=1668124192126494"],
)
http_file(
name = "com_google_mediapipe_gesture_recognizer_task",
sha256 = "a966b1d4e774e0423c19c8aa71f070e5a72fe7a03c2663dd2f3cb0b0095ee3e1",
@ -721,7 +733,7 @@ def external_files():
http_file(
name = "com_google_mediapipe_README_md",
sha256 = "a96d08c9c70cd9717207ed72c926e02e5eada751f00bdc5d3a7e82e3492b72cb",
urls = ["https://storage.googleapis.com/mediapipe-assets/README.md?generation=1661875904887163"],
urls = ["https://storage.googleapis.com/mediapipe-assets/README.md?generation=1668124179156767"],
)
http_file(
@ -970,6 +982,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/yamnet_audio_classifier_with_metadata.tflite?generation=1661875980774466"],
)
http_file(
name = "com_google_mediapipe_gesture_embedder_keras_metadata_pb",
sha256 = "24268b69429be4e307f9ab099ba20d1de7c40e4191a53f6a92dcbbd97a7047d3",
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/keras_metadata.pb?generation=1668124196996131"],
)
http_file(
name = "com_google_mediapipe_gesture_embedder_saved_model_pb",
sha256 = "f3a2870ba3ef537a4f6a5889ffc5b7061ad98f9fd96ec431a62116892f100659",
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668124199460071"],
)
http_file(
name = "com_google_mediapipe_object_detection_saved_model_model_ckpt_data-00000-of-00001",
sha256 = "ad2f733f271dd5000a8c7f926bfea1083e6408b34d4f3b60679e5a6f96251c97",
@ -1005,3 +1029,15 @@ def external_files():
sha256 = "f29606cf218397d5580c496e50fd28cddf66e2f59b819ab9c761b72270a5adf3",
urls = ["https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/saved_model.pb?generation=1661875999264354"],
)
http_file(
name = "com_google_mediapipe_gesture_embedder_variables_variables_data-00000-of-00001",
sha256 = "9fdb750c4bac67afb9c0f61916510930b496cc47e7f89449aee2bec6b6ed0af8",
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.data-00000-of-00001?generation=1668124201918980"],
)
http_file(
name = "com_google_mediapipe_gesture_embedder_variables_variables_index",
sha256 = "3ccbcee9488fec4627d496abd9837997276b32b839a4d0ae434bd806fe380b86",
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668124204353848"],
)