Internal Changes
PiperOrigin-RevId: 487673720
|
@ -750,19 +750,12 @@ objc_library(
|
|||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
mediapipe_proto_library(
|
||||
name = "scale_mode_proto",
|
||||
srcs = ["scale_mode.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "scale_mode_cc_proto",
|
||||
srcs = ["scale_mode.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":scale_mode_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gl_quad_renderer",
|
||||
srcs = ["gl_quad_renderer.cc"],
|
||||
|
|
51
mediapipe/model_maker/models/gesture_recognizer/BUILD
Normal file
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"//mediapipe/framework/tool:mediapipe_files.bzl",
|
||||
"mediapipe_files",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__"],
|
||||
)
|
||||
|
||||
mediapipe_files(
|
||||
srcs = [
|
||||
"canned_gesture_classifier.tflite",
|
||||
"gesture_embedder.tflite",
|
||||
"gesture_embedder/keras_metadata.pb",
|
||||
"gesture_embedder/saved_model.pb",
|
||||
"gesture_embedder/variables/variables.data-00000-of-00001",
|
||||
"gesture_embedder/variables/variables.index",
|
||||
"hand_landmark_full.tflite",
|
||||
"palm_detection_full.tflite",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "models",
|
||||
srcs = [
|
||||
"canned_gesture_classifier.tflite",
|
||||
"gesture_embedder.tflite",
|
||||
"gesture_embedder/keras_metadata.pb",
|
||||
"gesture_embedder/saved_model.pb",
|
||||
"gesture_embedder/variables/variables.data-00000-of-00001",
|
||||
"gesture_embedder/variables/variables.index",
|
||||
"hand_landmark_full.tflite",
|
||||
"palm_detection_full.tflite",
|
||||
],
|
||||
)
|
165
mediapipe/model_maker/python/vision/gesture_recognizer/BUILD
Normal file
|
@ -0,0 +1,165 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
# TODO: Remove the unncessary test data once the demo data are moved to an open-sourced
|
||||
# directory.
|
||||
filegroup(
|
||||
name = "test_data",
|
||||
srcs = glob([
|
||||
"test_data/**",
|
||||
]),
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "constants",
|
||||
srcs = ["constants.py"],
|
||||
)
|
||||
|
||||
# TODO: Change to py_library after migrating the MediaPipe hand solution
|
||||
# library to MediaPipe hand task library.
|
||||
py_library(
|
||||
name = "dataset",
|
||||
srcs = ["dataset.py"],
|
||||
deps = [
|
||||
":constants",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
"//mediapipe/model_maker/python/core/data:data_util",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
"//mediapipe/python/solutions:hands",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dataset_test",
|
||||
srcs = ["dataset_test.py"],
|
||||
data = [
|
||||
":test_data",
|
||||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||
],
|
||||
deps = [
|
||||
":dataset",
|
||||
"//mediapipe/python/solutions:hands",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "hyperparameters",
|
||||
srcs = ["hyperparameters.py"],
|
||||
deps = [
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_options",
|
||||
srcs = ["model_options.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gesture_recognizer_options",
|
||||
srcs = ["gesture_recognizer_options.py"],
|
||||
deps = [
|
||||
":hyperparameters",
|
||||
":model_options",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gesture_recognizer",
|
||||
srcs = ["gesture_recognizer.py"],
|
||||
data = ["//mediapipe/model_maker/models/gesture_recognizer:models"],
|
||||
deps = [
|
||||
":constants",
|
||||
":gesture_recognizer_options",
|
||||
":hyperparameters",
|
||||
":metadata_writer",
|
||||
":model_options",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||
"//mediapipe/model_maker/python/core/utils:loss_functions",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gesture_recognizer_import",
|
||||
srcs = ["__init__.py"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":gesture_recognizer",
|
||||
":gesture_recognizer_options",
|
||||
":hyperparameters",
|
||||
":model_options",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "metadata_writer",
|
||||
srcs = ["metadata_writer.py"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:writer_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "gesture_recognizer_test",
|
||||
size = "large",
|
||||
srcs = ["gesture_recognizer_test.py"],
|
||||
data = [
|
||||
":test_data",
|
||||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||
],
|
||||
shard_count = 2,
|
||||
deps = [
|
||||
":gesture_recognizer_import",
|
||||
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "metadata_writer_test",
|
||||
srcs = ["metadata_writer_test.py"],
|
||||
data = [
|
||||
":test_data",
|
||||
],
|
||||
deps = [
|
||||
":metadata_writer",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "gesture_recognizer_demo",
|
||||
srcs = ["gesture_recognizer_demo.py"],
|
||||
data = [
|
||||
":test_data",
|
||||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||
],
|
||||
python_version = "PY3",
|
||||
deps = [":gesture_recognizer_import"],
|
||||
)
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MediaPipe Model Maker Python Public API For Gesture Recognizer."""
|
||||
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import dataset
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer_options
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options
|
||||
|
||||
GestureRecognizer = gesture_recognizer.GestureRecognizer
|
||||
ModelOptions = model_options.GestureRecognizerModelOptions
|
||||
HParams = hyperparameters.HParams
|
||||
Dataset = dataset.Dataset
|
||||
HandDataPreprocessingParams = dataset.HandDataPreprocessingParams
|
||||
GestureRecognizerOptions = gesture_recognizer_options.GestureRecognizerOptions
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Gesture recognition constants."""
|
||||
|
||||
GESTURE_EMBEDDER_KERAS_MODEL_PATH = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder'
|
||||
GESTURE_EMBEDDER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder.tflite'
|
||||
HAND_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/palm_detection_full.tflite'
|
||||
HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/hand_landmark_full.tflite'
|
||||
CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/canned_gesture_classifier.tflite'
|
|
@ -0,0 +1,238 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Gesture recognition dataset library."""
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import random
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
import cv2
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
from mediapipe.model_maker.python.core.data import data_util
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import constants
|
||||
from mediapipe.python.solutions import hands as mp_hands
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HandDataPreprocessingParams:
|
||||
"""A dataclass wraps the hand data preprocessing hyperparameters.
|
||||
|
||||
Attributes:
|
||||
shuffle: A boolean controlling if shuffle the dataset. Default to true.
|
||||
min_detection_confidence: confidence threshold for hand detection.
|
||||
"""
|
||||
shuffle: bool = True
|
||||
min_detection_confidence: float = 0.7
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HandData:
|
||||
"""A dataclass represents hand data for training gesture recognizer model.
|
||||
|
||||
See https://google.github.io/mediapipe/solutions/hands#mediapipe-hands for
|
||||
more details of the hand gesture data API.
|
||||
|
||||
Attributes:
|
||||
hand: normalized hand landmarks of shape 21x3 from the screen based
|
||||
hand-landmark model.
|
||||
world_hand: hand landmarks of shape 21x3 in world coordinates.
|
||||
handedness: Collection of handedness confidence of the detected hands (i.e.
|
||||
is it a left or right hand).
|
||||
"""
|
||||
hand: List[List[float]]
|
||||
world_hand: List[List[float]]
|
||||
handedness: List[float]
|
||||
|
||||
|
||||
def _validate_data_sample(data: NamedTuple) -> bool:
|
||||
"""Validates the input hand data sample.
|
||||
|
||||
Args:
|
||||
data: input hand data sample.
|
||||
|
||||
Returns:
|
||||
False if the input data namedtuple does not contain the fields including
|
||||
'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness'
|
||||
or any of these attributes' values are none. Otherwise, True.
|
||||
"""
|
||||
if (not hasattr(data, 'multi_hand_landmarks') or
|
||||
data.multi_hand_landmarks is None):
|
||||
return False
|
||||
if (not hasattr(data, 'multi_hand_world_landmarks') or
|
||||
data.multi_hand_world_landmarks is None):
|
||||
return False
|
||||
if not hasattr(data, 'multi_handedness') or data.multi_handedness is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_hand_data(all_image_paths: List[str],
|
||||
min_detection_confidence: float) -> Optional[HandData]:
|
||||
"""Computes hand data (landmarks and handedness) in the input image.
|
||||
|
||||
Args:
|
||||
all_image_paths: all input image paths.
|
||||
min_detection_confidence: hand detection confidence threshold
|
||||
|
||||
Returns:
|
||||
A HandData object. Returns None if no hand is detected.
|
||||
"""
|
||||
hand_data_result = []
|
||||
with mp_hands.Hands(
|
||||
static_image_mode=True,
|
||||
max_num_hands=1,
|
||||
min_detection_confidence=min_detection_confidence) as hands:
|
||||
for path in all_image_paths:
|
||||
tf.compat.v1.logging.info('Loading image %s', path)
|
||||
image = data_util.load_image(path)
|
||||
# Flip image around y-axis for correct handedness output
|
||||
image = cv2.flip(image, 1)
|
||||
data = hands.process(image)
|
||||
if not _validate_data_sample(data):
|
||||
hand_data_result.append(None)
|
||||
continue
|
||||
hand_landmarks = [[
|
||||
hand_landmark.x, hand_landmark.y, hand_landmark.z
|
||||
] for hand_landmark in data.multi_hand_landmarks[0].landmark]
|
||||
hand_world_landmarks = [[
|
||||
hand_landmark.x, hand_landmark.y, hand_landmark.z
|
||||
] for hand_landmark in data.multi_hand_world_landmarks[0].landmark]
|
||||
handedness_scores = [
|
||||
handedness.score
|
||||
for handedness in data.multi_handedness[0].classification
|
||||
]
|
||||
hand_data_result.append(
|
||||
HandData(
|
||||
hand=hand_landmarks,
|
||||
world_hand=hand_world_landmarks,
|
||||
handedness=handedness_scores))
|
||||
return hand_data_result
|
||||
|
||||
|
||||
class Dataset(classification_dataset.ClassificationDataset):
|
||||
"""Dataset library for hand gesture recognizer."""
|
||||
|
||||
@classmethod
|
||||
def from_folder(
|
||||
cls,
|
||||
dirname: str,
|
||||
hparams: Optional[HandDataPreprocessingParams] = None
|
||||
) -> classification_dataset.ClassificationDataset:
|
||||
"""Loads images and labels from the given directory.
|
||||
|
||||
Directory contents are expected to be in the format:
|
||||
<root_dir>/<gesture_name>/*.jpg". One of the `gesture_name` must be `none`
|
||||
(case insensitive). The `none` sub-directory is expected to contain images
|
||||
of hands that don't belong to other gesture classes in <root_dir>. Assumes
|
||||
the image data of the same label are in the same subdirectory.
|
||||
|
||||
Args:
|
||||
dirname: Name of the directory containing the data files.
|
||||
hparams: Optional hyperparameters for processing input hand gesture
|
||||
images.
|
||||
|
||||
Returns:
|
||||
Dataset containing landmarks, labels, and other related info.
|
||||
|
||||
Raises:
|
||||
ValueError: if the input data directory is empty or the label set does not
|
||||
contain label 'none' (case insensitive).
|
||||
"""
|
||||
data_root = os.path.abspath(dirname)
|
||||
|
||||
# Assumes the image data of the same label are in the same subdirectory,
|
||||
# gets image path and label names.
|
||||
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
|
||||
if not all_image_paths:
|
||||
raise ValueError('Image dataset directory is empty.')
|
||||
|
||||
if not hparams:
|
||||
hparams = HandDataPreprocessingParams()
|
||||
|
||||
if hparams.shuffle:
|
||||
# Random shuffle data.
|
||||
random.shuffle(all_image_paths)
|
||||
|
||||
label_names = sorted(
|
||||
name for name in os.listdir(data_root)
|
||||
if os.path.isdir(os.path.join(data_root, name)))
|
||||
if 'none' not in [v.lower() for v in label_names]:
|
||||
raise ValueError('Label set does not contain label "None".')
|
||||
# Move label 'none' to the front of label list.
|
||||
none_idx = [v.lower() for v in label_names].index('none')
|
||||
none_value = label_names.pop(none_idx)
|
||||
label_names.insert(0, none_value)
|
||||
|
||||
index_by_label = dict(
|
||||
(name, index) for index, name in enumerate(label_names))
|
||||
all_gesture_indices = [
|
||||
index_by_label[os.path.basename(os.path.dirname(path))]
|
||||
for path in all_image_paths
|
||||
]
|
||||
|
||||
# Compute hand data (including local hand landmark, world hand landmark, and
|
||||
# handedness) for all the input images.
|
||||
hand_data = _get_hand_data(
|
||||
all_image_paths=all_image_paths,
|
||||
min_detection_confidence=hparams.min_detection_confidence)
|
||||
|
||||
# Get a list of the valid hand landmark sample in the hand data list.
|
||||
valid_indices = [
|
||||
i for i in range(len(hand_data)) if hand_data[i] is not None
|
||||
]
|
||||
# Remove 'None' element from the hand data and label list.
|
||||
valid_hand_data = [dataclasses.asdict(hand_data[i]) for i in valid_indices]
|
||||
if not valid_hand_data:
|
||||
raise ValueError('No valid hand is detected.')
|
||||
|
||||
valid_label = [all_gesture_indices[i] for i in valid_indices]
|
||||
|
||||
# Convert list of dictionaries to dictionary of lists.
|
||||
hand_data_dict = {
|
||||
k: [lm[k] for lm in valid_hand_data] for k in valid_hand_data[0]
|
||||
}
|
||||
hand_ds = tf.data.Dataset.from_tensor_slices(hand_data_dict)
|
||||
|
||||
embedder_model = model_util.load_keras_model(
|
||||
constants.GESTURE_EMBEDDER_KERAS_MODEL_PATH)
|
||||
|
||||
hand_ds = hand_ds.batch(batch_size=1)
|
||||
hand_embedding_ds = hand_ds.map(
|
||||
map_func=lambda feature: embedder_model(dict(feature)),
|
||||
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
hand_embedding_ds = hand_embedding_ds.unbatch()
|
||||
|
||||
# Create label dataset
|
||||
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||
tf.cast(valid_label, tf.int64))
|
||||
|
||||
label_one_hot_ds = label_ds.map(
|
||||
map_func=lambda index: tf.one_hot(index, len(label_names)),
|
||||
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
|
||||
# Create a dataset with (hand_embedding, one_hot_label) pairs
|
||||
hand_embedding_label_ds = tf.data.Dataset.zip(
|
||||
(hand_embedding_ds, label_one_hot_ds))
|
||||
|
||||
tf.compat.v1.logging.info(
|
||||
'Load valid hands with size: {}, num_label: {}, labels: {}.'.format(
|
||||
len(valid_hand_data), len(label_names), ','.join(label_names)))
|
||||
return Dataset(
|
||||
dataset=hand_embedding_label_ds,
|
||||
size=len(valid_hand_data),
|
||||
label_names=label_names)
|
|
@ -0,0 +1,161 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.s
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import os
|
||||
import shutil
|
||||
from typing import NamedTuple
|
||||
import unittest
|
||||
|
||||
from absl import flags
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import dataset
|
||||
from mediapipe.python.solutions import hands as mp_hands
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
_TEST_DATA_DIRNAME = 'raw_data'
|
||||
|
||||
|
||||
class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_split(self):
|
||||
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||
data = dataset.Dataset.from_folder(
|
||||
dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams())
|
||||
train_data, test_data = data.split(0.5)
|
||||
|
||||
self.assertLen(train_data, 17)
|
||||
for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)):
|
||||
self.assertEqual(elem[0].shape, (1, 128))
|
||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||
self.assertEqual(train_data.num_classes, 4)
|
||||
self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock'])
|
||||
|
||||
self.assertLen(test_data, 18)
|
||||
for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)):
|
||||
self.assertEqual(elem[0].shape, (1, 128))
|
||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||
self.assertEqual(test_data.num_classes, 4)
|
||||
self.assertEqual(test_data.label_names, ['none', 'call', 'four', 'rock'])
|
||||
|
||||
def test_from_folder(self):
|
||||
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||
data = dataset.Dataset.from_folder(
|
||||
dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams())
|
||||
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
|
||||
self.assertEqual(elem[0].shape, (1, 128))
|
||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||
self.assertLen(data, 35)
|
||||
self.assertEqual(data.num_classes, 4)
|
||||
self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock'])
|
||||
|
||||
def test_create_dataset_from_empty_folder_raise_value_error(self):
|
||||
with self.assertRaisesRegex(ValueError, 'Image dataset directory is empty'):
|
||||
dataset.Dataset.from_folder(
|
||||
dirname=self.get_temp_dir(),
|
||||
hparams=dataset.HandDataPreprocessingParams())
|
||||
|
||||
def test_create_dataset_from_folder_without_none_raise_value_error(self):
|
||||
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||
tmp_dir = self.create_tempdir()
|
||||
# Copy input dataset to a temporary directory and skip 'None' directory.
|
||||
for name in os.listdir(input_data_dir):
|
||||
if name == 'none':
|
||||
continue
|
||||
src_dir = os.path.join(input_data_dir, name)
|
||||
dst_dir = os.path.join(tmp_dir, name)
|
||||
shutil.copytree(src_dir, dst_dir)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'Label set does not contain label "None"'):
|
||||
dataset.Dataset.from_folder(
|
||||
dirname=tmp_dir, hparams=dataset.HandDataPreprocessingParams())
|
||||
|
||||
def test_create_dataset_from_folder_with_capital_letter_in_folder_name(self):
|
||||
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||
tmp_dir = self.create_tempdir()
|
||||
# Copy input dataset to a temporary directory and change the base folder
|
||||
# name to upper case letter, e.g. 'none' -> 'NONE'
|
||||
for name in os.listdir(input_data_dir):
|
||||
src_dir = os.path.join(input_data_dir, name)
|
||||
dst_dir = os.path.join(tmp_dir, name.upper())
|
||||
shutil.copytree(src_dir, dst_dir)
|
||||
|
||||
upper_base_folder_name = list(os.listdir(tmp_dir))
|
||||
self.assertCountEqual(upper_base_folder_name,
|
||||
['CALL', 'FOUR', 'NONE', 'ROCK'])
|
||||
|
||||
data = dataset.Dataset.from_folder(
|
||||
dirname=tmp_dir, hparams=dataset.HandDataPreprocessingParams())
|
||||
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
|
||||
self.assertEqual(elem[0].shape, (1, 128))
|
||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||
self.assertLen(data, 35)
|
||||
self.assertEqual(data.num_classes, 4)
|
||||
self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK'])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='invalid_field_name_multi_hand_landmark',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmark', 'multi_hand_world_landmarks',
|
||||
'multi_handedness'
|
||||
])(1, 2, 3)),
|
||||
dict(
|
||||
testcase_name='invalid_field_name_multi_hand_world_landmarks',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmark',
|
||||
'multi_handedness'
|
||||
])(1, 2, 3)),
|
||||
dict(
|
||||
testcase_name='invalid_field_name_multi_handed',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||
'multi_handed'
|
||||
])(1, 2, 3)),
|
||||
dict(
|
||||
testcase_name='multi_hand_landmarks_is_none',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||
'multi_handedness'
|
||||
])(None, 2, 3)),
|
||||
dict(
|
||||
testcase_name='multi_hand_world_landmarks_is_none',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||
'multi_handedness'
|
||||
])(1, None, 3)),
|
||||
dict(
|
||||
testcase_name='multi_handedness_is_none',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||
'multi_handedness'
|
||||
])(1, 2, None)),
|
||||
)
|
||||
def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple):
|
||||
with unittest.mock.patch.object(
|
||||
mp_hands.Hands, 'process', return_value=hand):
|
||||
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||
with self.assertRaisesRegex(ValueError, 'No valid hand is detected'):
|
||||
dataset.Dataset.from_folder(
|
||||
dirname=input_data_dir,
|
||||
hparams=dataset.HandDataPreprocessingParams())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,239 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""APIs to train gesture recognizer model."""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
|
||||
from mediapipe.model_maker.python.core.tasks import classifier
|
||||
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import constants
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer_options
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options as model_opt
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer as base_metadata_writer
|
||||
|
||||
_EMBEDDING_SIZE = 128
|
||||
|
||||
|
||||
class GestureRecognizer(classifier.Classifier):
|
||||
"""GestureRecognizer for building hand gesture recognizer model.
|
||||
|
||||
Attributes:
|
||||
embedding_size: Size of the input gesture embedding vector.
|
||||
"""
|
||||
|
||||
def __init__(self, label_names: List[str],
|
||||
model_options: model_opt.GestureRecognizerModelOptions,
|
||||
hparams: hp.HParams):
|
||||
"""Initializes GestureRecognizer class.
|
||||
|
||||
Args:
|
||||
label_names: A list of label names for the classes.
|
||||
model_options: options to create gesture recognizer model.
|
||||
hparams: The hyperparameters for training hand gesture recognizer model.
|
||||
"""
|
||||
super().__init__(
|
||||
model_spec=None, label_names=label_names, shuffle=hparams.shuffle)
|
||||
self._model_options = model_options
|
||||
self._hparams = hparams
|
||||
self._history = None
|
||||
self.embedding_size = _EMBEDDING_SIZE
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
train_data: classification_ds.ClassificationDataset,
|
||||
validation_data: classification_ds.ClassificationDataset,
|
||||
options: gesture_recognizer_options.GestureRecognizerOptions,
|
||||
) -> 'GestureRecognizer':
|
||||
"""Creates and trains a hand gesture recognizer with input datasets.
|
||||
|
||||
If a checkpoint file exists in the {options.hparams.export_dir}/checkpoint/
|
||||
directory, the training process will load the weight from the checkpoint
|
||||
file for continual training.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data. If None, skips validation process.
|
||||
options: options for creating and training gesture recognizer model.
|
||||
|
||||
Returns:
|
||||
An instance of GestureRecognizer.
|
||||
"""
|
||||
if options.model_options is None:
|
||||
options.model_options = model_opt.GestureRecognizerModelOptions()
|
||||
|
||||
if options.hparams is None:
|
||||
options.hparams = hp.HParams()
|
||||
|
||||
gesture_recognizer = cls(
|
||||
label_names=train_data.label_names,
|
||||
model_options=options.model_options,
|
||||
hparams=options.hparams)
|
||||
|
||||
gesture_recognizer._create_model()
|
||||
|
||||
train_dataset = train_data.gen_tf_dataset(
|
||||
batch_size=options.hparams.batch_size,
|
||||
is_training=True,
|
||||
shuffle=options.hparams.shuffle)
|
||||
options.hparams.steps_per_epoch = model_util.get_steps_per_epoch(
|
||||
steps_per_epoch=options.hparams.steps_per_epoch,
|
||||
batch_size=options.hparams.batch_size,
|
||||
train_data=train_data)
|
||||
train_dataset = train_dataset.take(count=options.hparams.steps_per_epoch)
|
||||
|
||||
validation_dataset = validation_data.gen_tf_dataset(
|
||||
batch_size=options.hparams.batch_size, is_training=False)
|
||||
|
||||
tf.compat.v1.logging.info('Training the gesture recognizer model...')
|
||||
gesture_recognizer._train(
|
||||
train_data=train_dataset, validation_data=validation_dataset)
|
||||
|
||||
return gesture_recognizer
|
||||
|
||||
def _train(self, train_data: tf.data.Dataset,
|
||||
validation_data: tf.data.Dataset):
|
||||
"""Trains the model with input train_data.
|
||||
|
||||
The training results are recorded by a self.History object returned by
|
||||
tf.keras.Model.fit().
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
"""
|
||||
hparams = self._hparams
|
||||
|
||||
scheduler = lambda epoch: hparams.learning_rate * (hparams.lr_decay**epoch)
|
||||
scheduler_callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
|
||||
|
||||
job_dir = hparams.export_dir
|
||||
checkpoint_path = os.path.join(job_dir, 'epoch_models')
|
||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
|
||||
save_weights_only=True)
|
||||
|
||||
best_model_path = os.path.join(job_dir, 'best_model_weights')
|
||||
best_model_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
best_model_path,
|
||||
monitor='val_loss',
|
||||
mode='min',
|
||||
save_best_only=True,
|
||||
save_weights_only=True)
|
||||
|
||||
tensorboard_callback = tf.keras.callbacks.TensorBoard(
|
||||
log_dir=os.path.join(job_dir, 'logs'))
|
||||
|
||||
self._model.compile(
|
||||
optimizer='adam',
|
||||
loss=loss_functions.FocalLoss(gamma=self._hparams.gamma),
|
||||
metrics=['categorical_accuracy'])
|
||||
|
||||
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
|
||||
if latest_checkpoint:
|
||||
print(f'Resuming from {latest_checkpoint}')
|
||||
self._model.load_weights(latest_checkpoint)
|
||||
|
||||
self._history = self._model.fit(
|
||||
x=train_data,
|
||||
epochs=hparams.epochs,
|
||||
validation_data=validation_data,
|
||||
validation_freq=1,
|
||||
callbacks=[
|
||||
checkpoint_callback, best_model_callback, scheduler_callback,
|
||||
tensorboard_callback
|
||||
],
|
||||
)
|
||||
|
||||
def _create_model(self):
|
||||
"""Creates the hand gesture recognizer model.
|
||||
|
||||
The gesture embedding model is pretrained and loaded from a tf.saved_model.
|
||||
"""
|
||||
inputs = tf.keras.Input(
|
||||
shape=[self.embedding_size],
|
||||
batch_size=None,
|
||||
dtype=tf.float32,
|
||||
name='hand_embedding')
|
||||
|
||||
x = tf.keras.layers.BatchNormalization()(inputs)
|
||||
x = tf.keras.layers.ReLU()(x)
|
||||
dropout_rate = self._model_options.dropout_rate
|
||||
x = tf.keras.layers.Dropout(rate=dropout_rate, name='dropout')(x)
|
||||
outputs = tf.keras.layers.Dense(
|
||||
self._num_classes,
|
||||
activation='softmax',
|
||||
name='custom_gesture_recognizer')(
|
||||
x)
|
||||
|
||||
self._model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
print(self._model.summary())
|
||||
|
||||
def export_model(self, model_name: str = 'gesture_recognizer.task'):
|
||||
"""Converts the model to TFLite and exports as a model bundle file.
|
||||
|
||||
Saves a model bundle file and metadata json file to hparams.export_dir. The
|
||||
resulting model bundle file will contain necessary models for hand
|
||||
detection, canned gesture classification, and customized gesture
|
||||
classification. Only the model bundle file is needed for the downstream
|
||||
gesture recognition task. The metadata.json file is saved only to
|
||||
interpret the contents of the model bundle file.
|
||||
|
||||
The customized gesture model is in float without quantization. The model is
|
||||
lightweight and there is no need to balance performance and efficiency by
|
||||
quantization. The default score_thresholding is set to 0.5 as it can be
|
||||
adjusted during inference.
|
||||
|
||||
Args:
|
||||
model_name: File name to save model bundle file. The full export path is
|
||||
{export_dir}/{model_name}.
|
||||
"""
|
||||
# TODO: Convert keras embedder model instead of using tflite
|
||||
gesture_embedding_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.GESTURE_EMBEDDER_TFLITE_MODEL_FILE)
|
||||
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.HAND_DETECTOR_TFLITE_MODEL_FILE)
|
||||
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE)
|
||||
canned_gesture_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE)
|
||||
|
||||
if not tf.io.gfile.exists(self._hparams.export_dir):
|
||||
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||
model_bundle_file = os.path.join(self._hparams.export_dir, model_name)
|
||||
metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
|
||||
|
||||
gesture_classifier_options = metadata_writer.GestureClassifierOptions(
|
||||
model_buffer=model_util.convert_to_tflite(self._model),
|
||||
labels=base_metadata_writer.Labels().add(list(self._label_names)),
|
||||
score_thresholding=base_metadata_writer.ScoreThresholding(
|
||||
global_score_threshold=0.5))
|
||||
|
||||
writer = metadata_writer.MetadataWriter.create(
|
||||
hand_detector_model_buffer, hand_landmarks_detector_model_buffer,
|
||||
gesture_embedding_model_buffer, canned_gesture_model_buffer,
|
||||
gesture_classifier_options)
|
||||
model_bundle_content, metadata_json = writer.populate()
|
||||
with open(model_bundle_file, 'wb') as f:
|
||||
f.write(model_bundle_content)
|
||||
with open(metadata_file, 'w') as f:
|
||||
f.write(metadata_json)
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Demo for making an gesture recognizer model by Mediapipe Model Maker."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
# Dependency imports
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
|
||||
from mediapipe.model_maker.python.vision import gesture_recognizer
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
# TODO: Move hand gesture recognizer demo dataset to an
|
||||
# open-sourced directory.
|
||||
TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data'
|
||||
|
||||
|
||||
def define_flags():
|
||||
flags.DEFINE_string('export_dir', None,
|
||||
'The directory to save exported files.')
|
||||
flags.DEFINE_string('input_data_dir', None,
|
||||
'The directory with input training data.')
|
||||
flags.mark_flag_as_required('export_dir')
|
||||
|
||||
|
||||
def run(data_dir: str, export_dir: str):
|
||||
"""Runs demo."""
|
||||
data = gesture_recognizer.Dataset.from_folder(dirname=data_dir)
|
||||
train_data, rest_data = data.split(0.8)
|
||||
validation_data, test_data = rest_data.split(0.5)
|
||||
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=train_data,
|
||||
validation_data=validation_data,
|
||||
options=gesture_recognizer.GestureRecognizerOptions(
|
||||
hparams=gesture_recognizer.HParams(export_dir=export_dir)))
|
||||
|
||||
metric = model.evaluate(test_data, batch_size=2)
|
||||
print('Evaluation metric')
|
||||
print(metric)
|
||||
|
||||
model.export_model()
|
||||
|
||||
|
||||
def main(_):
|
||||
logging.set_verbosity(logging.INFO)
|
||||
|
||||
if FLAGS.input_data_dir is None:
|
||||
data_dir = os.path.join(FLAGS.test_srcdir, TEST_DATA_DIR)
|
||||
else:
|
||||
data_dir = FLAGS.input_data_dir
|
||||
|
||||
export_dir = os.path.expanduser(FLAGS.export_dir)
|
||||
run(data_dir=data_dir, export_dir=export_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
define_flags()
|
||||
app.run(main)
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Options for building gesture recognizer."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options as model_opt
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GestureRecognizerOptions:
|
||||
"""Configurable options for building gesture recognizer.
|
||||
|
||||
Attributes:
|
||||
model_options: A set of options for configuring the selected model.
|
||||
hparams: A set of hyperparameters used to train the gesture recognizer.
|
||||
"""
|
||||
model_options: Optional[model_opt.GestureRecognizerModelOptions] = None
|
||||
hparams: Optional[hyperparameters.HParams] = None
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import os
|
||||
from unittest import mock as unittest_mock
|
||||
import zipfile
|
||||
|
||||
import mock
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import test_util
|
||||
from mediapipe.model_maker.python.vision import gesture_recognizer
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data'
|
||||
|
||||
|
||||
class GestureRecognizerTest(tf.test.TestCase):
|
||||
|
||||
def _load_data(self):
|
||||
input_data_dir = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'raw_data'))
|
||||
|
||||
data = gesture_recognizer.Dataset.from_folder(
|
||||
dirname=input_data_dir,
|
||||
hparams=gesture_recognizer.HandDataPreprocessingParams(shuffle=True))
|
||||
return data
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._model_options = gesture_recognizer.ModelOptions()
|
||||
self._hparams = gesture_recognizer.HParams(epochs=2)
|
||||
self._gesture_recognizer_options = (
|
||||
gesture_recognizer.GestureRecognizerOptions(
|
||||
model_options=self._model_options, hparams=self._hparams))
|
||||
all_data = self._load_data()
|
||||
# Splits data, 90% data for training, 10% for testing
|
||||
self._train_data, self._test_data = all_data.split(0.9)
|
||||
|
||||
def test_gesture_recognizer_model(self):
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=self._gesture_recognizer_options)
|
||||
|
||||
self._test_accuracy(model)
|
||||
|
||||
def test_export_gesture_recognizer_model(self):
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=self._gesture_recognizer_options)
|
||||
model.export_model()
|
||||
model_bundle_file = os.path.join(self._hparams.export_dir,
|
||||
'gesture_recognizer.task')
|
||||
with zipfile.ZipFile(model_bundle_file) as zf:
|
||||
self.assertEqual(
|
||||
set(zf.namelist()),
|
||||
set(['hand_landmarker.task', 'hand_gesture_recognizer.task']))
|
||||
zf.extractall(self.get_temp_dir())
|
||||
hand_gesture_recognizer_bundle_file = os.path.join(
|
||||
self.get_temp_dir(), 'hand_gesture_recognizer.task')
|
||||
with zipfile.ZipFile(hand_gesture_recognizer_bundle_file) as zf:
|
||||
self.assertEqual(
|
||||
set(zf.namelist()),
|
||||
set([
|
||||
'canned_gesture_classifier.tflite',
|
||||
'custom_gesture_classifier.tflite', 'gesture_embedder.tflite'
|
||||
]))
|
||||
zf.extractall(self.get_temp_dir())
|
||||
gesture_classifier_tflite_file = os.path.join(
|
||||
self.get_temp_dir(), 'custom_gesture_classifier.tflite')
|
||||
test_util.test_tflite_file(
|
||||
keras_model=model._model,
|
||||
tflite_file=gesture_classifier_tflite_file,
|
||||
size=[1, model.embedding_size])
|
||||
|
||||
def _test_accuracy(self, model, threshold=0.5):
|
||||
_, accuracy = model.evaluate(self._test_data)
|
||||
tf.compat.v1.logging.info(f'accuracy: {accuracy}')
|
||||
self.assertGreaterEqual(accuracy, threshold)
|
||||
|
||||
@unittest_mock.patch.object(
|
||||
gesture_recognizer.hyperparameters,
|
||||
'HParams',
|
||||
autospec=True,
|
||||
return_value=gesture_recognizer.HParams(epochs=1))
|
||||
@unittest_mock.patch.object(
|
||||
gesture_recognizer.model_options,
|
||||
'GestureRecognizerModelOptions',
|
||||
autospec=True,
|
||||
return_value=gesture_recognizer.ModelOptions())
|
||||
def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
|
||||
self, mock_hparams, mock_model_options):
|
||||
options = gesture_recognizer.GestureRecognizerOptions()
|
||||
gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=options)
|
||||
mock_hparams.assert_called_once()
|
||||
mock_model_options.assert_called_once()
|
||||
|
||||
def test_continual_training_by_loading_checkpoint(self):
|
||||
mock_stdout = io.StringIO()
|
||||
with mock.patch('sys.stdout', mock_stdout):
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=self._gesture_recognizer_options)
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=self._gesture_recognizer_options)
|
||||
self._test_accuracy(model)
|
||||
|
||||
self.assertRegex(mock_stdout.getvalue(), 'Resuming from')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Hyperparameters for training customized gesture recognizer models."""
|
||||
|
||||
import dataclasses
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HParams(hp.BaseHParams):
|
||||
"""The hyperparameters for training gesture recognizer.
|
||||
|
||||
Attributes:
|
||||
learning_rate: Learning rate to use for gradient descent training.
|
||||
batch_size: Batch size for training.
|
||||
epochs: Number of training iterations over the dataset.
|
||||
lr_decay: Learning rate decay to use for gradient descent training.
|
||||
gamma: Gamma parameter for focal loss.
|
||||
"""
|
||||
# Parameters from BaseHParams class.
|
||||
learning_rate: float = 0.001
|
||||
batch_size: int = 2
|
||||
epochs: int = 10
|
||||
|
||||
# Parameters about training configuration
|
||||
# TODO: Move lr_decay to hp.baseHParams.
|
||||
lr_decay: float = 0.99
|
||||
gamma: int = 2
|
|
@ -0,0 +1,193 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Writes metadata and creates model asset bundle for gesture recognizer."""
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Union
|
||||
|
||||
import tensorflow as tf
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
|
||||
|
||||
_HAND_DETECTOR_TFLITE_NAME = "hand_detector.tflite"
|
||||
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME = "hand_landmarks_detector.tflite"
|
||||
_HAND_LANDMARKER_BUNDLE_NAME = "hand_landmarker.task"
|
||||
_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME = "hand_gesture_recognizer.task"
|
||||
_GESTURE_EMBEDDER_TFLITE_NAME = "gesture_embedder.tflite"
|
||||
_CANNED_GESTURE_CLASSIFIER_TFLITE_NAME = "canned_gesture_classifier.tflite"
|
||||
_CUSTOM_GESTURE_CLASSIFIER_TFLITE_NAME = "custom_gesture_classifier.tflite"
|
||||
|
||||
_MODEL_NAME = "HandGestureRecognition"
|
||||
_MODEL_DESCRIPTION = "Recognize the hand gesture in the image."
|
||||
|
||||
_INPUT_NAME = "embedding"
|
||||
_INPUT_DESCRIPTION = "Embedding feature vector from gesture embedder."
|
||||
_OUTPUT_NAME = "scores"
|
||||
_OUTPUT_DESCRIPTION = "Hand gesture category scores."
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GestureClassifierOptions:
|
||||
"""Options to write metadata for gesture classifier.
|
||||
|
||||
Attributes:
|
||||
model_buffer: Gesture classifier TFLite model buffer.
|
||||
labels: Labels for the gesture classifier.
|
||||
score_thresholding: Parameters to performs thresholding on output tensor
|
||||
values [1].
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||
"""
|
||||
model_buffer: bytearray
|
||||
labels: metadata_writer.Labels
|
||||
score_thresholding: metadata_writer.ScoreThresholding
|
||||
|
||||
|
||||
def read_file(file_path: str, mode: str = "rb") -> Union[str, bytes]:
|
||||
with tf.io.gfile.GFile(file_path, mode) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
class MetadataWriter:
|
||||
"""MetadataWriter to write the metadata and the model asset bundle."""
|
||||
|
||||
def __init__(
|
||||
self, hand_detector_model_buffer: bytearray,
|
||||
hand_landmarks_detector_model_buffer: bytearray,
|
||||
gesture_embedder_model_buffer: bytearray,
|
||||
canned_gesture_classifier_model_buffer: bytearray,
|
||||
custom_gesture_classifier_metadata_writer: metadata_writer.MetadataWriter
|
||||
) -> None:
|
||||
"""Initialize MetadataWriter to write the metadata and model asset bundle.
|
||||
|
||||
Args:
|
||||
hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from
|
||||
the TFLite hand detector model file.
|
||||
hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata
|
||||
loaded from the TFLite hand landmarks detector model file.
|
||||
gesture_embedder_model_buffer: A valid flatbuffer *with* metadata loaded
|
||||
from the TFLite gesture embedder model file.
|
||||
canned_gesture_classifier_model_buffer: A valid flatbuffer *with* metadata
|
||||
loaded from the TFLite canned gesture classifier model file.
|
||||
custom_gesture_classifier_metadata_writer: Metadata writer to write custom
|
||||
gesture classifier metadata into the TFLite file.
|
||||
"""
|
||||
self._hand_detector_model_buffer = hand_detector_model_buffer
|
||||
self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer
|
||||
self._gesture_embedder_model_buffer = gesture_embedder_model_buffer
|
||||
self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer
|
||||
self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer
|
||||
self._temp_folder = tempfile.TemporaryDirectory()
|
||||
|
||||
def __del__(self):
|
||||
if os.path.exists(self._temp_folder.name):
|
||||
self._temp_folder.cleanup()
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
hand_detector_model_buffer: bytearray,
|
||||
hand_landmarks_detector_model_buffer: bytearray,
|
||||
gesture_embedder_model_buffer: bytearray,
|
||||
canned_gesture_classifier_model_buffer: bytearray,
|
||||
custom_gesture_classifier_options: GestureClassifierOptions,
|
||||
) -> "MetadataWriter":
|
||||
"""Creates MetadataWriter to write the metadata for gesture recognizer.
|
||||
|
||||
Args:
|
||||
hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from
|
||||
the TFLite hand detector model file.
|
||||
hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata
|
||||
loaded from the TFLite hand landmarks detector model file.
|
||||
gesture_embedder_model_buffer: A valid flatbuffer *with* metadata loaded
|
||||
from the TFLite gesture embedder model file.
|
||||
canned_gesture_classifier_model_buffer: A valid flatbuffer *with* metadata
|
||||
loaded from the TFLite canned gesture classifier model file.
|
||||
custom_gesture_classifier_options: Custom gesture classifier options to
|
||||
write custom gesture classifier metadata into the TFLite file.
|
||||
|
||||
Returns:
|
||||
An MetadataWrite object.
|
||||
"""
|
||||
writer = metadata_writer.MetadataWriter.create(
|
||||
custom_gesture_classifier_options.model_buffer)
|
||||
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
||||
writer.add_feature_input(name=_INPUT_NAME, description=_INPUT_DESCRIPTION)
|
||||
writer.add_classification_output(
|
||||
labels=custom_gesture_classifier_options.labels,
|
||||
score_thresholding=custom_gesture_classifier_options.score_thresholding,
|
||||
name=_OUTPUT_NAME,
|
||||
description=_OUTPUT_DESCRIPTION)
|
||||
return cls(hand_detector_model_buffer, hand_landmarks_detector_model_buffer,
|
||||
gesture_embedder_model_buffer,
|
||||
canned_gesture_classifier_model_buffer, writer)
|
||||
|
||||
def populate(self):
|
||||
"""Populates the metadata and creates model asset bundle.
|
||||
|
||||
Note that only the output model asset bundle is used for deployment.
|
||||
The output JSON content is used to interpret the custom gesture classifier
|
||||
metadata content.
|
||||
|
||||
Returns:
|
||||
A tuple of (model_asset_bundle_in_bytes, metadata_json_content)
|
||||
"""
|
||||
# Creates the model asset bundle for hand landmarker task.
|
||||
landmark_models = {
|
||||
_HAND_DETECTOR_TFLITE_NAME:
|
||||
self._hand_detector_model_buffer,
|
||||
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME:
|
||||
self._hand_landmarks_detector_model_buffer
|
||||
}
|
||||
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
|
||||
_HAND_LANDMARKER_BUNDLE_NAME)
|
||||
writer_utils.create_model_asset_bundle(landmark_models,
|
||||
output_hand_landmarker_path)
|
||||
|
||||
# Write metadata into custom gesture classifier model.
|
||||
self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate(
|
||||
)
|
||||
# Creates the model asset bundle for hand gesture recognizer sub graph.
|
||||
hand_gesture_recognizer_models = {
|
||||
_GESTURE_EMBEDDER_TFLITE_NAME:
|
||||
self._gesture_embedder_model_buffer,
|
||||
_CANNED_GESTURE_CLASSIFIER_TFLITE_NAME:
|
||||
self._canned_gesture_classifier_model_buffer,
|
||||
_CUSTOM_GESTURE_CLASSIFIER_TFLITE_NAME:
|
||||
self._custom_gesture_classifier_model_buffer
|
||||
}
|
||||
output_hand_gesture_recognizer_path = os.path.join(
|
||||
self._temp_folder.name, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME)
|
||||
writer_utils.create_model_asset_bundle(hand_gesture_recognizer_models,
|
||||
output_hand_gesture_recognizer_path)
|
||||
|
||||
# Creates the model asset bundle for end-to-end hand gesture recognizer
|
||||
# graph.
|
||||
gesture_recognizer_models = {
|
||||
_HAND_LANDMARKER_BUNDLE_NAME:
|
||||
read_file(output_hand_landmarker_path),
|
||||
_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME:
|
||||
read_file(output_hand_gesture_recognizer_path),
|
||||
}
|
||||
|
||||
output_file_path = os.path.join(self._temp_folder.name,
|
||||
"gesture_recognizer.task")
|
||||
writer_utils.create_model_asset_bundle(gesture_recognizer_models,
|
||||
output_file_path)
|
||||
with open(output_file_path, "rb") as f:
|
||||
gesture_recognizer_model_buffer = f.read()
|
||||
return gesture_recognizer_model_buffer, custom_gesture_classifier_metadata_json
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for metadata_writer."""
|
||||
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer as base_metadata_writer
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_TEST_DATA_DIR = "mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata"
|
||||
|
||||
_EXPECTED_JSON = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "custom_gesture_classifier_meta.json"))
|
||||
_CUSTOM_GESTURE_CLASSIFIER_PATH = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "custom_gesture_classifier.tflite"))
|
||||
|
||||
|
||||
class MetadataWriterTest(tf.test.TestCase):
|
||||
|
||||
def test_write_metadata_and_create_model_asset_bundle_successful(self):
|
||||
# Use dummy model buffer for unit test only.
|
||||
hand_detector_model_buffer = b"\x11\x12"
|
||||
hand_landmarks_detector_model_buffer = b"\x22"
|
||||
gesture_embedder_model_buffer = b"\x33"
|
||||
canned_gesture_classifier_model_buffer = b"\x44"
|
||||
custom_gesture_classifier_metadata_writer = metadata_writer.GestureClassifierOptions(
|
||||
model_buffer=metadata_writer.read_file(_CUSTOM_GESTURE_CLASSIFIER_PATH),
|
||||
labels=base_metadata_writer.Labels().add(
|
||||
["None", "Paper", "Rock", "Scissors"]),
|
||||
score_thresholding=base_metadata_writer.ScoreThresholding(
|
||||
global_score_threshold=0.5))
|
||||
writer = metadata_writer.MetadataWriter.create(
|
||||
hand_detector_model_buffer, hand_landmarks_detector_model_buffer,
|
||||
gesture_embedder_model_buffer, canned_gesture_classifier_model_buffer,
|
||||
custom_gesture_classifier_metadata_writer)
|
||||
model_bundle_content, metadata_json = writer.populate()
|
||||
with open(_EXPECTED_JSON, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
# Checks the top-level model bundle can be extracted successfully.
|
||||
model_bundle_filepath = os.path.join(self.get_temp_dir(),
|
||||
"gesture_recognition.task")
|
||||
|
||||
with open(model_bundle_filepath, "wb") as f:
|
||||
f.write(model_bundle_content)
|
||||
|
||||
with zipfile.ZipFile(model_bundle_filepath) as zf:
|
||||
self.assertEqual(
|
||||
set(zf.namelist()),
|
||||
set(["hand_landmarker.task", "hand_gesture_recognizer.task"]))
|
||||
zf.extractall(self.get_temp_dir())
|
||||
|
||||
# Checks the model bundles for sub-task can be extracted successfully.
|
||||
hand_landmarker_bundle_filepath = os.path.join(self.get_temp_dir(),
|
||||
"hand_landmarker.task")
|
||||
with zipfile.ZipFile(hand_landmarker_bundle_filepath) as zf:
|
||||
self.assertEqual(
|
||||
set(zf.namelist()),
|
||||
set(["hand_landmarks_detector.tflite", "hand_detector.tflite"]))
|
||||
|
||||
hand_gesture_recognizer_bundle_filepath = os.path.join(
|
||||
self.get_temp_dir(), "hand_gesture_recognizer.task")
|
||||
with zipfile.ZipFile(hand_gesture_recognizer_bundle_filepath) as zf:
|
||||
self.assertEqual(
|
||||
set(zf.namelist()),
|
||||
set([
|
||||
"canned_gesture_classifier.tflite",
|
||||
"custom_gesture_classifier.tflite", "gesture_embedder.tflite"
|
||||
]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Configurable model options for gesture recognizer models."""
|
||||
|
||||
import dataclasses
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GestureRecognizerModelOptions:
|
||||
"""Configurable options for gesture recognizer model.
|
||||
|
||||
Attributes:
|
||||
dropout_rate: The fraction of the input units to drop, used in dropout
|
||||
layer.
|
||||
"""
|
||||
dropout_rate: float = 0.05
|
|
@ -0,0 +1,56 @@
|
|||
{
|
||||
"name": "HandGestureRecognition",
|
||||
"description": "Recognize the hand gesture in the image.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "embedding",
|
||||
"description": "Embedding feature vector from gesture embedder.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "scores",
|
||||
"description": "Hand gesture category scores.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "ScoreThresholdingOptions",
|
||||
"options": {
|
||||
"global_score_threshold": 0.5
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
0.0
|
||||
]
|
||||
},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
"description": "Labels for categories that the model can recognize.",
|
||||
"type": "TENSOR_AXIS_LABELS"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.0.0"
|
||||
}
|
After Width: | Height: | Size: 28 KiB |
After Width: | Height: | Size: 27 KiB |
After Width: | Height: | Size: 40 KiB |
After Width: | Height: | Size: 15 KiB |
After Width: | Height: | Size: 21 KiB |
After Width: | Height: | Size: 18 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 19 KiB |
After Width: | Height: | Size: 22 KiB |
After Width: | Height: | Size: 32 KiB |
After Width: | Height: | Size: 22 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 38 KiB |
After Width: | Height: | Size: 27 KiB |
After Width: | Height: | Size: 22 KiB |
After Width: | Height: | Size: 34 KiB |
After Width: | Height: | Size: 43 KiB |
After Width: | Height: | Size: 35 KiB |
After Width: | Height: | Size: 19 KiB |
After Width: | Height: | Size: 30 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 34 KiB |
After Width: | Height: | Size: 21 KiB |
After Width: | Height: | Size: 20 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 24 KiB |
After Width: | Height: | Size: 40 KiB |
After Width: | Height: | Size: 24 KiB |
After Width: | Height: | Size: 30 KiB |
After Width: | Height: | Size: 21 KiB |
After Width: | Height: | Size: 18 KiB |
After Width: | Height: | Size: 19 KiB |
After Width: | Height: | Size: 14 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 27 KiB |
After Width: | Height: | Size: 43 KiB |
After Width: | Height: | Size: 15 KiB |
After Width: | Height: | Size: 31 KiB |
After Width: | Height: | Size: 24 KiB |
After Width: | Height: | Size: 17 KiB |
|
@ -120,6 +120,34 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "solution_base",
|
||||
srcs = ["solution_base.py"],
|
||||
srcs_version = "PY3",
|
||||
visibility = [
|
||||
"//mediapipe/python:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
":_framework_bindings",
|
||||
":packet_creator",
|
||||
":packet_getter",
|
||||
"//mediapipe/calculators/core:constant_side_packet_calculator_py_pb2",
|
||||
"//mediapipe/calculators/image:image_transformation_calculator_py_pb2",
|
||||
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:landmarks_smoothing_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:logic_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:thresholding_calculator_py_pb2",
|
||||
"//mediapipe/framework:calculator_py_pb2",
|
||||
"//mediapipe/framework/formats:classification_py_pb2",
|
||||
"//mediapipe/framework/formats:detection_py_pb2",
|
||||
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||
"//mediapipe/framework/formats:rect_py_pb2",
|
||||
"//mediapipe/modules/objectron/calculators:annotation_py_pb2",
|
||||
"//mediapipe/modules/objectron/calculators:lift_2d_frame_annotation_to_3d_calculator_py_pb2",
|
||||
"@com_google_protobuf//:protobuf_python",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "calculator_graph_test",
|
||||
srcs = ["calculator_graph_test.py"],
|
||||
|
@ -175,3 +203,24 @@ py_test(
|
|||
":_framework_bindings",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "solution_base_test",
|
||||
srcs = ["solution_base_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":solution_base",
|
||||
"//file/google_src",
|
||||
"//file/localfile",
|
||||
"//mediapipe/calculators/core:gate_calculator",
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
"//mediapipe/calculators/core:side_packet_to_stream_calculator",
|
||||
"//mediapipe/calculators/image:image_transformation_calculator",
|
||||
"//mediapipe/calculators/util:detection_unique_id_calculator",
|
||||
"//mediapipe/calculators/util:to_image_calculator",
|
||||
"//mediapipe/framework:calculator_py_pb2",
|
||||
"//mediapipe/framework/formats:detection_py_pb2",
|
||||
"@com_google_protobuf//:protobuf_python",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -40,7 +40,6 @@ from mediapipe.calculators.util import landmarks_smoothing_calculator_pb2
|
|||
from mediapipe.calculators.util import logic_calculator_pb2
|
||||
from mediapipe.calculators.util import thresholding_calculator_pb2
|
||||
from mediapipe.framework import calculator_pb2
|
||||
from mediapipe.framework.formats import body_rig_pb2
|
||||
from mediapipe.framework.formats import classification_pb2
|
||||
from mediapipe.framework.formats import detection_pb2
|
||||
from mediapipe.framework.formats import landmark_pb2
|
||||
|
|
105
mediapipe/python/solutions/BUILD
Normal file
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2020 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
py_library(
|
||||
name = "hands",
|
||||
srcs = [
|
||||
"hands.py",
|
||||
"hands_connections.py",
|
||||
],
|
||||
data = [
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark_lite.tflite",
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark_tracking_cpu_graph",
|
||||
"//mediapipe/modules/hand_landmark:handedness.txt",
|
||||
"//mediapipe/modules/palm_detection:palm_detection_full.tflite",
|
||||
"//mediapipe/modules/palm_detection:palm_detection_lite.tflite",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:constant_side_packet_calculator_py_pb2",
|
||||
"//mediapipe/calculators/core:gate_calculator_py_pb2",
|
||||
"//mediapipe/calculators/core:split_vector_calculator_py_pb2",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator_py_pb2",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_py_pb2",
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_py_pb2",
|
||||
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_py_pb2",
|
||||
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator_py_pb2",
|
||||
"//mediapipe/calculators/tflite:ssd_anchors_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:association_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:logic_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:non_max_suppression_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:rect_transformation_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:thresholding_calculator_py_pb2",
|
||||
"//mediapipe/python:solution_base",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "drawing_styles",
|
||||
srcs = ["drawing_styles.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"drawing_utils",
|
||||
"face_mesh",
|
||||
"hands",
|
||||
"pose",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "drawing_utils",
|
||||
srcs = ["drawing_utils.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:detection_py_pb2",
|
||||
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||
"//mediapipe/framework/formats:location_data_py_pb2",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "drawing_utils_test",
|
||||
srcs = ["drawing_utils_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":drawing_utils",
|
||||
"//mediapipe/framework/formats:detection_py_pb2",
|
||||
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||
"@com_google_protobuf//:protobuf_python",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "hands_test",
|
||||
srcs = ["hands_test.py"],
|
||||
data = [
|
||||
":testdata/asl_hand.25fps.mp4",
|
||||
":testdata/asl_hand.full.npz",
|
||||
":testdata/hands.jpg",
|
||||
],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":drawing_styles",
|
||||
":drawing_utils",
|
||||
":hands",
|
||||
],
|
||||
)
|
38
third_party/external_files.bzl
vendored
|
@ -88,6 +88,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/burger_rotated.jpg?generation=1665065843774448"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_canned_gesture_classifier_tflite",
|
||||
sha256 = "2fc7e279966a7a9e15fc869223793e390791fc61fdc0062f9bc7d0eef6be98a2",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/canned_gesture_classifier.tflite?generation=1668124189331326"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_cat_jpg",
|
||||
sha256 = "2533197401eebe9410ea4d063f86c43fbd2666f3e8165a38aca155c0d09c21be",
|
||||
|
@ -286,6 +292,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/general_meta.json?generation=1665422822603848"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_gesture_embedder_tflite",
|
||||
sha256 = "54abe78de1d1cd5e3cdaa0dab01db18e3ec7e09a76e7c3b5fa278572f7a60977",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tflite?generation=1668124192126494"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_gesture_recognizer_task",
|
||||
sha256 = "a966b1d4e774e0423c19c8aa71f070e5a72fe7a03c2663dd2f3cb0b0095ee3e1",
|
||||
|
@ -721,7 +733,7 @@ def external_files():
|
|||
http_file(
|
||||
name = "com_google_mediapipe_README_md",
|
||||
sha256 = "a96d08c9c70cd9717207ed72c926e02e5eada751f00bdc5d3a7e82e3492b72cb",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/README.md?generation=1661875904887163"],
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/README.md?generation=1668124179156767"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -970,6 +982,18 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/yamnet_audio_classifier_with_metadata.tflite?generation=1661875980774466"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_gesture_embedder_keras_metadata_pb",
|
||||
sha256 = "24268b69429be4e307f9ab099ba20d1de7c40e4191a53f6a92dcbbd97a7047d3",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/keras_metadata.pb?generation=1668124196996131"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_gesture_embedder_saved_model_pb",
|
||||
sha256 = "f3a2870ba3ef537a4f6a5889ffc5b7061ad98f9fd96ec431a62116892f100659",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668124199460071"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_object_detection_saved_model_model_ckpt_data-00000-of-00001",
|
||||
sha256 = "ad2f733f271dd5000a8c7f926bfea1083e6408b34d4f3b60679e5a6f96251c97",
|
||||
|
@ -1005,3 +1029,15 @@ def external_files():
|
|||
sha256 = "f29606cf218397d5580c496e50fd28cddf66e2f59b819ab9c761b72270a5adf3",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/saved_model.pb?generation=1661875999264354"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_gesture_embedder_variables_variables_data-00000-of-00001",
|
||||
sha256 = "9fdb750c4bac67afb9c0f61916510930b496cc47e7f89449aee2bec6b6ed0af8",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.data-00000-of-00001?generation=1668124201918980"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_gesture_embedder_variables_variables_index",
|
||||
sha256 = "3ccbcee9488fec4627d496abd9837997276b32b839a4d0ae434bd806fe380b86",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668124204353848"],
|
||||
)
|
||||
|
|